From 9960149e3f84564ab324bfb9db7c50063d87a7bd Mon Sep 17 00:00:00 2001 From: Petr Viktorin Date: Wed, 25 Apr 2012 10:31:10 -0400 Subject: Rework the CallbackInterface Fix several problems with the callback interface: - Automatically registered callbacks (i.e. methods named exc_callback, pre_callback etc) were registered on every instantiation. Fix: Do not register callbacks in __init__; instead return the method when asked for it. - The calling code had to distinguish between bound methods and plain functions by checking the 'im_self' attribute. Fix: Always return the "default" callback as an unbound method. Registered callbacks now always take the extra `self` argument, whether they happen to be bound methods or not. Calling code now always needs to pass the `self` argument. - Did not work well with inheritance: due to the fact that Python looks up missing attributes in superclasses, callbacks could get attached to a superclass if it was instantiated early enough. * Fix: Instead of attribute lookup, use a dictionary with class keys. - The interface included the callback types, which are LDAP-specific. Fix: Create generic register_callback and get_callback mehods, move LDAP-specific code to BaseLDAPCommand Update code that calls the callbacks. Add tests. Remove lint exceptions for CallbackInterface. * https://fedorahosted.org/freeipa/ticket/2674 --- ipalib/plugins/baseldap.py | 334 +++++++++++++++++---------------------------- 1 file changed, 128 insertions(+), 206 deletions(-) (limited to 'ipalib/plugins') diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py index b8ef43d47..475222a6a 100644 --- a/ipalib/plugins/baseldap.py +++ b/ipalib/plugins/baseldap.py @@ -690,93 +690,57 @@ def _check_limit_object_class(attributes, attrs, allow_only): if len(limitattrs) > 0 and allow_only: raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0])) + class CallbackInterface(Method): + """Callback registration interface + + This class's subclasses allow different types of callbacks to be added and + removed to them. + Registering a callback is done either by ``register_callback``, or by + defining a ``_callback`` method. + + Subclasses should define the `_callback_registry` attribute as a dictionary + mapping allowed callback types to (initially) empty dictionaries. """ - Callback registration interface - """ - def __init__(self): - #pylint: disable=E1003 - if not hasattr(self.__class__, 'PRE_CALLBACKS'): - self.__class__.PRE_CALLBACKS = [] - if not hasattr(self.__class__, 'POST_CALLBACKS'): - self.__class__.POST_CALLBACKS = [] - if not hasattr(self.__class__, 'EXC_CALLBACKS'): - self.__class__.EXC_CALLBACKS = [] - if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'): - self.__class__.INTERACTIVE_PROMPT_CALLBACKS = [] - if hasattr(self, 'pre_callback'): - self.register_pre_callback(self.pre_callback, True) - if hasattr(self, 'post_callback'): - self.register_post_callback(self.post_callback, True) - if hasattr(self, 'exc_callback'): - self.register_exc_callback(self.exc_callback, True) - if hasattr(self, 'interactive_prompt_callback'): - self.register_interactive_prompt_callback( - self.interactive_prompt_callback, True) #pylint: disable=E1101 - super(Method, self).__init__() - @classmethod - def register_pre_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'PRE_CALLBACKS'): - klass.PRE_CALLBACKS = [] - if first: - klass.PRE_CALLBACKS.insert(0, callback) - else: - klass.PRE_CALLBACKS.append(callback) + _callback_registry = dict() @classmethod - def register_post_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'POST_CALLBACKS'): - klass.POST_CALLBACKS = [] - if first: - klass.POST_CALLBACKS.insert(0, callback) - else: - klass.POST_CALLBACKS.append(callback) + def get_callbacks(cls, callback_type): + """Yield callbacks of the given type""" + # Use one shared callback registry, keyed on class, to avoid problems + # with missing attributes being looked up in superclasses + callbacks = cls._callback_registry[callback_type].get(cls, [None]) + for callback in callbacks: + if callback is None: + try: + yield getattr(cls, '%s_callback' % callback_type) + except AttributeError: + pass + else: + yield callback @classmethod - def register_exc_callback(klass, callback, first=False): - assert callable(callback) - if not hasattr(klass, 'EXC_CALLBACKS'): - klass.EXC_CALLBACKS = [] - if first: - klass.EXC_CALLBACKS.insert(0, callback) - else: - klass.EXC_CALLBACKS.append(callback) + def register_callback(cls, callback_type, callback, first=False): + """Register a callback - @classmethod - def register_interactive_prompt_callback(klass, callback, first=False): + :param callback_type: The callback type (e.g. 'pre', 'post') + :param callback: The callable added + :param first: If true, the new callback will be added before all + existing callbacks; otherwise it's added after them + + Note that callbacks registered this way will be attached to this class + only, not to its subclasses. + """ assert callable(callback) - if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'): - klass.INTERACTIVE_PROMPT_CALLBACKS = [] + try: + callbacks = cls._callback_registry[callback_type][cls] + except KeyError: + callbacks = cls._callback_registry[callback_type][cls] = [None] if first: - klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback) + callbacks.insert(0, callback) else: - klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback) - - def _exc_wrapper(self, keys, options, call_func): - """Function wrapper that automatically calls exception callbacks""" - def wrapped(*call_args, **call_kwargs): - # call call_func first - func = call_func - callbacks = list(getattr(self, 'EXC_CALLBACKS', [])) - while True: - try: - return func(*call_args, **call_kwargs) - except errors.ExecutionError, e: - if not callbacks: - raise - # call exc_callback in the next loop - callback = callbacks.pop(0) - if hasattr(callback, 'im_self'): - def exc_func(*args, **kwargs): - return callback(keys, options, e, call_func, *args, **kwargs) - else: - def exc_func(*args, **kwargs): - return callback(self, keys, options, e, call_func, *args, **kwargs) - func = exc_func - return wrapped + callbacks.append(callback) class BaseLDAPCommand(CallbackInterface, Command): @@ -802,6 +766,8 @@ last, after all sets and adds."""), exclude='webui', ) + _callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={}) + def _convert_2_dict(self, attrs): """ Convert a string in the form of name/value pairs into a dictionary. @@ -961,6 +927,45 @@ last, after all sets and adds."""), elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1: entry_attrs[attr] = entry_attrs[attr][0] + @classmethod + def register_pre_callback(cls, callback, first=False): + """Shortcut for register_callback('pre', ...)""" + cls.register_callback('pre', callback, first) + + @classmethod + def register_post_callback(cls, callback, first=False): + """Shortcut for register_callback('post', ...)""" + cls.register_callback('post', callback, first) + + @classmethod + def register_exc_callback(cls, callback, first=False): + """Shortcut for register_callback('exc', ...)""" + cls.register_callback('exc', callback, first) + + @classmethod + def register_interactive_prompt_callback(cls, callback, first=False): + """Shortcut for register_callback('interactive_prompt', ...)""" + cls.register_callback('interactive_prompt', callback, first) + + def _exc_wrapper(self, keys, options, call_func): + """Function wrapper that automatically calls exception callbacks""" + def wrapped(*call_args, **call_kwargs): + # call call_func first + func = call_func + callbacks = list(self.get_callbacks('exc')) + while True: + try: + return func(*call_args, **call_kwargs) + except errors.ExecutionError, e: + if not callbacks: + raise + # call exc_callback in the next loop + callback = callbacks.pop(0) + def exc_func(*args, **kwargs): + return callback( + self, keys, options, e, call_func, *args, **kwargs) + func = exc_func + return wrapped class LDAPCreate(BaseLDAPCommand, crud.Create): """ @@ -1012,15 +1017,9 @@ class LDAPCreate(BaseLDAPCommand, crud.Create): set(self.obj.default_attributes + entry_attrs.keys()) ) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback( - ldap, dn, entry_attrs, attrs_list, *keys, **options - ) - else: - dn = callback( - self, ldap, dn, entry_attrs, attrs_list, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback( + self, ldap, dn, entry_attrs, attrs_list, *keys, **options) _check_single_value_attrs(self.params, entry_attrs) ldap.get_schema() @@ -1064,11 +1063,8 @@ class LDAPCreate(BaseLDAPCommand, crud.Create): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) entry_attrs['dn'] = dn @@ -1173,11 +1169,8 @@ class LDAPRetrieve(LDAPQuery): else: attrs_list = list(self.obj.default_attributes) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, attrs_list, *keys, **options) - else: - dn = callback(self, ldap, dn, attrs_list, *keys, **options) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, attrs_list, *keys, **options) try: (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)( @@ -1189,11 +1182,8 @@ class LDAPRetrieve(LDAPQuery): if options.get('rights', False) and options.get('all', False): entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) self.obj.convert_attribute_members(entry_attrs, *keys, **options) entry_attrs['dn'] = dn @@ -1268,15 +1258,9 @@ class LDAPUpdate(LDAPQuery, crud.Update): _check_single_value_attrs(self.params, entry_attrs) _check_empty_attrs(self.obj.params, entry_attrs) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback( - ldap, dn, entry_attrs, attrs_list, *keys, **options - ) - else: - dn = callback( - self, ldap, dn, entry_attrs, attrs_list, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback( + self, ldap, dn, entry_attrs, attrs_list, *keys, **options) ldap.get_schema() _check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.limit_object_classes), entry_attrs.keys(), allow_only=True) @@ -1323,11 +1307,8 @@ class LDAPUpdate(LDAPQuery, crud.Update): if options.get('rights', False) and options.get('all', False): entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, entry_attrs, *keys, **options) - else: - dn = callback(self, ldap, dn, entry_attrs, *keys, **options) + for callback in self.get_callbacks('post'): + dn = callback(self, ldap, dn, entry_attrs, *keys, **options) self.obj.convert_attribute_members(entry_attrs, *keys, **options) if self.obj.primary_key and keys[-1] is not None: @@ -1362,11 +1343,8 @@ class LDAPDelete(LDAPMultiQuery): nkeys = keys[:-1] + (pkey, ) dn = self.obj.get_dn(*nkeys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *nkeys, **options) - else: - dn = callback(self, ldap, dn, *nkeys, **options) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *nkeys, **options) def delete_subtree(base_dn): truncated = True @@ -1387,11 +1365,8 @@ class LDAPDelete(LDAPMultiQuery): delete_subtree(dn) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - result = callback(ldap, dn, *nkeys, **options) - else: - result = callback(self, ldap, dn, *nkeys, **options) + for callback in self.get_callbacks('post'): + result = callback(self, ldap, dn, *nkeys, **options) return result @@ -1503,13 +1478,8 @@ class LDAPAddMember(LDAPModMember): dn = self.obj.get_dn(*keys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, member_dns, failed, *keys, **options) - else: - dn = callback( - self, ldap, dn, member_dns, failed, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, member_dns, failed, *keys, **options) completed = 0 for (attr, objs) in member_dns.iteritems(): @@ -1542,16 +1512,10 @@ class LDAPAddMember(LDAPModMember): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn self.obj.convert_attribute_members(entry_attrs, *keys, **options) @@ -1602,13 +1566,8 @@ class LDAPRemoveMember(LDAPModMember): dn = self.obj.get_dn(*keys, **options) - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, member_dns, failed, *keys, **options) - else: - dn = callback( - self, ldap, dn, member_dns, failed, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, member_dns, failed, *keys, **options) completed = 0 for (attr, objs) in member_dns.iteritems(): @@ -1644,16 +1603,10 @@ class LDAPRemoveMember(LDAPModMember): except errors.NotFound: self.obj.handle_not_found(*keys) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn @@ -1838,15 +1791,9 @@ class LDAPSearch(BaseLDAPCommand, crud.Search): ) scope = ldap.SCOPE_ONELEVEL - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - (filter, base_dn, scope) = callback( - ldap, filter, attrs_list, base_dn, scope, *args, **options - ) - else: - (filter, base_dn, scope) = callback( - self, ldap, filter, attrs_list, base_dn, scope, *args, **options - ) + for callback in self.get_callbacks('pre'): + (filter, base_dn, scope) = callback( + self, ldap, filter, attrs_list, base_dn, scope, *args, **options) try: (entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)( @@ -1857,11 +1804,8 @@ class LDAPSearch(BaseLDAPCommand, crud.Search): except errors.NotFound: (entries, truncated) = ([], False) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - truncated = callback(ldap, entries, truncated, *args, **options) - else: - truncated = callback(self, ldap, entries, truncated, *args, **options) + for callback in self.get_callbacks('post'): + truncated = callback(self, ldap, entries, truncated, *args, **options) if self.sort_result_entries: if self.obj.primary_key: @@ -1965,13 +1909,8 @@ class LDAPAddReverseMember(LDAPModReverseMember): result = self.api.Command[self.show_command](keys[-1])['result'] dn = result['dn'] - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *keys, **options) - else: - dn = callback( - self, ldap, dn, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *keys, **options) if options.get('all', False): attrs_list = ['*'] + self.obj.default_attributes @@ -2006,16 +1945,10 @@ class LDAPAddReverseMember(LDAPModReverseMember): except Exception, e: raise errors.ReverseMemberError(verb=_('added'), exc=str(e)) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn return dict( @@ -2072,13 +2005,8 @@ class LDAPRemoveReverseMember(LDAPModReverseMember): result = self.api.Command[self.show_command](keys[-1])['result'] dn = result['dn'] - for callback in self.PRE_CALLBACKS: - if hasattr(callback, 'im_self'): - dn = callback(ldap, dn, *keys, **options) - else: - dn = callback( - self, ldap, dn, *keys, **options - ) + for callback in self.get_callbacks('pre'): + dn = callback(self, ldap, dn, *keys, **options) if options.get('all', False): attrs_list = ['*'] + self.obj.default_attributes @@ -2113,16 +2041,10 @@ class LDAPRemoveReverseMember(LDAPModReverseMember): except Exception, e: raise errors.ReverseMemberError(verb=_('removed'), exc=str(e)) - for callback in self.POST_CALLBACKS: - if hasattr(callback, 'im_self'): - (completed, dn) = callback( - ldap, completed, failed, dn, entry_attrs, *keys, **options - ) - else: - (completed, dn) = callback( - self, ldap, completed, failed, dn, entry_attrs, *keys, - **options - ) + for callback in self.get_callbacks('post'): + (completed, dn) = callback( + self, ldap, completed, failed, dn, entry_attrs, *keys, + **options) entry_attrs['dn'] = dn return dict( -- cgit