summaryrefslogtreecommitdiffstats
path: root/ipapython
diff options
context:
space:
mode:
authorPetr Vobornik <pvoborni@redhat.com>2015-03-25 13:39:43 +0100
committerPetr Vobornik <pvoborni@redhat.com>2015-04-14 19:31:54 +0200
commit11bd9d96f191066f7ba760549f00179c128a9787 (patch)
tree3d20a02cae3af4600793c6a339a2ceeceb9a6a03 /ipapython
parent0a1a3d73120bdf20ae05bcf663f14ca1a8b02c25 (diff)
downloadfreeipa-11bd9d96f191066f7ba760549f00179c128a9787.tar.gz
freeipa-11bd9d96f191066f7ba760549f00179c128a9787.tar.xz
freeipa-11bd9d96f191066f7ba760549f00179c128a9787.zip
performance: faster DN implementation
DN code was optimized to be faster if DNs are created from string. This is the major use case, since most DNs come from LDAP. With this patch, DN creation is almost 8-10x faster (with 30K-100K DNs). Second mojor use case - deepcopy in LDAPEntry is about 20x faster - done by custom __deepcopy__ function. The major change is that DN is no longer internally composed of RDNs and AVAs but it rather keeps the data in open ldap format - the same as output of str2dn function. Therefore, for immutable DNs, no other transformations are required on instantiation. The format is: DN: [RDN, RDN,...] RDN: [AVA, AVA,...] AVA: ['utf-8 encoded str - attr', 'utf-8 encode str -value', FLAG] FLAG: int Further indexing of DN object constructs an RDN which is just an encapsulation of the RDN part of open ldap representation. Indexing of RDN constructs AVA in the same fashion. Obtained EditableAVA, EditableRDN from EditableDN shares the respected lists of the open ldap repr. so that the change of value or attr is reflected in parent object. Reviewed-By: Petr Viktorin <pviktori@redhat.com>
Diffstat (limited to 'ipapython')
-rw-r--r--ipapython/dn.py591
1 files changed, 294 insertions, 297 deletions
diff --git a/ipapython/dn.py b/ipapython/dn.py
index 834291fbe..5b6570770 100644
--- a/ipapython/dn.py
+++ b/ipapython/dn.py
@@ -497,6 +497,97 @@ def _adjust_indices(start, end, length):
return start, end
+
+def _normalize_ava_input(val):
+ if not isinstance(val, basestring):
+ val = unicode(val).encode('utf-8')
+ elif isinstance(val, unicode):
+ val = val.encode('utf-8')
+ return val
+
+
+def str2rdn(value):
+ try:
+ rdns = str2dn(value.encode('utf-8'))
+ except DECODING_ERROR:
+ raise ValueError("malformed AVA string = \"%s\"" % value)
+ if len(rdns) != 1:
+ raise ValueError("multiple RDN's specified by \"%s\"" % (value))
+ return rdns[0]
+
+
+def get_ava(*args, **kwds):
+ """
+ Get AVA from args in open ldap format(raw). Optimized for construction
+ from openldap format.
+
+ Allowed formats of argument list:
+ 1) three args - open ldap format (attr and value have to be utf-8 encoded):
+ a) ['attr', 'value', 0]
+ 2) two args:
+ a) ['attr', 'value']
+ 3) one arg:
+ a) [('attr', 'value')]
+ b) [['attr', 'value']]
+ c) [AVA(..)]
+ d) ['attr=value']
+ """
+ ava = None
+ l = len(args)
+ if l == 3: # raw values - constructed FROM RDN
+ if kwds.get('mutable', False):
+ ava = args
+ else:
+ ava = (args[0], args[1], args[2])
+ elif l == 2: # user defined values
+ ava = [_normalize_ava_input(args[0]), _normalize_ava_input(args[1]), 0]
+ elif l == 1: # slow mode, tuple, string,
+ arg = args[0]
+ if isinstance(arg, AVA):
+ ava = arg.to_openldap()
+ elif isinstance(arg, (tuple, list)):
+ if len(arg) != 2:
+ raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (arg))
+ ava = [_normalize_ava_input(arg[0]), _normalize_ava_input(arg[1]), 0]
+ elif isinstance(arg, basestring):
+ rdn = str2rdn(arg)
+ if len(rdn) > 1:
+ raise TypeError("multiple AVA's specified by \"%s\"" % (arg))
+ ava = list(rdn[0])
+ else:
+ raise TypeError("with 1 argument, argument must be str, unicode, tuple or list, got %s instead" %
+ arg.__class__.__name__)
+ else:
+ raise TypeError("invalid number of arguments. 1-3 allowed")
+ return ava
+
+
+def sort_avas(rdn):
+ if len(rdn) <= 1:
+ return
+ rdn.sort(cmp=cmp_avas)
+
+
+def cmp_avas(a, b):
+ r = cmp(a[0].lower(), b[0].lower())
+ if r == 0:
+ r = cmp(a[1].lower(), b[1].lower())
+ return r
+
+
+def cmp_rdns(a, b):
+
+ l = len(a)
+ r = cmp(l, len(b))
+ if r != 0:
+ return r
+
+ for i, ava_a in enumerate(a):
+ r = cmp_avas(ava_a, b[i])
+ if r != 0:
+ return r
+ return 0
+
class AVA(object):
'''
AVA(arg0, ...)
@@ -552,100 +643,51 @@ class AVA(object):
syntax with proper escaping.
'''
is_mutable = False
- flags = 0
def __init__(self, *args, **kwds):
- if len(args) == 1:
- arg = args[0]
- if isinstance(arg, AVA):
- ava = (arg.attr, arg.value)
- elif isinstance(arg, basestring):
- try:
- rdns = str2dn(arg.encode('utf-8'))
- except DECODING_ERROR:
- raise ValueError("malformed AVA string = \"%s\"" % arg)
- if len(rdns) != 1:
- raise ValueError("multiple RDN's specified by \"%s\"" % (arg))
- rdn = rdns[0]
- if len(rdn) != 1:
- raise ValueError("multiple AVA's specified by \"%s\"" % (arg))
- ava = rdn[0]
- elif isinstance(arg, (tuple, list)):
- ava = arg
- if len(ava) != 2:
- raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (ava))
- else:
- raise TypeError("with 1 argument, argument must be str,unicode,tuple or list, got %s instead" % \
- arg.__class__.__name__)
-
- attr = ava[0]
- value = ava[1]
- elif len(args) == 2:
- attr = args[0]
- value = args[1]
- else:
- raise TypeError("takes 1 or 2 arguments (%d given)" % (len(args)))
-
- self._set_attr(attr)
- self._set_value(value)
+ self._ava = get_ava(*args, **{'mutable': self.is_mutable})
def _get_attr(self):
- return self._attr_unicode
+ return self._ava[0].decode('utf-8')
def _set_attr(self, new_attr):
- # Scalars only
- if isinstance(new_attr, (tuple, list)):
- raise TypeError("attr must be scalar, got %s" % type(new_attr))
-
try:
- if isinstance(new_attr, unicode):
- self._attr_unicode = new_attr
- elif isinstance(new_attr, str):
- self._attr_unicode = new_attr.decode('utf-8')
- else:
- self._attr_unicode = unicode(new_attr)
+ self._ava[0] = _normalize_ava_input(new_attr)
except Exception, e:
- raise ValueError('unable to convert attr "%s" to unicode: %s' % (new_attr, e))
+ raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
- attr = property(_get_attr)
+ attr = property(_get_attr)
def _get_value(self):
- return self._value_unicode
+ return self._ava[1].decode('utf-8')
def _set_value(self, new_value):
- # Scalars only
- if isinstance(new_value, (tuple, list)):
- raise TypeError("value must be scalar, got %s" % type(new_value))
-
try:
- if isinstance(new_value, unicode):
- self._value_unicode = new_value
- elif isinstance(new_value, str):
- self._value_unicode = new_value.decode('utf-8')
- else:
- self._value_unicode = unicode(new_value)
+ self._ava[1] = _normalize_ava_input(new_value)
except Exception, e:
- raise ValueError('unable to convert value "%s" to unicode: %s' % (new_value, e))
+ raise ValueError('unable to convert value "%s": %s' % (new_value, e))
value = property(_get_value)
- def _to_openldap(self):
- return [[(self._attr_unicode.encode('utf-8'), self._value_unicode.encode('utf-8'), self.flags)]]
+ def to_openldap(self):
+ return list(self._ava)
def __str__(self):
- return dn2str(self._to_openldap())
+ return dn2str([[self.to_openldap()]])
def __repr__(self):
return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
def __getitem__(self, key):
- if isinstance(key, basestring):
- if key == self._attr_unicode:
- return self._value_unicode
- raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
+
+ if key == 0:
+ return self.attr
+ elif key == 1:
+ return self.value
+ elif key == self.attr:
+ return self.value
else:
- raise TypeError("unsupported type for AVA indexing, must be basestring; not %s" % \
- (key.__class__.__name__))
+ raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
def __hash__(self):
# Hash is computed from AVA's string representation because it's immutable.
@@ -682,8 +724,7 @@ class AVA(object):
return False
# Perform comparison between objects of same type
- return self._attr_unicode.lower() == other.attr.lower() and \
- self._value_unicode.lower() == other.value.lower()
+ return cmp_avas(self._ava, other._ava) == 0
def __ne__(self, other):
return not self.__eq__(other)
@@ -694,11 +735,7 @@ class AVA(object):
if not isinstance(other, AVA):
raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
- result = cmp(self._attr_unicode.lower(), other.attr.lower())
- if result != 0:
- return result
- result = cmp(self._value_unicode.lower(), other.value.lower())
- return result
+ return cmp_avas(self._ava, other._ava)
class EditableAVA(AVA):
'''
@@ -826,113 +863,96 @@ class RDN(object):
'''
is_mutable = False
- flags = 0
AVA_type = AVA
def __init__(self, *args, **kwds):
- self.avas = self._avas_from_sequence(args)
- self.avas.sort()
-
- def _ava_from_value(self, value):
- if isinstance(value, AVA):
- return self.AVA_type(value.attr, value.value)
- elif isinstance(value, RDN):
- avas = []
- for ava in value.avas:
- avas.append(self.AVA_type(ava.attr, ava.value))
- if len(avas) == 1:
- return avas[0]
- else:
- return avas
- elif isinstance(value, basestring):
- try:
- rdns = str2dn(value.encode('utf-8'))
- if len(rdns) != 1:
- raise ValueError("multiple RDN's specified by \"%s\"" % (value))
- rdn = rdns[0]
- if len(rdn) == 1:
- return self.AVA_type(rdn[0][0], rdn[0][1])
- else:
- avas = []
- for ava_tuple in rdn:
- avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1]))
- return avas
- except DECODING_ERROR:
- raise ValueError("malformed RDN string = \"%s\"" % value)
- elif isinstance(value, (tuple, list)):
- if len(value) != 2:
- raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value))
- return self.AVA_type(value)
- else:
- raise TypeError("must be str,unicode,tuple, or AVA, got %s instead" % \
- value.__class__.__name__)
-
+ self._avas = self._avas_from_sequence(args, kwds.get('raw', False))
- def _avas_from_sequence(self, seq):
+ def _avas_from_sequence(self, args, raw=False):
avas = []
+ sort = 0
+ ava_count = len(args)
- for item in seq:
- ava = self._ava_from_value(item)
- if isinstance(ava, list):
- avas.extend(ava)
- else:
- avas.append(ava)
+ if raw: # fast raw mode
+ try:
+ if self.is_mutable:
+ avas = args
+ else:
+ for arg in args:
+ avas.append((arg[0], arg[1], arg[2]))
+ except KeyError as e:
+ raise TypeError('all AVA values in RAW mode must be in open ldap format')
+ elif ava_count == 1 and isinstance(args[0], basestring):
+ avas = str2rdn(args[0])
+ sort = 1
+ elif ava_count == 1 and isinstance(args[0], RDN):
+ avas = args[0].to_openldap()
+ elif ava_count > 0:
+ sort = 1
+ for arg in args:
+ avas.append(get_ava(arg))
+ if sort:
+ sort_avas(avas)
return avas
- def _to_openldap(self):
- return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in self.avas]]
+ def to_openldap(self):
+ return [list(a) for a in self._avas]
def __str__(self):
- return dn2str(self._to_openldap())
+ return dn2str([self.to_openldap()])
def __repr__(self):
return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
+ def _get_ava(self, ava):
+ return self.AVA_type(*ava)
+
def _next(self):
- for ava in self.avas:
- yield ava
+ for ava in self._avas:
+ yield self._get_ava(ava)
def __iter__(self):
return self._next()
def __len__(self):
- return len(self.avas)
+ return len(self._avas)
def __getitem__(self, key):
- if isinstance(key, (int, long, slice)):
- return self.avas[key]
+ if isinstance(key, (int, long)):
+ return self._get_ava(self._avas[key])
+ if isinstance(key, slice):
+ return [self._get_ava(ava) for ava in self._avas[key]]
elif isinstance(key, basestring):
- for ava in self.avas:
- if key == ava.attr:
- return ava.value
+ for ava in self._avas:
+ if key == ava[0].decode('utf-8'):
+ return ava[1].decode('utf-8')
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
(key.__class__.__name__))
def _get_attr(self):
- if len(self.avas) == 0:
+ if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- return self.avas[0].attr
+ return self._avas[0][0].decode('utf-8')
def _set_attr(self, new_attr):
- if len(self.avas) == 0:
+ if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- self.avas[0].attr = new_attr
+ self._avas[0][0] = unicode(new_attr).encode('utf-8')
attr = property(_get_attr)
def _get_value(self):
- if len(self.avas) == 0:
+ if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- return self.avas[0].value
+ return self._avas[0][1].decode('utf-8')
def _set_value(self, new_value):
- if len(self.avas) == 0:
+ if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
-
- self.avas[0].value = new_value
+ self._avas[0][1] = unicode(new_value).encode('utf-8')
value = property(_get_value)
@@ -959,7 +979,7 @@ class RDN(object):
return False
# Perform comparison between objects of same type
- return self.avas == other.avas
+ return cmp_rdns(self._avas, other._avas) == 0
def __ne__(self, other):
return not self.__eq__(other)
@@ -968,32 +988,23 @@ class RDN(object):
if not isinstance(other, RDN):
raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
- result = cmp(len(self), len(other))
- if result != 0:
- return result
- i = 0
- while i < len(self):
- result = cmp(self[i], other[i])
- if result != 0:
- return result
- i += 1
- return 0
+ return cmp_rdns(self._avas, other._avas)
def __add__(self, other):
result = self.__class__(self)
if isinstance(other, RDN):
- for ava in other.avas:
- result.avas.append(self.AVA_type(ava.attr, ava.value))
+ for ava in other._avas:
+ result._avas.append((ava[0], ava[1], ava[2]))
elif isinstance(other, AVA):
- result.avas.append(self.AVA_type(other.attr, other.value))
+ result._avas.append(other.to_openldap())
elif isinstance(other, basestring):
rdn = self.__class__(other)
- for ava in rdn.avas:
- result.avas.append(self.AVA_type(ava.attr, ava.value))
+ for ava in rdn._avas:
+ result._avas.append((ava[0], ava[1], ava[2]))
else:
raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__))
- result.avas.sort()
+ sort_avas(result._avas)
return result
class EditableRDN(RDN):
@@ -1016,24 +1027,22 @@ class EditableRDN(RDN):
AVA_type = EditableAVA
def __setitem__(self, key, value):
+
if isinstance(key, (int, long)):
- new_ava = self._ava_from_value(value)
- if isinstance(new_ava, list):
- raise TypeError("cannot assign multiple AVA's to single entry")
- self.avas[key] = new_ava
+ self._avas[key] = get_ava(value)
elif isinstance(key, slice):
avas = self._avas_from_sequence(value)
- self.avas[key] = avas
+ self._avas[key] = avas
elif isinstance(key, basestring):
- new_ava = self._ava_from_value(value)
- if isinstance(new_ava, list):
+ if isinstance(value, list):
raise TypeError("cannot assign multiple AVA's to single entry")
+ new_ava = get_ava(value)
found = False
i = 0
- while i < len(self.avas):
- if key == self.avas[i].attr:
+ while i < len(self._avas):
+ if key == self._avas[i][0].decode('utf-8'):
found = True
- self.avas[i] = new_ava
+ self._avas[i] = new_ava
break
i += 1
if not found:
@@ -1041,7 +1050,7 @@ class EditableRDN(RDN):
else:
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
(key.__class__.__name__))
- self.avas.sort()
+ sort_avas(self._avas)
attr = property(RDN._get_attr, RDN._set_attr)
value = property(RDN._get_value, RDN._set_value)
@@ -1051,18 +1060,15 @@ class EditableRDN(RDN):
# If __iadd__ is not available Python will emulate += by
# replacing the lhs object with the result of __add__ (if available).
if isinstance(other, RDN):
- for ava in other.avas:
- self.avas.append(self.AVA_type(ava.attr, ava.value))
+ self._avas.extend(other.to_openldap())
elif isinstance(other, AVA):
- self.avas.append(self.AVA_type(other.attr, other.value))
+ self._avas.append(other.to_openldap())
elif isinstance(other, basestring):
- rdn = self.__class__(other)
- for ava in rdn.avas:
- self.avas.append(self.AVA_type(ava.attr, ava.value))
+ self._avas.extend(self._avas_from_sequence([other]))
else:
raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__))
- self.avas.sort()
+ sort_avas(self._avas)
return self
class DN(object):
@@ -1213,72 +1219,74 @@ class DN(object):
'''
is_mutable = False
- flags = 0
AVA_type = AVA
RDN_type = RDN
def __init__(self, *args, **kwds):
self.rdns = self._rdns_from_sequence(args)
- def _rdn_from_value(self, value):
- if isinstance(value, RDN):
- return self.RDN_type(value)
- elif isinstance(value, DN):
- rdns = []
- for rdn in value.rdns:
- rdns.append(self.RDN_type(rdn))
- if len(rdns) == 1:
- return rdns[0]
- else:
- return rdns
- elif isinstance(value, basestring):
- rdns = []
+ def _copy_rdns(self, rdns=None):
+ if not rdns:
+ rdns = self.rdns
+ return [[list(a) for a in rdn] for rdn in rdns]
+
+ def _rdns_from_value(self, value):
+ if isinstance(value, basestring):
try:
- dn_list = str2dn(value.encode('utf-8'))
- for rdn_list in dn_list:
- avas = []
- for ava_tuple in rdn_list:
- avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1]))
- rdn = self.RDN_type(*avas)
- rdns.append(rdn)
+ if isinstance(value, unicode):
+ value = value.encode('utf-8')
+ rdns = str2dn(value)
+ if self.is_mutable:
+ self._copy_rdns(rdns) # AVAs to be list instead of tuple
except DECODING_ERROR:
raise ValueError("malformed RDN string = \"%s\"" % value)
- if len(rdns) == 1:
- return rdns[0]
- else:
- return rdns
- elif isinstance(value, (tuple, list)):
- if len(value) != 2:
- raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value))
- rdn = self.RDN_type(value)
- return rdn
+ for rdn in rdns:
+ sort_avas(rdn)
+ elif isinstance(value, DN):
+ rdns = value._copy_rdns()
+ elif isinstance(value, (tuple, list, AVA)):
+ ava = get_ava(value)
+ rdns = [[ava]]
+ elif isinstance(value, RDN):
+ rdns = [value.to_openldap()]
else:
- raise TypeError("must be str,unicode,tuple, or RDN, got %s instead" % \
- value.__class__.__name__)
+ raise TypeError("must be str, unicode, tuple, or RDN or DN, got %s instead" %
+ type(value))
+ return rdns
def _rdns_from_sequence(self, seq):
rdns = []
for item in seq:
- rdn = self._rdn_from_value(item)
- if isinstance(rdn, list):
- rdns.extend(rdn)
- else:
- rdns.append(rdn)
+ rdn = self._rdns_from_value(item)
+ rdns.extend(rdn)
return rdns
- def _to_openldap(self):
- return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in rdn] for rdn in self.rdns]
+ def __deepcopy__(self, memo):
+ if self.is_mutable:
+ cls = self.__class__
+ clone = cls.__new__(cls)
+ clone.rdns = self._copy_rdns()
+ return clone
+ return self
+
+ def _get_rdn(self, rdn):
+ return self.RDN_type(*rdn, **{'raw': True})
def __str__(self):
- return dn2str(self._to_openldap())
+ try:
+ return dn2str(self.rdns)
+ except Exception, e:
+ print len(self.rdns)
+ print self.rdns
+ raise
def __repr__(self):
return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__())
def _next(self):
for rdn in self.rdns:
- yield rdn
+ yield self._get_rdn(rdn)
def __iter__(self):
return self._next()
@@ -1287,12 +1295,20 @@ class DN(object):
return len(self.rdns)
def __getitem__(self, key):
- if isinstance(key, (int, long, slice)):
- return self.rdns[key]
+ if isinstance(key, (int, long)):
+ return self._get_rdn(self.rdns[key])
+ if isinstance(key, slice):
+ cls = self.__class__
+ new_dn = cls.__new__(cls)
+ new_dn.rdns = self.rdns[key]
+ if self.is_mutable:
+ new_dn.rdns = self._copy_rdns(new_dn.rdns)
+ return new_dn
elif isinstance(key, basestring):
for rdn in self.rdns:
- if key == rdn.attr:
- return rdn.value
+ for ava in rdn:
+ if key == ava[0].decode('utf-8'):
+ return ava[1].decode('utf-8')
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
@@ -1305,11 +1321,16 @@ class DN(object):
# hash value between two objects which compare as equal but
# differ in case must yield the same hash value.
- return hash(str(self).lower())
+ str_dn = ';,'.join([
+ '++'.join(
+ ['=='.join((atype, avalue or '')) for atype,avalue,dummy in rdn]
+ ) for rdn in self.rdns
+ ])
+ return hash(str_dn.lower())
def __eq__(self, other):
- # Try coercing string to DN, if successful compare to coerced object
- if isinstance(other, basestring):
+ # Try coercing to DN, if successful compare to coerced object
+ if isinstance(other, (basestring, RDN, AVA)):
try:
other_dn = DN(other)
return self.__eq__(other_dn)
@@ -1321,7 +1342,7 @@ class DN(object):
return False
# Perform comparison between objects of same type
- return self.rdns == other.rdns
+ return self.__cmp__(other) == 0
def __ne__(self, other):
return not self.__eq__(other)
@@ -1336,31 +1357,21 @@ class DN(object):
return self._cmp_sequence(other, 0, len(self))
def _cmp_sequence(self, pattern, self_start, pat_len):
+
self_idx = self_start
+ self_len = len(self)
pat_idx = 0
+ # and self_idx < self_len
while pat_idx < pat_len:
- result = cmp(self[self_idx], pattern[pat_idx])
- if result != 0:
- return result
+ r = cmp_rdns(self.rdns[self_idx], pattern.rdns[pat_idx])
+ if r != 0:
+ return r
self_idx += 1
pat_idx += 1
return 0
def __add__(self, other):
- result = self.__class__(self)
- if isinstance(other, DN):
- for rdn in other.rdns:
- result.rdns.append(self.RDN_type(rdn))
- elif isinstance(other, RDN):
- result.rdns.append(self.RDN_type(other))
- elif isinstance(other, basestring):
- dn = self.__class__(other)
- for rdn in dn.rdns:
- result.rdns.append(rdn)
- else:
- raise TypeError("expected DN, RDN or basestring but got %s" % (other.__class__.__name__))
-
- return result
+ return self.__class__(self, other)
# The implementation of startswith, endswith, tailmatch, adjust_indices
# was based on the Python's stringobject.c implementation
@@ -1402,10 +1413,10 @@ class DN(object):
arguments. Returns 0 if not found and 1 if found.
'''
+ if isinstance(pattern, RDN):
+ pattern = DN(pattern)
if isinstance(pattern, DN):
pat_len = len(pattern)
- elif isinstance(pattern, RDN):
- pat_len = 1
else:
raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
@@ -1423,16 +1434,16 @@ class DN(object):
if end-pat_len >= start:
start = end - pat_len
- if isinstance(pattern, DN):
- if end-start >= pat_len:
- return not self._cmp_sequence(pattern, start, pat_len)
- return 0
- else:
- return self.rdns[start] == pattern
+ if end-start >= pat_len:
+ return not self._cmp_sequence(pattern, start, pat_len)
+ return 0
+
def __contains__(self, other):
'Return the outcome of the test other in self. Note the reversed operands.'
+ if isinstance(other, RDN):
+ other = DN(other)
if isinstance(other, DN):
other_len = len(other)
end = len(self) - other_len
@@ -1443,16 +1454,13 @@ class DN(object):
return True
i += 1
return False
-
- elif isinstance(other, RDN):
- return other in self.rdns
else:
raise TypeError("expected DN or RDN but got %s" % (other.__class__.__name__))
def find(self, pattern, start=None, end=None):
'''
- Return the lowest index in the DN where pattern DN (or RDN) is found,
+ Return the lowest index in the DN where pattern DN is found,
such that pattern is contained in the range [start, end]. Optional
arguments start and end are interpreted as in slice notation. Return
-1 if pattern is not found.
@@ -1460,10 +1468,8 @@ class DN(object):
if isinstance(pattern, DN):
pat_len = len(pattern)
- elif isinstance(pattern, RDN):
- pat_len = 1
else:
- raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
+ raise TypeError("expected DN but got %s" % (pattern.__class__.__name__))
self_len = len(self)
@@ -1476,19 +1482,14 @@ class DN(object):
i = start
stop = max(start, end - pat_len)
- if isinstance(pattern, DN):
- while i <= stop:
- result = self._cmp_sequence(pattern, i, pat_len)
- if result == 0:
- return i
- i += 1
- return -1
- else:
- while i <= stop:
- if self.rdns[i] == pattern:
- return i
- i += 1
- return -1
+
+ while i <= stop:
+ result = self._cmp_sequence(pattern, i, pat_len)
+ if result == 0:
+ return i
+ i += 1
+ return -1
+
def index(self, pattern, start=None, end=None):
'''
@@ -1502,7 +1503,7 @@ class DN(object):
def rfind(self, pattern, start=None, end=None):
'''
- Return the highest index in the DN where pattern DN (or RDN) is found,
+ Return the highest index in the DN where pattern DN is found,
such that pattern is contained in the range [start, end]. Optional
arguments start and end are interpreted as in slice notation. Return
-1 if pattern is not found.
@@ -1510,10 +1511,8 @@ class DN(object):
if isinstance(pattern, DN):
pat_len = len(pattern)
- elif isinstance(pattern, RDN):
- pat_len = 1
else:
- raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__))
+ raise TypeError("expected DN but got %s" % (pattern.__class__.__name__))
self_len = len(self)
@@ -1526,19 +1525,13 @@ class DN(object):
i = max(start, min(end, self_len - pat_len))
stop = start
- if isinstance(pattern, DN):
- while i >= stop:
- result = self._cmp_sequence(pattern, i, pat_len)
- if result == 0:
- return i
- i -= 1
- return -1
- else:
- while i >= stop:
- if self.rdns[i] == pattern:
- return i
- i -= 1
- return -1
+
+ while i >= stop:
+ result = self._cmp_sequence(pattern, i, pat_len)
+ if result == 0:
+ return i
+ i -= 1
+ return -1
def rindex(self, pattern, start=None, end=None):
'''
@@ -1573,23 +1566,23 @@ class EditableDN(DN):
def __setitem__(self, key, value):
if isinstance(key, (int, long)):
- new_rdn = self._rdn_from_value(value)
- if isinstance(new_rdn, list):
+ new_rdns = self._rdns_from_value(value)
+ if len(new_rdns) > 1:
raise TypeError("cannot assign multiple RDN's to single entry")
- self.rdns[key] = new_rdn
+ self.rdns[key] = new_rdns[0]
elif isinstance(key, slice):
rdns = self._rdns_from_sequence(value)
self.rdns[key] = rdns
elif isinstance(key, basestring):
- new_rdn = self._rdn_from_value(value)
- if isinstance(new_rdn, list):
+ new_rdns = self._rdns_from_value(value)
+ if len(new_rdns) > 1:
raise TypeError("cannot assign multiple values to single entry")
found = False
i = 0
while i < len(self.rdns):
- if key == self.rdns[i].attr:
+ if key == self.rdns[i][0][0].decode('utf-8'):
found = True
- self.rdns[i] = new_rdn
+ self.rdns[i] = new_rdns[0]
break
i += 1
if not found:
@@ -1602,10 +1595,9 @@ class EditableDN(DN):
# If __iadd__ is not available Python will emulate += by
# replacing the lhs object with the result of __add__ (if available).
if isinstance(other, DN):
- for rdn in other.rdns:
- self.rdns.append(self.RDN_type(rdn))
+ self.rdns.extend(other._copy_rdns())
elif isinstance(other, RDN):
- self.rdns.append(self.RDN_type(other))
+ self.rdns.append(other.to_openldap())
elif isinstance(other, basestring):
dn = self.__class__(other)
self.__iadd__(dn)
@@ -1627,7 +1619,11 @@ class EditableDN(DN):
for slice indices.
'''
- self.rdns.insert(i, self._rdn_from_value(x))
+ rdns = self._rdns_from_value(x)
+ if len(rdns) > 1:
+ raise TypeError("cannot assign multiple RDN's to single entry")
+
+ self.rdns.insert(i, rdns[0])
def replace(self, old, new, count=sys.maxsize):
'''
@@ -1656,3 +1652,4 @@ class EditableDN(DN):
start = index + pat_len
return n_replaced
+