summaryrefslogtreecommitdiffstats
path: root/ipalib
diff options
context:
space:
mode:
authorJan Cholasta <jcholast@redhat.com>2016-06-20 07:41:08 +0200
committerJan Cholasta <jcholast@redhat.com>2016-06-28 13:30:49 +0200
commit79d1f5833547044a7cb2700454cacb2a0976dd5f (patch)
tree9ca54021b6b5dc8ee1dd02ab104b65188784d013 /ipalib
parent9a21964877c4bb64599e75ca708ec83a72abeb51 (diff)
downloadfreeipa-79d1f5833547044a7cb2700454cacb2a0976dd5f.tar.gz
freeipa-79d1f5833547044a7cb2700454cacb2a0976dd5f.tar.xz
freeipa-79d1f5833547044a7cb2700454cacb2a0976dd5f.zip
plugable: use plugin class as the key in API namespaces
When iterating over APINameSpace objects, use plugin class rather than its name as the key. https://fedorahosted.org/freeipa/ticket/4427 Reviewed-By: David Kupka <dkupka@redhat.com>
Diffstat (limited to 'ipalib')
-rw-r--r--ipalib/cli.py6
-rw-r--r--ipalib/frontend.py10
-rw-r--r--ipalib/plugable.py134
3 files changed, 82 insertions, 68 deletions
diff --git a/ipalib/cli.py b/ipalib/cli.py
index 374429f46..f60dc927d 100644
--- a/ipalib/cli.py
+++ b/ipalib/cli.py
@@ -790,11 +790,11 @@ class help(frontend.Local):
if type(t[2]) is dict):
self.print_commands(name, outfile)
elif name == "commands":
- mcl = max(len(s) for s in (self.Command))
- for cname in self.Command:
- cmd = self.Command[cname]
+ mcl = 0
+ for cmd in self.Command():
if cmd.NO_CLI:
continue
+ mcl = max(mcl, len(cmd.name))
writer('%s %s' % (to_cli(cmd.name).ljust(mcl), cmd.summary))
else:
raise HelpError(topic=name)
diff --git a/ipalib/frontend.py b/ipalib/frontend.py
index 3edd298f7..71db84e3e 100644
--- a/ipalib/frontend.py
+++ b/ipalib/frontend.py
@@ -389,7 +389,7 @@ class Command(HasParam):
>>> api.add_plugin(my_command)
>>> api.finalize()
>>> list(api.Command)
- ['my_command']
+ [<class '__main__.my_command'>]
>>> api.Command.my_command # doctest:+ELLIPSIS
ipalib.frontend.my_command()
@@ -1381,7 +1381,7 @@ class Method(Attribute, Command):
namespace:
>>> list(api.Method)
- ['user_add']
+ [<class '__main__.user_add'>]
>>> api.Method.user_add(version=u'2.88') # Will call user_add.run()
{'result': 'Added the user!'}
@@ -1392,7 +1392,7 @@ class Method(Attribute, Command):
plugin can also be accessed through the ``api.Command`` namespace:
>>> list(api.Command)
- ['user_add']
+ [<class '__main__.user_add'>]
>>> api.Command.user_add(version=u'2.88') # Will call user_add.run()
{'result': 'Added the user!'}
@@ -1400,7 +1400,7 @@ class Method(Attribute, Command):
`Object`:
>>> list(api.Object)
- ['user']
+ [<class '__main__.user'>]
>>> list(api.Object.user.methods)
['add']
>>> api.Object.user.methods.add(version=u'2.88') # Will call user_add.run()
@@ -1443,7 +1443,7 @@ class Updater(Plugin):
>>> api.add_plugin(my_update)
>>> api.finalize()
>>> list(api.Updater)
- ['my_update']
+ [<class '__main__.my_update'>]
>>> api.Updater.my_update # doctest:+ELLIPSIS
ipalib.frontend.my_update()
"""
diff --git a/ipalib/plugable.py b/ipalib/plugable.py
index 5a5d02fb0..575e9bd63 100644
--- a/ipalib/plugable.py
+++ b/ipalib/plugable.py
@@ -25,6 +25,7 @@ you are unfamiliar with this Python feature, see
http://docs.python.org/ref/sequence-types.html
"""
+import operator
import sys
import threading
import os
@@ -95,23 +96,23 @@ class Registry(object):
self.__registry = collections.OrderedDict()
def __call__(self, **kwargs):
- def register(klass):
+ def register(plugin):
"""
- Register the plugin ``klass``.
+ Register the plugin ``plugin``.
- :param klass: A subclass of `Plugin` to attempt to register.
+ :param plugin: A subclass of `Plugin` to attempt to register.
"""
- if not callable(klass):
- raise TypeError('plugin must be callable; got %r' % klass)
+ if not callable(plugin):
+ raise TypeError('plugin must be callable; got %r' % plugin)
# Raise DuplicateError if this exact class was already registered:
- if klass in self.__registry:
- raise errors.PluginDuplicateError(plugin=klass)
+ if plugin in self.__registry:
+ raise errors.PluginDuplicateError(plugin=plugin)
# The plugin is okay, add to __registry:
- self.__registry[klass] = dict(kwargs, klass=klass)
+ self.__registry[plugin] = dict(kwargs, plugin=plugin)
- return klass
+ return plugin
return register
@@ -270,44 +271,50 @@ class APINameSpace(collections.Mapping):
def __init__(self, api, base):
self.__api = api
self.__base = base
- self.__name_seq = None
- self.__name_set = None
+ self.__plugins = None
+ self.__plugins_by_key = None
def __enumerate(self):
- if self.__name_set is None:
- self.__name_set = frozenset(
- name for name, klass in six.iteritems(self.__api._API__plugins)
- if any(issubclass(b, self.__base) for b in klass.bases))
+ if self.__plugins is not None and self.__plugins_by_key is not None:
+ return
+
+ plugins = set()
+ key_dict = self.__plugins_by_key = {}
+
+ for plugin in self.__api._API__plugins:
+ if not any(issubclass(b, self.__base) for b in plugin.bases):
+ continue
+ plugins.add(plugin)
+ key_dict[plugin] = plugin
+ key_dict[plugin.name] = plugin
+
+ self.__plugins = sorted(plugins, key=operator.attrgetter('name'))
def __len__(self):
self.__enumerate()
- return len(self.__name_set)
+ return len(self.__plugins)
- def __contains__(self, name):
+ def __contains__(self, key):
self.__enumerate()
- return name in self.__name_set
+ return key in self.__plugins_by_key
def __iter__(self):
- if self.__name_seq is None:
- self.__enumerate()
- self.__name_seq = tuple(sorted(self.__name_set))
- return iter(self.__name_seq)
+ self.__enumerate()
+ return iter(self.__plugins)
- def __getitem__(self, name):
- name = getattr(name, '__name__', name)
- klass = self.__api._API__plugins[name]
- if not any(issubclass(b, self.__base) for b in klass.bases):
- raise KeyError(name)
- return self.__api._get(name)
+ def __getitem__(self, key):
+ self.__enumerate()
+ plugin = self.__plugins_by_key[key]
+ return self.__api._get(plugin)
def __call__(self):
return six.itervalues(self)
- def __getattr__(self, name):
+ def __getattr__(self, key):
try:
- return self[name]
+ return self[key]
except KeyError:
- raise AttributeError(name)
+ raise AttributeError(key)
class API(ReadOnly):
@@ -317,7 +324,8 @@ class API(ReadOnly):
def __init__(self):
super(API, self).__init__()
- self.__plugins = {}
+ self.__plugins = set()
+ self.__plugins_by_key = {}
self.__instances = {}
self.__next = {}
self.__done = set()
@@ -616,49 +624,51 @@ class API(ReadOnly):
raise errors.PluginModuleError(name=module.__name__)
- def add_plugin(self, klass, override=False):
+ def add_plugin(self, plugin, override=False):
"""
- Add the plugin ``klass``.
+ Add the plugin ``plugin``.
- :param klass: A subclass of `Plugin` to attempt to add.
+ :param plugin: A subclass of `Plugin` to attempt to add.
:param override: If true, override an already added plugin.
"""
- if not callable(klass):
- raise TypeError('plugin must be callable; got %r' % klass)
+ if not callable(plugin):
+ raise TypeError('plugin must be callable; got %r' % plugin)
# Find the base class or raise SubclassError:
- for base in klass.bases:
+ for base in plugin.bases:
if issubclass(base, self.bases):
break
else:
raise errors.PluginSubclassError(
- plugin=klass,
+ plugin=plugin,
bases=self.bases,
)
# Check override:
- prev = self.__plugins.get(klass.name)
+ prev = self.__plugins_by_key.get(plugin.name)
if prev:
if not override:
# Must use override=True to override:
raise errors.PluginOverrideError(
base=base.__name__,
- name=klass.name,
- plugin=klass,
+ name=plugin.name,
+ plugin=plugin,
)
- self.__next[klass] = prev
+ self.__plugins.remove(prev)
+ self.__next[plugin] = prev
else:
if override:
# There was nothing already registered to override:
raise errors.PluginMissingOverrideError(
base=base.__name__,
- name=klass.name,
- plugin=klass,
+ name=plugin.name,
+ plugin=plugin,
)
# The plugin is okay, add to sub_d:
- self.__plugins[klass.name] = klass
+ self.__plugins.add(plugin)
+ self.__plugins_by_key[plugin.name] = plugin
def finalize(self):
"""
@@ -673,19 +683,18 @@ class API(ReadOnly):
production_mode = self.is_production_mode()
for base in self.bases:
- name = base.__name__
-
- for klass in six.itervalues(self.__plugins):
- if not any(issubclass(b, base) for b in klass.bases):
+ for plugin in self.__plugins:
+ if not any(issubclass(b, base) for b in plugin.bases):
continue
if not self.env.plugins_on_demand:
- self._get(klass.name)
+ self._get(plugin)
+ name = base.__name__
if not production_mode:
assert not hasattr(self, name)
setattr(self, name, APINameSpace(self, base))
- for klass, instance in six.iteritems(self.__instances):
+ for instance in six.itervalues(self.__instances):
if not production_mode:
assert instance.api is self
if not self.env.plugins_on_demand:
@@ -698,19 +707,24 @@ class API(ReadOnly):
if not production_mode:
lock(self)
- def _get(self, name):
- klass = self.__plugins[name]
+ def _get(self, plugin):
+ if not callable(plugin):
+ raise TypeError('plugin must be callable; got %r' % plugin)
+ if plugin not in self.__plugins:
+ raise KeyError(plugin)
+
try:
- instance = self.__instances[klass]
+ instance = self.__instances[plugin]
except KeyError:
- instance = self.__instances[klass] = klass(self)
+ instance = self.__instances[plugin] = plugin(self)
+
return instance
- def get_plugin_next(self, klass):
- if not callable(klass):
- raise TypeError('plugin must be callable; got %r' % klass)
+ def get_plugin_next(self, plugin):
+ if not callable(plugin):
+ raise TypeError('plugin must be callable; got %r' % plugin)
- return self.__next[klass]
+ return self.__next[plugin]
class IPAHelpFormatter(optparse.IndentedHelpFormatter):