summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPetr Viktorin <pviktori@redhat.com>2012-04-25 10:31:10 -0400
committerMartin Kosek <mkosek@redhat.com>2012-06-14 11:09:43 +0200
commit9960149e3f84564ab324bfb9db7c50063d87a7bd (patch)
treec3caa731e88c1c12479edd6a2e59a9518688626b
parentf52fa2a0185c8cc4e1c2cacda3eac59209e659f4 (diff)
downloadfreeipa-9960149e3f84564ab324bfb9db7c50063d87a7bd.zip
freeipa-9960149e3f84564ab324bfb9db7c50063d87a7bd.tar.gz
freeipa-9960149e3f84564ab324bfb9db7c50063d87a7bd.tar.xz
Rework the CallbackInterface
Fix several problems with the callback interface: - Automatically registered callbacks (i.e. methods named exc_callback, pre_callback etc) were registered on every instantiation. Fix: Do not register callbacks in __init__; instead return the method when asked for it. - The calling code had to distinguish between bound methods and plain functions by checking the 'im_self' attribute. Fix: Always return the "default" callback as an unbound method. Registered callbacks now always take the extra `self` argument, whether they happen to be bound methods or not. Calling code now always needs to pass the `self` argument. - Did not work well with inheritance: due to the fact that Python looks up missing attributes in superclasses, callbacks could get attached to a superclass if it was instantiated early enough. * Fix: Instead of attribute lookup, use a dictionary with class keys. - The interface included the callback types, which are LDAP-specific. Fix: Create generic register_callback and get_callback mehods, move LDAP-specific code to BaseLDAPCommand Update code that calls the callbacks. Add tests. Remove lint exceptions for CallbackInterface. * https://fedorahosted.org/freeipa/ticket/2674
-rw-r--r--ipalib/cli.py9
-rw-r--r--ipalib/plugins/baseldap.py334
-rwxr-xr-xmake-lint2
-rw-r--r--tests/test_xmlrpc/test_baseldap_plugin.py95
4 files changed, 229 insertions, 211 deletions
diff --git a/ipalib/cli.py b/ipalib/cli.py
index 8279345..d53e6cd 100644
--- a/ipalib/cli.py
+++ b/ipalib/cli.py
@@ -1195,8 +1195,13 @@ class cli(backend.Executioner):
param.label, param.confirm
)
- for callback in getattr(cmd, 'INTERACTIVE_PROMPT_CALLBACKS', []):
- callback(kw)
+ try:
+ callbacks = cmd.get_callbacks('interactive_prompt')
+ except AttributeError:
+ pass
+ else:
+ for callback in callbacks:
+ callback(cmd, kw)
def load_files(self, cmd, kw):
"""
diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py
index b8ef43d..475222a 100644
--- a/ipalib/plugins/baseldap.py
+++ b/ipalib/plugins/baseldap.py
@@ -690,93 +690,57 @@ def _check_limit_object_class(attributes, attrs, allow_only):
if len(limitattrs) > 0 and allow_only:
raise errors.ObjectclassViolation(info='attribute "%(attribute)s" not allowed' % dict(attribute=limitattrs[0]))
+
class CallbackInterface(Method):
+ """Callback registration interface
+
+ This class's subclasses allow different types of callbacks to be added and
+ removed to them.
+ Registering a callback is done either by ``register_callback``, or by
+ defining a ``<type>_callback`` method.
+
+ Subclasses should define the `_callback_registry` attribute as a dictionary
+ mapping allowed callback types to (initially) empty dictionaries.
"""
- Callback registration interface
- """
- def __init__(self):
- #pylint: disable=E1003
- if not hasattr(self.__class__, 'PRE_CALLBACKS'):
- self.__class__.PRE_CALLBACKS = []
- if not hasattr(self.__class__, 'POST_CALLBACKS'):
- self.__class__.POST_CALLBACKS = []
- if not hasattr(self.__class__, 'EXC_CALLBACKS'):
- self.__class__.EXC_CALLBACKS = []
- if not hasattr(self.__class__, 'INTERACTIVE_PROMPT_CALLBACKS'):
- self.__class__.INTERACTIVE_PROMPT_CALLBACKS = []
- if hasattr(self, 'pre_callback'):
- self.register_pre_callback(self.pre_callback, True)
- if hasattr(self, 'post_callback'):
- self.register_post_callback(self.post_callback, True)
- if hasattr(self, 'exc_callback'):
- self.register_exc_callback(self.exc_callback, True)
- if hasattr(self, 'interactive_prompt_callback'):
- self.register_interactive_prompt_callback(
- self.interactive_prompt_callback, True) #pylint: disable=E1101
- super(Method, self).__init__()
- @classmethod
- def register_pre_callback(klass, callback, first=False):
- assert callable(callback)
- if not hasattr(klass, 'PRE_CALLBACKS'):
- klass.PRE_CALLBACKS = []
- if first:
- klass.PRE_CALLBACKS.insert(0, callback)
- else:
- klass.PRE_CALLBACKS.append(callback)
+ _callback_registry = dict()
@classmethod
- def register_post_callback(klass, callback, first=False):
- assert callable(callback)
- if not hasattr(klass, 'POST_CALLBACKS'):
- klass.POST_CALLBACKS = []
- if first:
- klass.POST_CALLBACKS.insert(0, callback)
- else:
- klass.POST_CALLBACKS.append(callback)
+ def get_callbacks(cls, callback_type):
+ """Yield callbacks of the given type"""
+ # Use one shared callback registry, keyed on class, to avoid problems
+ # with missing attributes being looked up in superclasses
+ callbacks = cls._callback_registry[callback_type].get(cls, [None])
+ for callback in callbacks:
+ if callback is None:
+ try:
+ yield getattr(cls, '%s_callback' % callback_type)
+ except AttributeError:
+ pass
+ else:
+ yield callback
@classmethod
- def register_exc_callback(klass, callback, first=False):
- assert callable(callback)
- if not hasattr(klass, 'EXC_CALLBACKS'):
- klass.EXC_CALLBACKS = []
- if first:
- klass.EXC_CALLBACKS.insert(0, callback)
- else:
- klass.EXC_CALLBACKS.append(callback)
+ def register_callback(cls, callback_type, callback, first=False):
+ """Register a callback
- @classmethod
- def register_interactive_prompt_callback(klass, callback, first=False):
+ :param callback_type: The callback type (e.g. 'pre', 'post')
+ :param callback: The callable added
+ :param first: If true, the new callback will be added before all
+ existing callbacks; otherwise it's added after them
+
+ Note that callbacks registered this way will be attached to this class
+ only, not to its subclasses.
+ """
assert callable(callback)
- if not hasattr(klass, 'INTERACTIVE_PROMPT_CALLBACKS'):
- klass.INTERACTIVE_PROMPT_CALLBACKS = []
+ try:
+ callbacks = cls._callback_registry[callback_type][cls]
+ except KeyError:
+ callbacks = cls._callback_registry[callback_type][cls] = [None]
if first:
- klass.INTERACTIVE_PROMPT_CALLBACKS.insert(0, callback)
+ callbacks.insert(0, callback)
else:
- klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
-
- def _exc_wrapper(self, keys, options, call_func):
- """Function wrapper that automatically calls exception callbacks"""
- def wrapped(*call_args, **call_kwargs):
- # call call_func first
- func = call_func
- callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
- while True:
- try:
- return func(*call_args, **call_kwargs)
- except errors.ExecutionError, e:
- if not callbacks:
- raise
- # call exc_callback in the next loop
- callback = callbacks.pop(0)
- if hasattr(callback, 'im_self'):
- def exc_func(*args, **kwargs):
- return callback(keys, options, e, call_func, *args, **kwargs)
- else:
- def exc_func(*args, **kwargs):
- return callback(self, keys, options, e, call_func, *args, **kwargs)
- func = exc_func
- return wrapped
+ callbacks.append(callback)
class BaseLDAPCommand(CallbackInterface, Command):
@@ -802,6 +766,8 @@ last, after all sets and adds."""),
exclude='webui',
)
+ _callback_registry = dict(pre={}, post={}, exc={}, interactive_prompt={})
+
def _convert_2_dict(self, attrs):
"""
Convert a string in the form of name/value pairs into a dictionary.
@@ -961,6 +927,45 @@ last, after all sets and adds."""),
elif isinstance(entry_attrs[attr], (tuple, list)) and len(entry_attrs[attr]) == 1:
entry_attrs[attr] = entry_attrs[attr][0]
+ @classmethod
+ def register_pre_callback(cls, callback, first=False):
+ """Shortcut for register_callback('pre', ...)"""
+ cls.register_callback('pre', callback, first)
+
+ @classmethod
+ def register_post_callback(cls, callback, first=False):
+ """Shortcut for register_callback('post', ...)"""
+ cls.register_callback('post', callback, first)
+
+ @classmethod
+ def register_exc_callback(cls, callback, first=False):
+ """Shortcut for register_callback('exc', ...)"""
+ cls.register_callback('exc', callback, first)
+
+ @classmethod
+ def register_interactive_prompt_callback(cls, callback, first=False):
+ """Shortcut for register_callback('interactive_prompt', ...)"""
+ cls.register_callback('interactive_prompt', callback, first)
+
+ def _exc_wrapper(self, keys, options, call_func):
+ """Function wrapper that automatically calls exception callbacks"""
+ def wrapped(*call_args, **call_kwargs):
+ # call call_func first
+ func = call_func
+ callbacks = list(self.get_callbacks('exc'))
+ while True:
+ try:
+ return func(*call_args, **call_kwargs)
+ except errors.ExecutionError, e:
+ if not callbacks:
+ raise
+ # call exc_callback in the next loop
+ callback = callbacks.pop(0)
+ def exc_func(*args, **kwargs):
+ return callback(
+ self, keys, options, e, call_func, *args, **kwargs)
+ func = exc_func
+ return wrapped
class LDAPCreate(BaseLDAPCommand, crud.Create):
"""
@@ -1012,15 +1017,9 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
set(self.obj.default_attributes + entry_attrs.keys())
)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(
- ldap, dn, entry_attrs, attrs_list, *keys, **options
- )
- else:
- dn = callback(
- self, ldap, dn, entry_attrs, attrs_list, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(
+ self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
_check_single_value_attrs(self.params, entry_attrs)
ldap.get_schema()
@@ -1064,11 +1063,8 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
except errors.NotFound:
self.obj.handle_not_found(*keys)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, entry_attrs, *keys, **options)
- else:
- dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+ for callback in self.get_callbacks('post'):
+ dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
entry_attrs['dn'] = dn
@@ -1173,11 +1169,8 @@ class LDAPRetrieve(LDAPQuery):
else:
attrs_list = list(self.obj.default_attributes)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, attrs_list, *keys, **options)
- else:
- dn = callback(self, ldap, dn, attrs_list, *keys, **options)
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, attrs_list, *keys, **options)
try:
(dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
@@ -1189,11 +1182,8 @@ class LDAPRetrieve(LDAPQuery):
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, entry_attrs, *keys, **options)
- else:
- dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+ for callback in self.get_callbacks('post'):
+ dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
entry_attrs['dn'] = dn
@@ -1268,15 +1258,9 @@ class LDAPUpdate(LDAPQuery, crud.Update):
_check_single_value_attrs(self.params, entry_attrs)
_check_empty_attrs(self.obj.params, entry_attrs)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(
- ldap, dn, entry_attrs, attrs_list, *keys, **options
- )
- else:
- dn = callback(
- self, ldap, dn, entry_attrs, attrs_list, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(
+ self, ldap, dn, entry_attrs, attrs_list, *keys, **options)
ldap.get_schema()
_check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.limit_object_classes), entry_attrs.keys(), allow_only=True)
@@ -1323,11 +1307,8 @@ class LDAPUpdate(LDAPQuery, crud.Update):
if options.get('rights', False) and options.get('all', False):
entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, entry_attrs, *keys, **options)
- else:
- dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
+ for callback in self.get_callbacks('post'):
+ dn = callback(self, ldap, dn, entry_attrs, *keys, **options)
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
if self.obj.primary_key and keys[-1] is not None:
@@ -1362,11 +1343,8 @@ class LDAPDelete(LDAPMultiQuery):
nkeys = keys[:-1] + (pkey, )
dn = self.obj.get_dn(*nkeys, **options)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, *nkeys, **options)
- else:
- dn = callback(self, ldap, dn, *nkeys, **options)
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, *nkeys, **options)
def delete_subtree(base_dn):
truncated = True
@@ -1387,11 +1365,8 @@ class LDAPDelete(LDAPMultiQuery):
delete_subtree(dn)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- result = callback(ldap, dn, *nkeys, **options)
- else:
- result = callback(self, ldap, dn, *nkeys, **options)
+ for callback in self.get_callbacks('post'):
+ result = callback(self, ldap, dn, *nkeys, **options)
return result
@@ -1503,13 +1478,8 @@ class LDAPAddMember(LDAPModMember):
dn = self.obj.get_dn(*keys, **options)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, member_dns, failed, *keys, **options)
- else:
- dn = callback(
- self, ldap, dn, member_dns, failed, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
completed = 0
for (attr, objs) in member_dns.iteritems():
@@ -1542,16 +1512,10 @@ class LDAPAddMember(LDAPModMember):
except errors.NotFound:
self.obj.handle_not_found(*keys)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- (completed, dn) = callback(
- ldap, completed, failed, dn, entry_attrs, *keys, **options
- )
- else:
- (completed, dn) = callback(
- self, ldap, completed, failed, dn, entry_attrs, *keys,
- **options
- )
+ for callback in self.get_callbacks('post'):
+ (completed, dn) = callback(
+ self, ldap, completed, failed, dn, entry_attrs, *keys,
+ **options)
entry_attrs['dn'] = dn
self.obj.convert_attribute_members(entry_attrs, *keys, **options)
@@ -1602,13 +1566,8 @@ class LDAPRemoveMember(LDAPModMember):
dn = self.obj.get_dn(*keys, **options)
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, member_dns, failed, *keys, **options)
- else:
- dn = callback(
- self, ldap, dn, member_dns, failed, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, member_dns, failed, *keys, **options)
completed = 0
for (attr, objs) in member_dns.iteritems():
@@ -1644,16 +1603,10 @@ class LDAPRemoveMember(LDAPModMember):
except errors.NotFound:
self.obj.handle_not_found(*keys)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- (completed, dn) = callback(
- ldap, completed, failed, dn, entry_attrs, *keys, **options
- )
- else:
- (completed, dn) = callback(
- self, ldap, completed, failed, dn, entry_attrs, *keys,
- **options
- )
+ for callback in self.get_callbacks('post'):
+ (completed, dn) = callback(
+ self, ldap, completed, failed, dn, entry_attrs, *keys,
+ **options)
entry_attrs['dn'] = dn
@@ -1838,15 +1791,9 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
)
scope = ldap.SCOPE_ONELEVEL
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- (filter, base_dn, scope) = callback(
- ldap, filter, attrs_list, base_dn, scope, *args, **options
- )
- else:
- (filter, base_dn, scope) = callback(
- self, ldap, filter, attrs_list, base_dn, scope, *args, **options
- )
+ for callback in self.get_callbacks('pre'):
+ (filter, base_dn, scope) = callback(
+ self, ldap, filter, attrs_list, base_dn, scope, *args, **options)
try:
(entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
@@ -1857,11 +1804,8 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
except errors.NotFound:
(entries, truncated) = ([], False)
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- truncated = callback(ldap, entries, truncated, *args, **options)
- else:
- truncated = callback(self, ldap, entries, truncated, *args, **options)
+ for callback in self.get_callbacks('post'):
+ truncated = callback(self, ldap, entries, truncated, *args, **options)
if self.sort_result_entries:
if self.obj.primary_key:
@@ -1965,13 +1909,8 @@ class LDAPAddReverseMember(LDAPModReverseMember):
result = self.api.Command[self.show_command](keys[-1])['result']
dn = result['dn']
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, *keys, **options)
- else:
- dn = callback(
- self, ldap, dn, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, *keys, **options)
if options.get('all', False):
attrs_list = ['*'] + self.obj.default_attributes
@@ -2006,16 +1945,10 @@ class LDAPAddReverseMember(LDAPModReverseMember):
except Exception, e:
raise errors.ReverseMemberError(verb=_('added'), exc=str(e))
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- (completed, dn) = callback(
- ldap, completed, failed, dn, entry_attrs, *keys, **options
- )
- else:
- (completed, dn) = callback(
- self, ldap, completed, failed, dn, entry_attrs, *keys,
- **options
- )
+ for callback in self.get_callbacks('post'):
+ (completed, dn) = callback(
+ self, ldap, completed, failed, dn, entry_attrs, *keys,
+ **options)
entry_attrs['dn'] = dn
return dict(
@@ -2072,13 +2005,8 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
result = self.api.Command[self.show_command](keys[-1])['result']
dn = result['dn']
- for callback in self.PRE_CALLBACKS:
- if hasattr(callback, 'im_self'):
- dn = callback(ldap, dn, *keys, **options)
- else:
- dn = callback(
- self, ldap, dn, *keys, **options
- )
+ for callback in self.get_callbacks('pre'):
+ dn = callback(self, ldap, dn, *keys, **options)
if options.get('all', False):
attrs_list = ['*'] + self.obj.default_attributes
@@ -2113,16 +2041,10 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
except Exception, e:
raise errors.ReverseMemberError(verb=_('removed'), exc=str(e))
- for callback in self.POST_CALLBACKS:
- if hasattr(callback, 'im_self'):
- (completed, dn) = callback(
- ldap, completed, failed, dn, entry_attrs, *keys, **options
- )
- else:
- (completed, dn) = callback(
- self, ldap, completed, failed, dn, entry_attrs, *keys,
- **options
- )
+ for callback in self.get_callbacks('post'):
+ (completed, dn) = callback(
+ self, ldap, completed, failed, dn, entry_attrs, *keys,
+ **options)
entry_attrs['dn'] = dn
return dict(
diff --git a/make-lint b/make-lint
index 7ecd59d..f619260 100755
--- a/make-lint
+++ b/make-lint
@@ -51,8 +51,6 @@ class IPATypeChecker(TypeChecker):
'ipalib.plugable.Plugin': ['Command', 'Object', 'Method', 'Property',
'Backend', 'env', 'debug', 'info', 'warning', 'error', 'critical',
'exception', 'context', 'log'],
- 'ipalib.plugins.baseldap.CallbackInterface': ['pre_callback',
- 'post_callback', 'exc_callback'],
'ipalib.plugins.misc.env': ['env'],
'ipalib.parameters.Param': ['cli_name', 'cli_short_name', 'label',
'doc', 'required', 'multivalue', 'primary_key', 'normalizer',
diff --git a/tests/test_xmlrpc/test_baseldap_plugin.py b/tests/test_xmlrpc/test_baseldap_plugin.py
index 0800a5d..6a8501f 100644
--- a/tests/test_xmlrpc/test_baseldap_plugin.py
+++ b/tests/test_xmlrpc/test_baseldap_plugin.py
@@ -24,11 +24,12 @@ Test the `ipalib.plugins.baseldap` module.
from ipalib import errors
from ipalib.plugins import baseldap
+
def test_exc_wrapper():
"""Test the CallbackInterface._exc_wrapper helper method"""
handled_exceptions = []
- class test_callback(baseldap.CallbackInterface):
+ class test_callback(baseldap.BaseLDAPCommand):
"""Fake IPA method"""
def test_fail(self):
self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
@@ -64,3 +65,95 @@ def test_exc_wrapper():
instance.test_fail()
assert handled_exceptions == [None, errors.ExecutionError]
+
+
+def test_callback_registration():
+ class callbacktest_base(baseldap.CallbackInterface):
+ _callback_registry = dict(test={})
+
+ def test_callback(self, param):
+ messages.append(('Base test_callback', param))
+
+ def registered_callback(self, param):
+ messages.append(('Base registered callback', param))
+ callbacktest_base.register_callback('test', registered_callback)
+
+ class SomeClass(object):
+ def registered_callback(self, command, param):
+ messages.append(('Registered callback from another class', param))
+ callbacktest_base.register_callback('test', SomeClass().registered_callback)
+
+ class callbacktest_subclass(callbacktest_base):
+ pass
+
+ def subclass_callback(self, param):
+ messages.append(('Subclass registered callback', param))
+ callbacktest_subclass.register_callback('test', subclass_callback)
+
+
+ messages = []
+ instance = callbacktest_base()
+ for callback in instance.get_callbacks('test'):
+ callback(instance, 42)
+ assert messages == [
+ ('Base test_callback', 42),
+ ('Base registered callback', 42),
+ ('Registered callback from another class', 42)]
+
+ messages = []
+ instance = callbacktest_subclass()
+ for callback in instance.get_callbacks('test'):
+ callback(instance, 42)
+ assert messages == [
+ ('Base test_callback', 42),
+ ('Subclass registered callback', 42)]
+
+
+def test_exc_callback_registration():
+ messages = []
+ class callbacktest_base(baseldap.BaseLDAPCommand):
+ """A method superclass with an exception callback"""
+ def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+ """Let the world know we saw the error, but don't handle it"""
+ messages.append('Base exc_callback')
+ raise exc
+
+ def test_fail(self):
+ """Raise a handled exception"""
+ try:
+ self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
+ except Exception:
+ pass
+
+ def fail(self, *args, **kwargs):
+ """Raise an error"""
+ raise errors.ExecutionError('failure')
+
+ base_instance = callbacktest_base()
+
+ class callbacktest_subclass(callbacktest_base):
+ pass
+
+ @callbacktest_subclass.register_exc_callback
+ def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+ """Subclass's private exception callback"""
+ messages.append('Subclass registered callback')
+ raise exc
+
+ subclass_instance = callbacktest_subclass()
+
+ # Make sure exception in base class is only handled by the base class
+ base_instance.test_fail()
+ assert messages == ['Base exc_callback']
+
+
+ @callbacktest_base.register_exc_callback
+ def exc_callback(self, keys, options, exc, call_func, *args, **kwargs):
+ """Callback on super class; doesn't affect the subclass"""
+ messages.append('Superclass registered callback')
+ raise exc
+
+ # Make sure exception in subclass is only handled by both
+ messages = []
+ subclass_instance.test_fail()
+ assert messages == ['Base exc_callback', 'Subclass registered callback']