From 9960149e3f84564ab324bfb9db7c50063d87a7bd Mon Sep 17 00:00:00 2001 From: Petr Viktorin Date: Wed, 25 Apr 2012 10:31:10 -0400 Subject: Rework the CallbackInterface Fix several problems with the callback interface: - Automatically registered callbacks (i.e. methods named exc_callback, pre_callback etc) were registered on every instantiation. Fix: Do not register callbacks in __init__; instead return the method when asked for it. - The calling code had to distinguish between bound methods and plain functions by checking the 'im_self' attribute. Fix: Always return the "default" callback as an unbound method. Registered callbacks now always take the extra `self` argument, whether they happen to be bound methods or not. Calling code now always needs to pass the `self` argument. - Did not work well with inheritance: due to the fact that Python looks up missing attributes in superclasses, callbacks could get attached to a superclass if it was instantiated early enough. * Fix: Instead of attribute lookup, use a dictionary with class keys. - The interface included the callback types, which are LDAP-specific. Fix: Create generic register_callback and get_callback mehods, move LDAP-specific code to BaseLDAPCommand Update code that calls the callbacks. Add tests. Remove lint exceptions for CallbackInterface. * https://fedorahosted.org/freeipa/ticket/2674 --- ipalib/cli.py | 9 +- ipalib/plugins/baseldap.py | 334 ++++++++++++------------------ make-lint | 2 - tests/test_xmlrpc/test_baseldap_plugin.py | 95 ++++++++- 4 files changed, 229 insertions(+), 211 deletions(-) diff --git a/ipalib/cli.py b/ipalib/cli.py index 8279345a9..d53e6cd40 100644 --- a/ipalib/cli.py +++ b/ipalib/cli.py @@ -1195,8 +1195,13 @@ class cli(backend.Executioner): param.label, param.confirm ) - for callback in getattr(cmd, 'INTERACTIVE_PROMPT_CALLBACKS', []): - callback(kw) + try: + callbacks = cmd.get_callbacks('interactive_prompt') + except AttributeError: + pass + else: + for callback in callbacks: + callback(cmd, kw) def load_files(self, cmd, kw): """ diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py index b8ef43d47..475222a6a 100644 --- a/ipalib/plugins/baseldap.py +++ b/ipalib/plugins/baseldap.py @@ -690,93 +690,57 @@ def _check_limit_object_class(attributes, attrs, allow_only): if len(limitattrs) > 0 and allow_only: raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0])) + class CallbackInterface(Method): + """Callback registration interface + + This class's subclasses allow different types of callbacks to be added and + removed to them. + Registering a callback is done either by ``register_callback``, or by + defining a ``_callback`` method. + + Subclasses should define the `_callback_registry` attribute as a dictionary + mapping allowed callback types to (initially) empty dictionaries. """ - Callback registration interface - """ - def __init__(self): - #pylint: disable=E1003 - if not hasattr(self.__class__, 'PRE_CALLBACKS'): - self.__class__.PRE_CALLBACKS = [] - if not hasattr(self.__class__, 'POST_CALLBACKS'): - self.__class__.POST_CALLBACKS = [] - if not hasattr(self.__class__, 'EXC_CALLBACKS'): - self.__class__.EXC_CALLBACKS = [] - if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'): - self.__class__.INTERACTIVE_PROMPT_CALLBACKS = [] - if hasattr(self, 'pre_callback'): - self.register_pre_callback(self.pre_callback, True) - if hasattr(self, 'post_callback'): - self.register_post_callback(self.post_callback, True) - if hasattr(self, 'exc_callback'): - self.register_exc_callback(self.exc_callback, True) - if hasattr(self, 'interactive_prompt_callback'): - self.register_interactive_prompt_callback( - self.interactive_prompt_callback, True) #pylint: disable=E1101 - super(Method, self).__init__() - @classmethod - def register_pre_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'PRE_CALLBACKS'): - klass.PRE_CALLBACKS = [] - if first: - klass.PRE_CALLBACKS.insert(0, callback) - else: - klass.PRE_CALLBACKS.append(callback) + _callback_registry = dict() @classmethod - def register_post_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'POST_CALLBACKS'): - klass.POST_CALLBACKS = [] - if first: - klass.POST_CALLBACKS.insert(0, callback) - else: - klass.POST_CALLBACKS.append(callback) + def get_callbacks(cls, callback_type): + """Yield callbacks of the given type""" + # Use one shared callback registry, keyed on class, to avoid problems + # with missing attributes being looked up in superclasses + callbacks = cls._callback_registry[callback_type].get(cls, [None]) + for callback in callbacks: + if callback is None: + try: + yield getattr(cls, '%s_callback' % callback_type) + except AttributeError: + pass + else: + yield callback @classmethod - def register_exc_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'EXC_CALLBACKS'): - klass.EXC_CALLBACKS = [] - if first: - klass.EXC_CALLBACKS.insert(0, callback) - else: - klass.EXC_CALLBACKS.append(callback) + def register_callback(cls, callback_type, callback, first=False): + """Register a callback - @classmethod - def register_interactive_prompt_callback(klass, callback, first=False): + :param callback_type: The callback type (e.g. 'pre', 'post') + :param callback: The callable added + :param first: If true, the new callback will be added before all + existing callbacks; otherwise it's added after them + + Note that callbacks registered this way will be attached to this class + only, not to its subclasses. + """ assert callable(callback) - if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'): - klass.INTERACTIVE_PROMPT_CALLBACKS = [] + try: + callbacks = cls._callback_registry[callback_type][cls] + except KeyError: + callbacks = cls._callback_registry[callback_type][cls] = [None] if first: - klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback) + callbacks.insert(0, callback) else: - klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback) - - def _exc_wrapper(self, keys, options, call_func): - """Function wrapper that automatically calls exception callbacks""" - def wrapped(*call_args, **call_kwargs): - # call call_func first - func = call_func - callbacks = list(getattr(self, 'EXC_CALLBACKS', [])) - while True: - try: - return func(*call_args, **call_kwargs) - except errors.ExecutionError, e: - if not callbacks: - raise - # call exc_callback in the next loop - callback = callbacks.pop(0) - if hasattr(callback, 'im_self'): - def exc_func(*args, **kwargs): - return callback(keys, options, e, call_func, *args, **kwargs) - else: - def exc_func(*args, **kwargs): - return callback(self, keys, options, e, call_func, *args, **kwargs) - func = exc_func - return wrapped + callbacks.append(callback) class BaseLDAPCommand(CallbackInterface, Command): @@ -802,6 +766,8 @@ last, after all sets and adds."""), exclude='webui', ) + _callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={}) + def _convert_2_dict(self, attrs): """ Convert a string in the form of name/value pairs into a dictionary. @@ -961,6 +927,45 @@ last, after all sets and adds."""), elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1: entry_attrs[attr] = entry_attrs[attr][0] + @classmethod + def register_pre_callback(cls, callback, first=False): + """Shortcut for register_callback('pre', ...)""" + cls.register_callback('pre', callback, first) + + @classmethod + def register_post_callback(cls, callback, first=False): + """Shortcut for register_callback('post', ...)""" + cls.register_callback('post', callback, first) + + @classmethod + def register_exc_callback(cls, callback, first=False): + """Shortcut for register_callback('exc', ...)""" + cls.register_callback('exc', callback, first) + + @classmethod + def register_interactive_prompt_callback(cls, callback, first=False): + """Shortcut for register_callback('interactive_prompt', ...)""" + cls.register_callback('interactive_prompt', callback, first) + + def _exc_wrapper(self, keys, options, call_func): + """Function wrapper that automatically calls exception callbacks""" + def wrapped(*call_args, **call_kwargs): + # call call_func first + func = call_func + callbacks = list(self.get_callbacks('exc')) + while True: + try: + return func(*call_args, **call_kwargs) + except errors.ExecutionError, e: + if not callbacks: + raise + # call exc_callback in the next loop + callback = callbacks.pop(0) + def exc_func(*args, **kwargs): + return callback( + self, keys, options, e, call_func, *args, **kwargs) + func = exc_func + return wrapped class LDAPCreate(BaseLDAPCommand, crud.Create): """ @@ -1012,15 +1017,9 @@ class LDAPCreate(BaseLDAPCommand, crud.Create): set(self.obj.default_attributes + entry_attrs.keys()) ) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback( - ldap, dn, entry_attrs, attrs_list, *keys, **options - ) - else: - dn = callback( - self, ldap, dn, entry_attrs, attrs_list, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback( + self, ldap, dn, entry_attrs, attrs_list, *keys, **options) _check_single_value_attrs(self.params, entry_attrs) ldap.get_schema() @@ -1064,11 +1063,8 @@ class LDAPCreate(BaseLDAPCommand, crud.Create): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) entry_attrs['dn'] = dn @@ -1173,11 +1169,8 @@ class LDAPRetrieve(LDAPQuery): else: attrs_list = list(self.obj.default_attributes) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, attrs_list, *keys, **options) - else: - dn = callback(self, ldap, dn, attrs_list, *keys, **options) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, attrs_list, *keys, **options) try: (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)( @@ -1189,11 +1182,8 @@ class LDAPRetrieve(LDAPQuery): if options.get('rights', False) and options.get('all', False): entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) self.obj.convert_attribute_members(entry_attrs, *keys, **options) entry_attrs['dn'] = dn @@ -1268,15 +1258,9 @@ class LDAPUpdate(LDAPQuery, crud.Update): _check_single_value_attrs(self.params, entry_attrs) _check_empty_attrs(self.obj.params, entry_attrs) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback( - ldap, dn, entry_attrs, attrs_list, *keys, **options - ) - else: - dn = callback( - self, ldap, dn, entry_attrs, attrs_list, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback( + self, ldap, dn, entry_attrs, attrs_list, *keys, **options) ldap.get_schema() _check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.limit_object_classes), entry_attrs.keys(), allow_only=True) @@ -1323,11 +1307,8 @@ class LDAPUpdate(LDAPQuery, crud.Update): if options.get('rights', False) and options.get('all', False): entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) self.obj.convert_attribute_members(entry_attrs, *keys, **options) if self.obj.primary_key and keys[-1] is not None: @@ -1362,11 +1343,8 @@ class LDAPDelete(LDAPMultiQuery): nkeys = keys[:-1] + (pkey, ) dn = self.obj.get_dn(*nkeys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *nkeys, **options) - else: - dn = callback(self, ldap, dn, *nkeys, **options) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *nkeys, **options) def delete_subtree(base_dn): truncated = True @@ -1387,11 +1365,8 @@ class LDAPDelete(LDAPMultiQuery): delete_subtree(dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - result = callback(ldap, dn, *nkeys, **options) - else: - result = callback(self, ldap, dn, *nkeys, **options) + for callback in self.get_callbacks('post'): + result = callback(self, ldap, dn, *nkeys, **options) return result @@ -1503,13 +1478,8 @@ class LDAPAddMember(LDAPModMember): dn = self.obj.get_dn(*keys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, member_dns, failed, *keys, **options) - else: - dn = callback( - self, ldap, dn, member_dns, failed, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, member_dns, failed, *keys, **options) completed = 0 for (attr, objs) in member_dns.iteritems(): @@ -1542,16 +1512,10 @@ class LDAPAddMember(LDAPModMember): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn self.obj.convert_attribute_members(entry_attrs, *keys, **options) @@ -1602,13 +1566,8 @@ class LDAPRemoveMember(LDAPModMember): dn = self.obj.get_dn(*keys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, member_dns, failed, *keys, **options) - else: - dn = callback( - self, ldap, dn, member_dns, failed, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, member_dns, failed, *keys, **options) completed = 0 for (attr, objs) in member_dns.iteritems(): @@ -1644,16 +1603,10 @@ class LDAPRemoveMember(LDAPModMember): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn @@ -1838,15 +1791,9 @@ class LDAPSearch(BaseLDAPCommand, crud.Search): ) scope = ldap.SCOPE_ONELEVEL - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - (filter, base_dn, scope) = callback( - ldap, filter, attrs_list, base_dn, scope, *args, **options - ) - else: - (filter, base_dn, scope) = callback( - self, ldap, filter, attrs_list, base_dn, scope, *args, **options - ) + for callback in self.get_callbacks('pre'): + (filter, base_dn, scope) = callback( + self, ldap, filter, attrs_list, base_dn, scope, *args, **options) try: (entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)( @@ -1857,11 +1804,8 @@ class LDAPSearch(BaseLDAPCommand, crud.Search): except errors.NotFound: (entries, truncated) = ([], False) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - truncated = callback(ldap, entries, truncated, *args, **options) - else: - truncated = callback(self, ldap, entries, truncated, *args, **options) + for callback in self.get_callbacks('post'): + truncated = callback(self, ldap, entries, truncated, *args, **options) if self.sort_result_entries: if self.obj.primary_key: @@ -1965,13 +1909,8 @@ class LDAPAddReverseMember(LDAPModReverseMember): result = self.api.Command[self.show_command](keys[-1])['result'] dn = result['dn'] - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *keys, **options) - else: - dn = callback( - self, ldap, dn, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *keys, **options) if options.get('all', False): attrs_list = ['*'] + self.obj.default_attributes @@ -2006,16 +1945,10 @@ class LDAPAddReverseMember(LDAPModReverseMember): except Exception, e: raise errors.ReverseMemberError(verb=_('added'), exc=str(e)) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn return dict( @@ -2072,13 +2005,8 @@ class LDAPRemoveReverseMember(LDAPModReverseMember): result = self.api.Command[self.show_command](keys[-1])['result'] dn = result['dn'] - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *keys, **options) - else: - dn = callback( - self, ldap, dn, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *keys, **options) if options.get('all', False): attrs_list = ['*'] + self.obj.default_attributes @@ -2113,16 +2041,10 @@ class LDAPRemoveReverseMember(LDAPModReverseMember): except Exception, e: raise errors.ReverseMemberError(verb=_('removed'), exc=str(e)) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn return dict( diff --git a/make-lint b/make-lint index 7ecd59d7e..f61926043 100755 --- a/make-lint +++ b/make-lint @@ -51,8 +51,6 @@ class IPATypeChecker(TypeChecker): 'ipalib.plugable.Plugin': ['Command', 'Object', 'Method', 'Property', 'Backend', 'env', 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'context', 'log'], - 'ipalib.plugins.baseldap.CallbackInterface': ['pre_callback', - 'post_callback', 'exc_callback'], 'ipalib.plugins.misc.env': ['env'], 'ipalib.parameters.Param': ['cli_name', 'cli_short_name', 'label', 'doc', 'required', 'multivalue', 'primary_key', 'normalizer', diff --git a/tests/test_xmlrpc/test_baseldap_plugin.py b/tests/test_xmlrpc/test_baseldap_plugin.py index 0800a5d52..6a8501f76 100644 --- a/tests/test_xmlrpc/test_baseldap_plugin.py +++ b/tests/test_xmlrpc/test_baseldap_plugin.py @@ -24,11 +24,12 @@ Test the `ipalib.plugins.baseldap` module. from ipalib import errors from ipalib.plugins import baseldap + def test_exc_wrapper(): """Test the CallbackInterface._exc_wrapper helper method""" handled_exceptions = [] - class test_callback(baseldap.CallbackInterface): + class test_callback(baseldap.BaseLDAPCommand): """Fake IPA method""" def test_fail(self): self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2) @@ -64,3 +65,95 @@ def test_exc_wrapper(): instance.test_fail() assert handled_exceptions == [None, errors.ExecutionError] + + +def test_callback_registration(): + class callbacktest_base(baseldap.CallbackInterface): + _callback_registry = dict(test={}) + + def test_callback(self, param): + messages.append(('Base test_callback', param)) + + def registered_callback(self, param): + messages.append(('Base registered callback', param)) + callbacktest_base.register_callback('test', registered_callback) + + class SomeClass(object): + def registered_callback(self, command, param): + messages.append(('Registered callback from another class', param)) + callbacktest_base.register_callback('test', SomeClass().registered_callback) + + class callbacktest_subclass(callbacktest_base): + pass + + def subclass_callback(self, param): + messages.append(('Subclass registered callback', param)) + callbacktest_subclass.register_callback('test', subclass_callback) + + + messages = [] + instance = callbacktest_base() + for callback in instance.get_callbacks('test'): + callback(instance, 42) + assert messages == [ + ('Base test_callback', 42), + ('Base registered callback', 42), + ('Registered callback from another class', 42)] + + messages = [] + instance = callbacktest_subclass() + for callback in instance.get_callbacks('test'): + callback(instance, 42) + assert messages == [ + ('Base test_callback', 42), + ('Subclass registered callback', 42)] + + +def test_exc_callback_registration(): + messages = [] + class callbacktest_base(baseldap.BaseLDAPCommand): + """A method superclass with an exception callback""" + def exc_callback(self, keys, options, exc, call_func, *args, **kwargs): + """Let the world know we saw the error, but don't handle it""" + messages.append('Base exc_callback') + raise exc + + def test_fail(self): + """Raise a handled exception""" + try: + self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2) + except Exception: + pass + + def fail(self, *args, **kwargs): + """Raise an error""" + raise errors.ExecutionError('failure') + + base_instance = callbacktest_base() + + class callbacktest_subclass(callbacktest_base): + pass + + @callbacktest_subclass.register_exc_callback + def exc_callback(self, keys, options, exc, call_func, *args, **kwargs): + """Subclass's private exception callback""" + messages.append('Subclass registered callback') + raise exc + + subclass_instance = callbacktest_subclass() + + # Make sure exception in base class is only handled by the base class + base_instance.test_fail() + assert messages == ['Base exc_callback'] + + + @callbacktest_base.register_exc_callback + def exc_callback(self, keys, options, exc, call_func, *args, **kwargs): + """Callback on super class; doesn't affect the subclass""" + messages.append('Superclass registered callback') + raise exc + + # Make sure exception in subclass is only handled by both + messages = [] + subclass_instance.test_fail() + assert messages == ['Base exc_callback', 'Subclass registered callback'] -- cgit