diff options
-rw-r--r-- | ipaclient/remote_plugins/schema.py | 301 |
1 files changed, 177 insertions, 124 deletions
diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py index cd1d5d607..70cd7536d 100644 --- a/ipaclient/remote_plugins/schema.py +++ b/ipaclient/remote_plugins/schema.py @@ -5,10 +5,8 @@ import collections import errno import fcntl -import glob import json import os -import re import sys import time import types @@ -65,8 +63,6 @@ USER_CACHE_PATH = ( '.cache' ) ) -SCHEMA_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema') -SERVERS_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers') logger = log_mgr.get_logger(__name__) @@ -274,15 +270,6 @@ class _SchemaObjectPlugin(_SchemaPlugin): schema_key = 'classes' -def _ensure_dir_created(d): - try: - os.makedirs(d) - except OSError as e: - if e.errno != errno.EEXIST: - raise RuntimeError("Unable to create cache directory: {}" - "".format(e)) - - class _LockedZipFile(zipfile.ZipFile): """ Add locking to zipfile.ZipFile Shared lock is used with read mode, exclusive with write mode. @@ -308,7 +295,10 @@ class _SchemaNameSpace(collections.Mapping): self._schema = schema def __getitem__(self, key): - return self._schema.read_namespace_member(self.name, key) + try: + return self._schema.read_namespace_member(self.name, key) + except KeyError: + raise KeyError(key) def __iter__(self): for key in self._schema.iter_namespace(self.name): @@ -322,6 +312,62 @@ class NotAvailable(Exception): pass +class ServerInfo(collections.MutableMapping): + _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers') + + def __init__(self, api): + hostname = DNSName(api.env.server).ToASCII() + self._path = os.path.join(self._DIR, hostname) + self._dict = {} + self._dirty = False + + self._read() + + def __enter__(self): + return self + + def __exit__(self, *_exc_info): + if self._dirty: + self._write() + + def _read(self): + try: + with open(self._path, 'r') as sc: + self._dict = json.load(sc) + except EnvironmentError as e: + if e.errno != errno.ENOENT: + logger.warning('Failed to read server info: {}'.format(e)) + + def _write(self): + try: + try: + os.makedirs(self._DIR) + except EnvironmentError as e: + if e.errno != errno.EEXIST: + raise + with open(self._path, 'w') as sc: + json.dump(self._dict, sc) + except EnvironmentError as e: + logger.warning('Failed to write server info: {}'.format(e)) + + def __getitem__(self, key): + return self._dict[key] + + def __setitem__(self, key, value): + self._dirty = key not in self._dict or self._dict[key] != value + self._dict[key] = value + + def __delitem__(self, key): + del self._dict[key] + self._dirty = True + + def __iter__(self): + return iter(self._dict) + + def __len__(self): + return len(self._dict) + + class Schema(object): """ Store and provide schema for commands and topics @@ -342,38 +388,76 @@ class Schema(object): u'Ping the remote IPA server to ...' """ - schema_path_template = os.path.join(SCHEMA_DIR, '{}') - servers_path_template = os.path.join(SERVERS_DIR, '{}') - ns_member_pattern_template = '^{}/(?P<name>.+)$' - ns_member_path_template = '{}/{}' namespaces = {'classes', 'commands', 'topics'} schema_info_path = 'schema' + _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema') - @classmethod - def _list(cls): - for f in glob.glob(cls.schema_path_template.format('*')): - yield os.path.splitext(os.path.basename(f))[0] + def __init__(self, api, server_info, client): + self._dict = {} + self._namespaces = {} + self._help = None - @classmethod - def _in_cache(cls, fingeprint): - return os.path.exists(cls.schema_path_template.format(fingeprint)) + for ns in self.namespaces: + self._dict[ns] = {} + self._namespaces[ns] = _SchemaNameSpace(self, ns) - def __init__(self, api, client): - self._api = api - self._client = client - self._dict = {} + is_known = False + if not api.env.force_schema_check: + try: + self._fingerprint = server_info['fingerprint'] + self._expiration = server_info['expiration'] + except KeyError: + pass + else: + is_known = True + + if is_known: + try: + self._read_schema() + except Exception: + pass + else: + return - def _open_server_info(self, hostname, mode): - encoded_hostname = DNSName(hostname).ToASCII() - path = self.servers_path_template.format(encoded_hostname) - return open(path, mode) + try: + self._fetch(client) + except NotAvailable: + raise + else: + self._write_schema() + finally: + try: + server_info['fingerprint'] = self._fingerprint + server_info['expiration'] = self._expiration + except AttributeError: + pass - def _get_schema(self): - client = self._client + def _open_schema(self, filename, mode): + path = os.path.join(self._DIR, filename) + return _LockedZipFile(path, mode) + + def _get_schema_fingerprint(self, schema): + schema_info = json.loads(schema.read(self.schema_info_path)) + return schema_info['fingerprint'] + + def _fetch(self, client): if not client.isconnected(): client.connect(verbose=False) - fps = [unicode(f) for f in Schema._list()] + fps = [] + try: + files = os.listdir(self._DIR) + except EnvironmentError: + pass + else: + for filename in files: + try: + with self._open_schema(filename, 'r') as schema: + fps.append( + unicode(self._get_schema_fingerprint(schema))) + except Exception: + continue + kwargs = {u'version': u'2.170'} if fps: kwargs[u'known_fingerprints'] = fps @@ -386,110 +470,80 @@ class Schema(object): ttl = e.ttl else: fp = schema['fingerprint'] - ttl = schema['ttl'] - self._store(fp, schema) - finally: - client.disconnect() + ttl = schema.pop('ttl', 0) - exp = ttl + time.time() - return (fp, exp) + for key, value in schema.items(): + if key in self.namespaces: + value = {m['full_name']: m for m in value} + self._dict[key] = value - def _ensure_cached(self): - no_info = False - try: - # pylint: disable=access-member-before-definition - fp = self._server_schema_fingerprint - exp = self._server_schema_expiration - except AttributeError: - try: - with self._open_server_info(self._api.env.server, 'r') as sc: - si = json.load(sc) - - fp = si['fingerprint'] - exp = si['expiration'] - except Exception as e: - no_info = True - if not (isinstance(e, EnvironmentError) and - e.errno == errno.ENOENT): # pylint: disable=no-member - logger.warning('Failed to load server properties: {}' - ''.format(e)) - - force_check = ((not getattr(self, '_schema_checked', False)) and - self._api.env.force_schema_check) - - if (force_check or - no_info or exp < time.time() or not Schema._in_cache(fp)): - (fp, exp) = self._get_schema() - self._schema_checked = True - _ensure_dir_created(SERVERS_DIR) - try: - with self._open_server_info(self._api.env.server, 'w') as sc: - json.dump(dict(fingerprint=fp, expiration=exp), sc) - except Exception as e: - logger.warning('Failed to store server properties: {}' - ''.format(e)) - - if not self._dict: - self._dict['fingerprint'] = fp - schema_info = self._read(self.schema_info_path) + self._fingerprint = fp + self._expiration = ttl + time.time() + + def _read_schema(self): + with self._open_schema(self._fingerprint, 'r') as schema: + self._dict['fingerprint'] = self._get_schema_fingerprint(schema) + schema_info = json.loads(schema.read(self.schema_info_path)) self._dict['version'] = schema_info['version'] - for ns in self.namespaces: - self._dict[ns] = _SchemaNameSpace(self, ns) - self._server_schema_fingerprintr = fp - self._server_schema_expiration = exp + for name in schema.namelist(): + ns, _slash, key = name.partition('/') + if ns in self.namespaces: + self._dict[ns][key] = {} def __getitem__(self, key): - self._ensure_cached() - return self._dict[key] + try: + return self._namespaces[key] + except KeyError: + return self._dict[key] - def _open_archive(self, mode, fp=None): - if not fp: - fp = self['fingerprint'] - arch_path = self.schema_path_template.format(fp) - return _LockedZipFile(arch_path, mode) - - def _store(self, fingerprint, schema={}): - _ensure_dir_created(SCHEMA_DIR) - - schema_info = dict(version=schema['version'], - fingerprint=schema['fingerprint']) - - with self._open_archive('w', fingerprint) as zf: - # store schema information - zf.writestr(self.schema_info_path, json.dumps(schema_info)) - # store namespaces - for namespace in self.namespaces: - for member in schema[namespace]: - path = self.ns_member_path_template.format( - namespace, - member['full_name'] - ) - zf.writestr(path, json.dumps(member)) + def _write_schema(self): + try: + os.makedirs(self._DIR) + except EnvironmentError as e: + if e.errno != errno.EEXIST: + logger.warning("Failed ti write schema: {}".format(e)) + return + + with self._open_schema(self._fingerprint, 'w') as schema: + schema_info = {} + for key, value in self._dict.items(): + if key in self.namespaces: + ns = value + for member in ns: + path = '{}/{}'.format(key, member) + schema.writestr(path, json.dumps(ns[member])) + else: + schema_info[key] = value + + schema.writestr(self.schema_info_path, json.dumps(schema_info)) def _read(self, path): - with self._open_archive('r') as zf: + with self._open_schema(self._fingerprint, 'r') as zf: return json.loads(zf.read(path)) def read_namespace_member(self, namespace, member): - path = self.ns_member_path_template.format(namespace, member) - return self._read(path) + value = self._dict[namespace][member] + + if (not value) or ('full_name' not in value): + path = '{}/{}'.format(namespace, member) + value = self._dict[namespace].setdefault( + member, {} + ).update(self._read(path)) + + return value def iter_namespace(self, namespace): - pattern = self.ns_member_pattern_template.format(namespace) - with self._open_archive('r') as zf: - for name in zf.namelist(): - r = re.match(pattern, name) - if r: - yield r.groups('name')[0] + return iter(self._dict[namespace]) def get_package(api, client): try: schema = api._schema except AttributeError: - schema = Schema(api, client) - object.__setattr__(api, '_schema', schema) + with ServerInfo(api.env.hostname) as server_info: + schema = Schema(api, server_info, client) + object.__setattr__(api, '_schema', schema) fingerprint = str(schema['fingerprint']) package_name = '{}${}'.format(__name__, fingerprint) @@ -509,10 +563,9 @@ def get_package(api, client): module = types.ModuleType(module_name) module.__file__ = os.path.join(package_dir, 'plugins.py') module.register = plugable.Registry() - for key, plugin_cls in (('commands', _SchemaCommandPlugin), - ('classes', _SchemaObjectPlugin)): - for full_name in schema[key]: - plugin = plugin_cls(full_name) + for plugin_cls in (_SchemaCommandPlugin, _SchemaObjectPlugin): + for full_name in schema[plugin_cls.schema_key]: + plugin = plugin_cls(str(full_name)) plugin = module.register()(plugin) sys.modules[module_name] = module |