summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--ipaclient/remote_plugins/schema.py301
1 files changed, 177 insertions, 124 deletions
diff --git a/ipaclient/remote_plugins/schema.py b/ipaclient/remote_plugins/schema.py
index cd1d5d607..70cd7536d 100644
--- a/ipaclient/remote_plugins/schema.py
+++ b/ipaclient/remote_plugins/schema.py
@@ -5,10 +5,8 @@
import collections
import errno
import fcntl
-import glob
import json
import os
-import re
import sys
import time
import types
@@ -65,8 +63,6 @@ USER_CACHE_PATH = (
'.cache'
)
)
-SCHEMA_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema')
-SERVERS_DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
logger = log_mgr.get_logger(__name__)
@@ -274,15 +270,6 @@ class _SchemaObjectPlugin(_SchemaPlugin):
schema_key = 'classes'
-def _ensure_dir_created(d):
- try:
- os.makedirs(d)
- except OSError as e:
- if e.errno != errno.EEXIST:
- raise RuntimeError("Unable to create cache directory: {}"
- "".format(e))
-
-
class _LockedZipFile(zipfile.ZipFile):
""" Add locking to zipfile.ZipFile
Shared lock is used with read mode, exclusive with write mode.
@@ -308,7 +295,10 @@ class _SchemaNameSpace(collections.Mapping):
self._schema = schema
def __getitem__(self, key):
- return self._schema.read_namespace_member(self.name, key)
+ try:
+ return self._schema.read_namespace_member(self.name, key)
+ except KeyError:
+ raise KeyError(key)
def __iter__(self):
for key in self._schema.iter_namespace(self.name):
@@ -322,6 +312,62 @@ class NotAvailable(Exception):
pass
+class ServerInfo(collections.MutableMapping):
+ _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'servers')
+
+ def __init__(self, api):
+ hostname = DNSName(api.env.server).ToASCII()
+ self._path = os.path.join(self._DIR, hostname)
+ self._dict = {}
+ self._dirty = False
+
+ self._read()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *_exc_info):
+ if self._dirty:
+ self._write()
+
+ def _read(self):
+ try:
+ with open(self._path, 'r') as sc:
+ self._dict = json.load(sc)
+ except EnvironmentError as e:
+ if e.errno != errno.ENOENT:
+ logger.warning('Failed to read server info: {}'.format(e))
+
+ def _write(self):
+ try:
+ try:
+ os.makedirs(self._DIR)
+ except EnvironmentError as e:
+ if e.errno != errno.EEXIST:
+ raise
+ with open(self._path, 'w') as sc:
+ json.dump(self._dict, sc)
+ except EnvironmentError as e:
+ logger.warning('Failed to write server info: {}'.format(e))
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ def __setitem__(self, key, value):
+ self._dirty = key not in self._dict or self._dict[key] != value
+ self._dict[key] = value
+
+ def __delitem__(self, key):
+ del self._dict[key]
+ self._dirty = True
+
+ def __iter__(self):
+ return iter(self._dict)
+
+ def __len__(self):
+ return len(self._dict)
+
+
class Schema(object):
"""
Store and provide schema for commands and topics
@@ -342,38 +388,76 @@ class Schema(object):
u'Ping the remote IPA server to ...'
"""
- schema_path_template = os.path.join(SCHEMA_DIR, '{}')
- servers_path_template = os.path.join(SERVERS_DIR, '{}')
- ns_member_pattern_template = '^{}/(?P<name>.+)$'
- ns_member_path_template = '{}/{}'
namespaces = {'classes', 'commands', 'topics'}
schema_info_path = 'schema'
+ _DIR = os.path.join(USER_CACHE_PATH, 'ipa', 'schema')
- @classmethod
- def _list(cls):
- for f in glob.glob(cls.schema_path_template.format('*')):
- yield os.path.splitext(os.path.basename(f))[0]
+ def __init__(self, api, server_info, client):
+ self._dict = {}
+ self._namespaces = {}
+ self._help = None
- @classmethod
- def _in_cache(cls, fingeprint):
- return os.path.exists(cls.schema_path_template.format(fingeprint))
+ for ns in self.namespaces:
+ self._dict[ns] = {}
+ self._namespaces[ns] = _SchemaNameSpace(self, ns)
- def __init__(self, api, client):
- self._api = api
- self._client = client
- self._dict = {}
+ is_known = False
+ if not api.env.force_schema_check:
+ try:
+ self._fingerprint = server_info['fingerprint']
+ self._expiration = server_info['expiration']
+ except KeyError:
+ pass
+ else:
+ is_known = True
+
+ if is_known:
+ try:
+ self._read_schema()
+ except Exception:
+ pass
+ else:
+ return
- def _open_server_info(self, hostname, mode):
- encoded_hostname = DNSName(hostname).ToASCII()
- path = self.servers_path_template.format(encoded_hostname)
- return open(path, mode)
+ try:
+ self._fetch(client)
+ except NotAvailable:
+ raise
+ else:
+ self._write_schema()
+ finally:
+ try:
+ server_info['fingerprint'] = self._fingerprint
+ server_info['expiration'] = self._expiration
+ except AttributeError:
+ pass
- def _get_schema(self):
- client = self._client
+ def _open_schema(self, filename, mode):
+ path = os.path.join(self._DIR, filename)
+ return _LockedZipFile(path, mode)
+
+ def _get_schema_fingerprint(self, schema):
+ schema_info = json.loads(schema.read(self.schema_info_path))
+ return schema_info['fingerprint']
+
+ def _fetch(self, client):
if not client.isconnected():
client.connect(verbose=False)
- fps = [unicode(f) for f in Schema._list()]
+ fps = []
+ try:
+ files = os.listdir(self._DIR)
+ except EnvironmentError:
+ pass
+ else:
+ for filename in files:
+ try:
+ with self._open_schema(filename, 'r') as schema:
+ fps.append(
+ unicode(self._get_schema_fingerprint(schema)))
+ except Exception:
+ continue
+
kwargs = {u'version': u'2.170'}
if fps:
kwargs[u'known_fingerprints'] = fps
@@ -386,110 +470,80 @@ class Schema(object):
ttl = e.ttl
else:
fp = schema['fingerprint']
- ttl = schema['ttl']
- self._store(fp, schema)
- finally:
- client.disconnect()
+ ttl = schema.pop('ttl', 0)
- exp = ttl + time.time()
- return (fp, exp)
+ for key, value in schema.items():
+ if key in self.namespaces:
+ value = {m['full_name']: m for m in value}
+ self._dict[key] = value
- def _ensure_cached(self):
- no_info = False
- try:
- # pylint: disable=access-member-before-definition
- fp = self._server_schema_fingerprint
- exp = self._server_schema_expiration
- except AttributeError:
- try:
- with self._open_server_info(self._api.env.server, 'r') as sc:
- si = json.load(sc)
-
- fp = si['fingerprint']
- exp = si['expiration']
- except Exception as e:
- no_info = True
- if not (isinstance(e, EnvironmentError) and
- e.errno == errno.ENOENT): # pylint: disable=no-member
- logger.warning('Failed to load server properties: {}'
- ''.format(e))
-
- force_check = ((not getattr(self, '_schema_checked', False)) and
- self._api.env.force_schema_check)
-
- if (force_check or
- no_info or exp < time.time() or not Schema._in_cache(fp)):
- (fp, exp) = self._get_schema()
- self._schema_checked = True
- _ensure_dir_created(SERVERS_DIR)
- try:
- with self._open_server_info(self._api.env.server, 'w') as sc:
- json.dump(dict(fingerprint=fp, expiration=exp), sc)
- except Exception as e:
- logger.warning('Failed to store server properties: {}'
- ''.format(e))
-
- if not self._dict:
- self._dict['fingerprint'] = fp
- schema_info = self._read(self.schema_info_path)
+ self._fingerprint = fp
+ self._expiration = ttl + time.time()
+
+ def _read_schema(self):
+ with self._open_schema(self._fingerprint, 'r') as schema:
+ self._dict['fingerprint'] = self._get_schema_fingerprint(schema)
+ schema_info = json.loads(schema.read(self.schema_info_path))
self._dict['version'] = schema_info['version']
- for ns in self.namespaces:
- self._dict[ns] = _SchemaNameSpace(self, ns)
- self._server_schema_fingerprintr = fp
- self._server_schema_expiration = exp
+ for name in schema.namelist():
+ ns, _slash, key = name.partition('/')
+ if ns in self.namespaces:
+ self._dict[ns][key] = {}
def __getitem__(self, key):
- self._ensure_cached()
- return self._dict[key]
+ try:
+ return self._namespaces[key]
+ except KeyError:
+ return self._dict[key]
- def _open_archive(self, mode, fp=None):
- if not fp:
- fp = self['fingerprint']
- arch_path = self.schema_path_template.format(fp)
- return _LockedZipFile(arch_path, mode)
-
- def _store(self, fingerprint, schema={}):
- _ensure_dir_created(SCHEMA_DIR)
-
- schema_info = dict(version=schema['version'],
- fingerprint=schema['fingerprint'])
-
- with self._open_archive('w', fingerprint) as zf:
- # store schema information
- zf.writestr(self.schema_info_path, json.dumps(schema_info))
- # store namespaces
- for namespace in self.namespaces:
- for member in schema[namespace]:
- path = self.ns_member_path_template.format(
- namespace,
- member['full_name']
- )
- zf.writestr(path, json.dumps(member))
+ def _write_schema(self):
+ try:
+ os.makedirs(self._DIR)
+ except EnvironmentError as e:
+ if e.errno != errno.EEXIST:
+ logger.warning("Failed ti write schema: {}".format(e))
+ return
+
+ with self._open_schema(self._fingerprint, 'w') as schema:
+ schema_info = {}
+ for key, value in self._dict.items():
+ if key in self.namespaces:
+ ns = value
+ for member in ns:
+ path = '{}/{}'.format(key, member)
+ schema.writestr(path, json.dumps(ns[member]))
+ else:
+ schema_info[key] = value
+
+ schema.writestr(self.schema_info_path, json.dumps(schema_info))
def _read(self, path):
- with self._open_archive('r') as zf:
+ with self._open_schema(self._fingerprint, 'r') as zf:
return json.loads(zf.read(path))
def read_namespace_member(self, namespace, member):
- path = self.ns_member_path_template.format(namespace, member)
- return self._read(path)
+ value = self._dict[namespace][member]
+
+ if (not value) or ('full_name' not in value):
+ path = '{}/{}'.format(namespace, member)
+ value = self._dict[namespace].setdefault(
+ member, {}
+ ).update(self._read(path))
+
+ return value
def iter_namespace(self, namespace):
- pattern = self.ns_member_pattern_template.format(namespace)
- with self._open_archive('r') as zf:
- for name in zf.namelist():
- r = re.match(pattern, name)
- if r:
- yield r.groups('name')[0]
+ return iter(self._dict[namespace])
def get_package(api, client):
try:
schema = api._schema
except AttributeError:
- schema = Schema(api, client)
- object.__setattr__(api, '_schema', schema)
+ with ServerInfo(api.env.hostname) as server_info:
+ schema = Schema(api, server_info, client)
+ object.__setattr__(api, '_schema', schema)
fingerprint = str(schema['fingerprint'])
package_name = '{}${}'.format(__name__, fingerprint)
@@ -509,10 +563,9 @@ def get_package(api, client):
module = types.ModuleType(module_name)
module.__file__ = os.path.join(package_dir, 'plugins.py')
module.register = plugable.Registry()
- for key, plugin_cls in (('commands', _SchemaCommandPlugin),
- ('classes', _SchemaObjectPlugin)):
- for full_name in schema[key]:
- plugin = plugin_cls(full_name)
+ for plugin_cls in (_SchemaCommandPlugin, _SchemaObjectPlugin):
+ for full_name in schema[plugin_cls.schema_key]:
+ plugin = plugin_cls(str(full_name))
plugin = module.register()(plugin)
sys.modules[module_name] = module