diff options
Diffstat (limited to 'ipapython/ipaldap.py')
-rw-r--r-- | ipapython/ipaldap.py | 68 |
1 files changed, 50 insertions, 18 deletions
diff --git a/ipapython/ipaldap.py b/ipapython/ipaldap.py index 2965ba4a5..405f1ee2f 100644 --- a/ipapython/ipaldap.py +++ b/ipapython/ipaldap.py @@ -60,6 +60,11 @@ AUTOBIND_AUTO = 1 AUTOBIND_ENABLED = 2 AUTOBIND_DISABLED = 3 +TRUNCATED_SIZE_LIMIT = object() +TRUNCATED_TIME_LIMIT = object() +TRUNCATED_ADMIN_LIMIT = object() + + def unicode_from_utf8(val): ''' val is a UTF-8 encoded string, return a unicode object. @@ -971,11 +976,11 @@ class LDAPClient(object): except ldap.OBJECT_CLASS_VIOLATION: raise errors.ObjectclassViolation(info=info) except ldap.ADMINLIMIT_EXCEEDED: - raise errors.LimitsExceeded() + raise errors.AdminLimitExceeded() except ldap.SIZELIMIT_EXCEEDED: - raise errors.LimitsExceeded() + raise errors.SizeLimitExceeded() except ldap.TIMELIMIT_EXCEEDED: - raise errors.LimitsExceeded() + raise errors.TimeLimitExceeded() except ldap.NOT_ALLOWED_ON_RDN: raise errors.NotAllowedOnRDN(attr=info) except ldap.FILTER_ERROR: @@ -1003,6 +1008,20 @@ class LDAPClient(object): 'Unhandled LDAPError: %s: %s' % (type(e).__name__, str(e))) raise errors.DatabaseError(desc=desc, info=info) + @staticmethod + def handle_truncated_result(truncated): + if not truncated: + return + + if truncated is TRUNCATED_ADMIN_LIMIT: + raise errors.AdminLimitExceeded() + elif truncated is TRUNCATED_SIZE_LIMIT: + raise errors.SizeLimitExceeded() + elif truncated is TRUNCATED_TIME_LIMIT: + raise errors.TimeLimitExceeded() + else: + raise errors.LimitsExceeded() + @property def schema(self): """schema associated with this LDAP server""" @@ -1249,7 +1268,7 @@ class LDAPClient(object): return self.combine_filters(flts, rules) def get_entries(self, base_dn, scope=ldap.SCOPE_SUBTREE, filter=None, - attrs_list=None): + attrs_list=None, **kwargs): """Return a list of matching entries. :raises: errors.LimitsExceeded if the list is truncated by the server @@ -1260,13 +1279,21 @@ class LDAPClient(object): :param scope: search scope, see LDAP docs (default ldap2.SCOPE_SUBTREE) :param filter: LDAP filter to apply :param attrs_list: ist of attributes to return, all if None (default) - - Use the find_entries method for more options. + :param kwargs: additional keyword arguments. See find_entries method + for their description. """ entries, truncated = self.find_entries( base_dn=base_dn, scope=scope, filter=filter, attrs_list=attrs_list) - if truncated: - raise errors.LimitsExceeded() + try: + self.handle_truncated_result(truncated) + except errors.LimitsExceeded as e: + self.log.error( + "{} while getting entries (base DN: {}, filter: {})".format( + e, base_dn, filter + ) + ) + raise + return entries def find_entries(self, filter=None, attrs_list=None, base_dn=None, @@ -1357,6 +1384,15 @@ class LDAPClient(object): break else: cookie = '' + except ldap.ADMINLIMIT_EXCEEDED: + truncated = TRUNCATED_ADMIN_LIMIT + break + except ldap.SIZELIMIT_EXCEEDED: + truncated = TRUNCATED_SIZE_LIMIT + break + except ldap.TIMELIMIT_EXCEEDED: + truncated = TRUNCATED_TIME_LIMIT + break except ldap.LDAPError as e: # If paged search is in progress, try to cancel it if paged_search and cookie: @@ -1402,15 +1438,13 @@ class LDAPClient(object): search_kw = {attr: value, 'objectClass': object_class} filter = self.make_filter(search_kw, rules=self.MATCH_ALL) - (entries, truncated) = self.find_entries(filter, attrs_list, base_dn) + entries = self.get_entries( + base_dn, filter=filter, attrs_list=attrs_list) if len(entries) > 1: raise errors.SingleMatchExpected(found=len(entries)) - else: - if truncated: - raise errors.LimitsExceeded() - else: - return entries[0] + + return entries[0] def get_entry(self, dn, attrs_list=None, time_limit=None, size_limit=None): @@ -1423,13 +1457,11 @@ class LDAPClient(object): assert isinstance(dn, DN) - (entries, truncated) = self.find_entries( - None, attrs_list, dn, self.SCOPE_BASE, time_limit=time_limit, + entries = self.get_entries( + dn, self.SCOPE_BASE, None, attrs_list, time_limit=time_limit, size_limit=size_limit ) - if truncated: - raise errors.LimitsExceeded() return entries[0] def add_entry(self, entry): |