diff options
Diffstat (limited to 'base/common/python/pki/__init__.py')
-rw-r--r-- | base/common/python/pki/__init__.py | 190 |
1 files changed, 165 insertions, 25 deletions
diff --git a/base/common/python/pki/__init__.py b/base/common/python/pki/__init__.py index bbcffb8a4..4b18ea0ed 100644 --- a/base/common/python/pki/__init__.py +++ b/base/common/python/pki/__init__.py @@ -18,9 +18,12 @@ # Copyright (C) 2013 Red Hat, Inc. # All rights reserved. # - +''' +This module contains top-level classes and functions used by the Dogtag project. +''' import os import re +import requests CONF_DIR = '/etc/pki' @@ -33,8 +36,8 @@ PACKAGE_VERSION = SHARE_DIR + '/VERSION' def read_text(message, options=None, default=None, delimiter=':', - allowEmpty=True, caseSensitive=True): - + allow_empty=True, case_sensitive=True): + ''' get an input from the user. ''' if default: message = message + ' [' + default + ']' message = message + delimiter + ' ' @@ -45,20 +48,20 @@ def read_text(message, value = value.strip() if len(value) == 0: # empty value - if allowEmpty: + if allow_empty: value = default done = True break else: # non-empty value if options is not None: - for v in options: - if caseSensitive: - if v == value: + for val in options: + if case_sensitive: + if val == value: done = True break else: - if v.lower() == value.lower(): + if val.lower() == value.lower(): done = True break else: @@ -69,9 +72,9 @@ def read_text(message, def implementation_version(): - - with open(PACKAGE_VERSION, 'r') as f: - for line in f: + ''' Return implementation version ''' + with open(PACKAGE_VERSION, 'r') as input_file: + for line in input_file: line = line.strip('\n') # parse <key>: <value> @@ -90,27 +93,164 @@ def implementation_version(): raise Exception('Missing implementation version.') +class Attribute(object): + ''' + Class representing a key/value pair. + + This object is the basis of the representation of a ResourceMessage. + ''' + + def __init__(self, name, value): + ''' Constructor ''' + self.name = name + self.value = value + +class AttributeList(object): + ''' + Class representing a list of attributes. + + This class is needed because of a JavaMapper used in the REST API. + ''' + + def __init__(self): + ''' Constructor ''' + self.Attribute = [] + +class ResourceMessage(object): + ''' + This class is the basis for the various types of key requests. + It is essentially a list of attributes. + ''' + + def __init__(self, class_name): + ''' Constructor ''' + self.Attributes = AttributeList() + self.ClassName = class_name + + def add_attribute(self, name, value): + ''' Add an attribute to the list. ''' + attr = Attribute(name, value) + self.Attributes.Attribute.append(attr) + + def get_attribute_value(self, name): + ''' Get the value of a given attribute ''' + for attr in self.Attributes.Attribute: + if attr.name == name: + return attr.value + return None -class PKIException(Exception): - - def __init__(self, message, exception=None): +class PKIException(Exception, ResourceMessage): + ''' + Base exception class for REST Interface + ''' + def __init__(self, message, exception=None, code=None, class_name=None): + ''' Constructor ''' Exception.__init__(self, message) - + ResourceMessage.__init__(self, class_name) + self.code = code + self.message = message self.exception = exception + @classmethod + def from_json(cls, json_value): + ''' Construct exception from JSON ''' + ret = cls(json_value['Message'], json_value['Code'], json_value['ClassName']) + for attr in json_value['Attributes']['Attribute']: + print str(attr) + ret.add_attribute(attr["name"], attr["value"]) + return ret + +class BadRequestException(PKIException): + ''' Bad Request Exception: return code = 400 ''' + +class ConflictingOperationException(PKIException): + ''' Conflicting Operation Exception: return code = 409 ''' + +class ForbiddenException(PKIException): + ''' Forbidden Exception: return code = 403 ''' + +class HTTPGoneException(PKIException): + ''' Gone Exception: return code = 410 ''' + +class ResourceNotFoundException(PKIException): + ''' Not Found Exception: return code = 404 ''' + +class UnauthorizedException(PKIException): + ''' Unauthorized Exception: return code = 401 ''' + +class CertNotFoundException(ResourceNotFoundException): + ''' Cert Not Found Exception: return code = 404 ''' + +class GroupNotFoundException(ResourceNotFoundException): + ''' Group Not Found Exception: return code = 404 ''' + +class ProfileNotFoundException(ResourceNotFoundException): + ''' Profile Not Found Exception: return code = 404 ''' + +class RequestNotFoundException(ResourceNotFoundException): + ''' Request Not Found Exception: return code = 404 ''' + +class UserNotFoundException(ResourceNotFoundException): + ''' User Not Found Exception: return code = 404 ''' + +EXCEPTION_MAPPINGS = { + "com.netscape.certsrv.base.BadRequestException": "pki.BadRequestException", + "com.netscape.certsrv.base.ConflictingOperationException": "pki.ConflictingOperationException", + "com.netscape.certsrv.base.ForbiddenException": "pki.ForbiddenException", + "com.netscape.certsrv.base.HTTPGoneException": "pki.HTTPGoneException", + "com.netscape.certsrv.base.ResourceNotFoundException": "pki.ResourceNotFoundException", + "com.netscape.certsrv.cert.CertNotFoundException": "pki.CertNotFoundException", + "com.netscape.certsrv.group.GroupNotFoundException": "pki.GroupNotFoundException", + "com.netscape.certsrv.profile.ProfileNotFoundException": "pki.ProfileNotFoundException", + "com.netscape.certsrv.request.RequestNotFoundException": "pki.RequestNotFoundException", + "com.netscape.certsrv.base.UserNotFoundException": "pki.UserNotFoundException", + "com.netscape.certsrv.base.PKIException": "pki.PKIException"} + +def get_class( kls ): + ''' Get reference to the class specified by string kls ''' + parts = kls.split('.') + module = ".".join(parts[:-1]) + mod = __import__( module ) + for comp in parts[1:]: + mod = getattr(mod, comp) + return mod + +def handle_exceptions(): + ''' Decorator handling exceptions from REST methods. ''' + + def exceptions_decorator(fn_call): + ''' The actual decorator handler.''' + + def handler(inst, *args, **kwargs): + ''' Decorator to catch and re-throw PKIExceptions.''' + try: + return fn_call(inst, *args, **kwargs) + except requests.exceptions.HTTPError as exc: + clazz = exc.response.json()['ClassName'] + if clazz in EXCEPTION_MAPPINGS: + exception_class = get_class(EXCEPTION_MAPPINGS[clazz]) + pki_exception = exception_class.from_json(exc.response.json()) + raise pki_exception + else: + raise exc + + return handler + return exceptions_decorator + class PropertyFile(object): + ''' Class to manage property files ''' def __init__(self, filename, delimiter='='): - + ''' Constructor ''' self.filename = filename self.delimiter = delimiter self.lines = [] def read(self): - + ''' Read from propert file ''' self.lines = [] if not os.path.exists(self.filename): @@ -123,27 +263,27 @@ class PropertyFile(object): self.lines.append(line) def write(self): - + ''' Write to property file ''' # write all lines in the original order with open(self.filename, 'w') as f: for line in self.lines: f.write(line + '\n') def show(self): - + ''' Show contents of property file.''' for line in self.lines: print line def insert_line(self, index, line): - + ''' Insert line in property file ''' self.lines.insert(index, line) def remove_line(self, index): - + ''' Remove line from property file ''' self.lines.pop(index) def index(self, name): - + ''' Find the index (position) of a property in a property file ''' for i, line in enumerate(self.lines): # parse <key> <delimiter> <value> @@ -160,7 +300,7 @@ class PropertyFile(object): return -1 def get(self, name): - + ''' Get value for specified property ''' result = None for line in self.lines: @@ -180,7 +320,7 @@ class PropertyFile(object): return result def set(self, name, value, index=None): - + ''' Set value for specified property ''' for i, line in enumerate(self.lines): # parse <key> <delimiter> <value> @@ -202,7 +342,7 @@ class PropertyFile(object): self.insert_line(index, name + self.delimiter + value) def remove(self, name): - + ''' Remove property from property file ''' for i, line in enumerate(self.lines): # parse <key> <delimiter> <value> |