diff options
Diffstat (limited to 'ipapython/dn.py')
-rw-r--r-- | ipapython/dn.py | 591 |
1 files changed, 294 insertions, 297 deletions
diff --git a/ipapython/dn.py b/ipapython/dn.py index 834291fbe..5b6570770 100644 --- a/ipapython/dn.py +++ b/ipapython/dn.py @@ -497,6 +497,97 @@ def _adjust_indices(start, end, length): return start, end + +def _normalize_ava_input(val): + if not isinstance(val, basestring): + val = unicode(val).encode('utf-8') + elif isinstance(val, unicode): + val = val.encode('utf-8') + return val + + +def str2rdn(value): + try: + rdns = str2dn(value.encode('utf-8')) + except DECODING_ERROR: + raise ValueError("malformed AVA string = \"%s\"" % value) + if len(rdns) != 1: + raise ValueError("multiple RDN's specified by \"%s\"" % (value)) + return rdns[0] + + +def get_ava(*args, **kwds): + """ + Get AVA from args in open ldap format(raw). Optimized for construction + from openldap format. + + Allowed formats of argument list: + 1) three args - open ldap format (attr and value have to be utf-8 encoded): + a) ['attr', 'value', 0] + 2) two args: + a) ['attr', 'value'] + 3) one arg: + a) [('attr', 'value')] + b) [['attr', 'value']] + c) [AVA(..)] + d) ['attr=value'] + """ + ava = None + l = len(args) + if l == 3: # raw values - constructed FROM RDN + if kwds.get('mutable', False): + ava = args + else: + ava = (args[0], args[1], args[2]) + elif l == 2: # user defined values + ava = [_normalize_ava_input(args[0]), _normalize_ava_input(args[1]), 0] + elif l == 1: # slow mode, tuple, string, + arg = args[0] + if isinstance(arg, AVA): + ava = arg.to_openldap() + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (arg)) + ava = [_normalize_ava_input(arg[0]), _normalize_ava_input(arg[1]), 0] + elif isinstance(arg, basestring): + rdn = str2rdn(arg) + if len(rdn) > 1: + raise TypeError("multiple AVA's specified by \"%s\"" % (arg)) + ava = list(rdn[0]) + else: + raise TypeError("with 1 argument, argument must be str, unicode, tuple or list, got %s instead" % + arg.__class__.__name__) + else: + raise TypeError("invalid number of arguments. 1-3 allowed") + return ava + + +def sort_avas(rdn): + if len(rdn) <= 1: + return + rdn.sort(cmp=cmp_avas) + + +def cmp_avas(a, b): + r = cmp(a[0].lower(), b[0].lower()) + if r == 0: + r = cmp(a[1].lower(), b[1].lower()) + return r + + +def cmp_rdns(a, b): + + l = len(a) + r = cmp(l, len(b)) + if r != 0: + return r + + for i, ava_a in enumerate(a): + r = cmp_avas(ava_a, b[i]) + if r != 0: + return r + return 0 + class AVA(object): ''' AVA(arg0, ...) @@ -552,100 +643,51 @@ class AVA(object): syntax with proper escaping. ''' is_mutable = False - flags = 0 def __init__(self, *args, **kwds): - if len(args) == 1: - arg = args[0] - if isinstance(arg, AVA): - ava = (arg.attr, arg.value) - elif isinstance(arg, basestring): - try: - rdns = str2dn(arg.encode('utf-8')) - except DECODING_ERROR: - raise ValueError("malformed AVA string = \"%s\"" % arg) - if len(rdns) != 1: - raise ValueError("multiple RDN's specified by \"%s\"" % (arg)) - rdn = rdns[0] - if len(rdn) != 1: - raise ValueError("multiple AVA's specified by \"%s\"" % (arg)) - ava = rdn[0] - elif isinstance(arg, (tuple, list)): - ava = arg - if len(ava) != 2: - raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (ava)) - else: - raise TypeError("with 1 argument, argument must be str,unicode,tuple or list, got %s instead" % \ - arg.__class__.__name__) - - attr = ava[0] - value = ava[1] - elif len(args) == 2: - attr = args[0] - value = args[1] - else: - raise TypeError("takes 1 or 2 arguments (%d given)" % (len(args))) - - self._set_attr(attr) - self._set_value(value) + self._ava = get_ava(*args, **{'mutable': self.is_mutable}) def _get_attr(self): - return self._attr_unicode + return self._ava[0].decode('utf-8') def _set_attr(self, new_attr): - # Scalars only - if isinstance(new_attr, (tuple, list)): - raise TypeError("attr must be scalar, got %s" % type(new_attr)) - try: - if isinstance(new_attr, unicode): - self._attr_unicode = new_attr - elif isinstance(new_attr, str): - self._attr_unicode = new_attr.decode('utf-8') - else: - self._attr_unicode = unicode(new_attr) + self._ava[0] = _normalize_ava_input(new_attr) except Exception, e: - raise ValueError('unable to convert attr "%s" to unicode: %s' % (new_attr, e)) + raise ValueError('unable to convert attr "%s": %s' % (new_attr, e)) - attr = property(_get_attr) + attr = property(_get_attr) def _get_value(self): - return self._value_unicode + return self._ava[1].decode('utf-8') def _set_value(self, new_value): - # Scalars only - if isinstance(new_value, (tuple, list)): - raise TypeError("value must be scalar, got %s" % type(new_value)) - try: - if isinstance(new_value, unicode): - self._value_unicode = new_value - elif isinstance(new_value, str): - self._value_unicode = new_value.decode('utf-8') - else: - self._value_unicode = unicode(new_value) + self._ava[1] = _normalize_ava_input(new_value) except Exception, e: - raise ValueError('unable to convert value "%s" to unicode: %s' % (new_value, e)) + raise ValueError('unable to convert value "%s": %s' % (new_value, e)) value = property(_get_value) - def _to_openldap(self): - return [[(self._attr_unicode.encode('utf-8'), self._value_unicode.encode('utf-8'), self.flags)]] + def to_openldap(self): + return list(self._ava) def __str__(self): - return dn2str(self._to_openldap()) + return dn2str([[self.to_openldap()]]) def __repr__(self): return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__()) def __getitem__(self, key): - if isinstance(key, basestring): - if key == self._attr_unicode: - return self._value_unicode - raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) + + if key == 0: + return self.attr + elif key == 1: + return self.value + elif key == self.attr: + return self.value else: - raise TypeError("unsupported type for AVA indexing, must be basestring; not %s" % \ - (key.__class__.__name__)) + raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) def __hash__(self): # Hash is computed from AVA's string representation because it's immutable. @@ -682,8 +724,7 @@ class AVA(object): return False # Perform comparison between objects of same type - return self._attr_unicode.lower() == other.attr.lower() and \ - self._value_unicode.lower() == other.value.lower() + return cmp_avas(self._ava, other._ava) == 0 def __ne__(self, other): return not self.__eq__(other) @@ -694,11 +735,7 @@ class AVA(object): if not isinstance(other, AVA): raise TypeError("expected AVA but got %s" % (other.__class__.__name__)) - result = cmp(self._attr_unicode.lower(), other.attr.lower()) - if result != 0: - return result - result = cmp(self._value_unicode.lower(), other.value.lower()) - return result + return cmp_avas(self._ava, other._ava) class EditableAVA(AVA): ''' @@ -826,113 +863,96 @@ class RDN(object): ''' is_mutable = False - flags = 0 AVA_type = AVA def __init__(self, *args, **kwds): - self.avas = self._avas_from_sequence(args) - self.avas.sort() - - def _ava_from_value(self, value): - if isinstance(value, AVA): - return self.AVA_type(value.attr, value.value) - elif isinstance(value, RDN): - avas = [] - for ava in value.avas: - avas.append(self.AVA_type(ava.attr, ava.value)) - if len(avas) == 1: - return avas[0] - else: - return avas - elif isinstance(value, basestring): - try: - rdns = str2dn(value.encode('utf-8')) - if len(rdns) != 1: - raise ValueError("multiple RDN's specified by \"%s\"" % (value)) - rdn = rdns[0] - if len(rdn) == 1: - return self.AVA_type(rdn[0][0], rdn[0][1]) - else: - avas = [] - for ava_tuple in rdn: - avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1])) - return avas - except DECODING_ERROR: - raise ValueError("malformed RDN string = \"%s\"" % value) - elif isinstance(value, (tuple, list)): - if len(value) != 2: - raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value)) - return self.AVA_type(value) - else: - raise TypeError("must be str,unicode,tuple, or AVA, got %s instead" % \ - value.__class__.__name__) - + self._avas = self._avas_from_sequence(args, kwds.get('raw', False)) - def _avas_from_sequence(self, seq): + def _avas_from_sequence(self, args, raw=False): avas = [] + sort = 0 + ava_count = len(args) - for item in seq: - ava = self._ava_from_value(item) - if isinstance(ava, list): - avas.extend(ava) - else: - avas.append(ava) + if raw: # fast raw mode + try: + if self.is_mutable: + avas = args + else: + for arg in args: + avas.append((arg[0], arg[1], arg[2])) + except KeyError as e: + raise TypeError('all AVA values in RAW mode must be in open ldap format') + elif ava_count == 1 and isinstance(args[0], basestring): + avas = str2rdn(args[0]) + sort = 1 + elif ava_count == 1 and isinstance(args[0], RDN): + avas = args[0].to_openldap() + elif ava_count > 0: + sort = 1 + for arg in args: + avas.append(get_ava(arg)) + if sort: + sort_avas(avas) return avas - def _to_openldap(self): - return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in self.avas]] + def to_openldap(self): + return [list(a) for a in self._avas] def __str__(self): - return dn2str(self._to_openldap()) + return dn2str([self.to_openldap()]) def __repr__(self): return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__()) + def _get_ava(self, ava): + return self.AVA_type(*ava) + def _next(self): - for ava in self.avas: - yield ava + for ava in self._avas: + yield self._get_ava(ava) def __iter__(self): return self._next() def __len__(self): - return len(self.avas) + return len(self._avas) def __getitem__(self, key): - if isinstance(key, (int, long, slice)): - return self.avas[key] + if isinstance(key, (int, long)): + return self._get_ava(self._avas[key]) + if isinstance(key, slice): + return [self._get_ava(ava) for ava in self._avas[key]] elif isinstance(key, basestring): - for ava in self.avas: - if key == ava.attr: - return ava.value + for ava in self._avas: + if key == ava[0].decode('utf-8'): + return ava[1].decode('utf-8') raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) else: raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \ (key.__class__.__name__)) def _get_attr(self): - if len(self.avas) == 0: + if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - return self.avas[0].attr + return self._avas[0][0].decode('utf-8') def _set_attr(self, new_attr): - if len(self.avas) == 0: + if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - self.avas[0].attr = new_attr + self._avas[0][0] = unicode(new_attr).encode('utf-8') attr = property(_get_attr) def _get_value(self): - if len(self.avas) == 0: + if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - return self.avas[0].value + return self._avas[0][1].decode('utf-8') def _set_value(self, new_value): - if len(self.avas) == 0: + if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - - self.avas[0].value = new_value + self._avas[0][1] = unicode(new_value).encode('utf-8') value = property(_get_value) @@ -959,7 +979,7 @@ class RDN(object): return False # Perform comparison between objects of same type - return self.avas == other.avas + return cmp_rdns(self._avas, other._avas) == 0 def __ne__(self, other): return not self.__eq__(other) @@ -968,32 +988,23 @@ class RDN(object): if not isinstance(other, RDN): raise TypeError("expected RDN but got %s" % (other.__class__.__name__)) - result = cmp(len(self), len(other)) - if result != 0: - return result - i = 0 - while i < len(self): - result = cmp(self[i], other[i]) - if result != 0: - return result - i += 1 - return 0 + return cmp_rdns(self._avas, other._avas) def __add__(self, other): result = self.__class__(self) if isinstance(other, RDN): - for ava in other.avas: - result.avas.append(self.AVA_type(ava.attr, ava.value)) + for ava in other._avas: + result._avas.append((ava[0], ava[1], ava[2])) elif isinstance(other, AVA): - result.avas.append(self.AVA_type(other.attr, other.value)) + result._avas.append(other.to_openldap()) elif isinstance(other, basestring): rdn = self.__class__(other) - for ava in rdn.avas: - result.avas.append(self.AVA_type(ava.attr, ava.value)) + for ava in rdn._avas: + result._avas.append((ava[0], ava[1], ava[2])) else: raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__)) - result.avas.sort() + sort_avas(result._avas) return result class EditableRDN(RDN): @@ -1016,24 +1027,22 @@ class EditableRDN(RDN): AVA_type = EditableAVA def __setitem__(self, key, value): + if isinstance(key, (int, long)): - new_ava = self._ava_from_value(value) - if isinstance(new_ava, list): - raise TypeError("cannot assign multiple AVA's to single entry") - self.avas[key] = new_ava + self._avas[key] = get_ava(value) elif isinstance(key, slice): avas = self._avas_from_sequence(value) - self.avas[key] = avas + self._avas[key] = avas elif isinstance(key, basestring): - new_ava = self._ava_from_value(value) - if isinstance(new_ava, list): + if isinstance(value, list): raise TypeError("cannot assign multiple AVA's to single entry") + new_ava = get_ava(value) found = False i = 0 - while i < len(self.avas): - if key == self.avas[i].attr: + while i < len(self._avas): + if key == self._avas[i][0].decode('utf-8'): found = True - self.avas[i] = new_ava + self._avas[i] = new_ava break i += 1 if not found: @@ -1041,7 +1050,7 @@ class EditableRDN(RDN): else: raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \ (key.__class__.__name__)) - self.avas.sort() + sort_avas(self._avas) attr = property(RDN._get_attr, RDN._set_attr) value = property(RDN._get_value, RDN._set_value) @@ -1051,18 +1060,15 @@ class EditableRDN(RDN): # If __iadd__ is not available Python will emulate += by # replacing the lhs object with the result of __add__ (if available). if isinstance(other, RDN): - for ava in other.avas: - self.avas.append(self.AVA_type(ava.attr, ava.value)) + self._avas.extend(other.to_openldap()) elif isinstance(other, AVA): - self.avas.append(self.AVA_type(other.attr, other.value)) + self._avas.append(other.to_openldap()) elif isinstance(other, basestring): - rdn = self.__class__(other) - for ava in rdn.avas: - self.avas.append(self.AVA_type(ava.attr, ava.value)) + self._avas.extend(self._avas_from_sequence([other])) else: raise TypeError("expected RDN, AVA or basestring but got %s" % (other.__class__.__name__)) - self.avas.sort() + sort_avas(self._avas) return self class DN(object): @@ -1213,72 +1219,74 @@ class DN(object): ''' is_mutable = False - flags = 0 AVA_type = AVA RDN_type = RDN def __init__(self, *args, **kwds): self.rdns = self._rdns_from_sequence(args) - def _rdn_from_value(self, value): - if isinstance(value, RDN): - return self.RDN_type(value) - elif isinstance(value, DN): - rdns = [] - for rdn in value.rdns: - rdns.append(self.RDN_type(rdn)) - if len(rdns) == 1: - return rdns[0] - else: - return rdns - elif isinstance(value, basestring): - rdns = [] + def _copy_rdns(self, rdns=None): + if not rdns: + rdns = self.rdns + return [[list(a) for a in rdn] for rdn in rdns] + + def _rdns_from_value(self, value): + if isinstance(value, basestring): try: - dn_list = str2dn(value.encode('utf-8')) - for rdn_list in dn_list: - avas = [] - for ava_tuple in rdn_list: - avas.append(self.AVA_type(ava_tuple[0], ava_tuple[1])) - rdn = self.RDN_type(*avas) - rdns.append(rdn) + if isinstance(value, unicode): + value = value.encode('utf-8') + rdns = str2dn(value) + if self.is_mutable: + self._copy_rdns(rdns) # AVAs to be list instead of tuple except DECODING_ERROR: raise ValueError("malformed RDN string = \"%s\"" % value) - if len(rdns) == 1: - return rdns[0] - else: - return rdns - elif isinstance(value, (tuple, list)): - if len(value) != 2: - raise ValueError("tuple or list must be 2-valued, not \"%s\"" % (value)) - rdn = self.RDN_type(value) - return rdn + for rdn in rdns: + sort_avas(rdn) + elif isinstance(value, DN): + rdns = value._copy_rdns() + elif isinstance(value, (tuple, list, AVA)): + ava = get_ava(value) + rdns = [[ava]] + elif isinstance(value, RDN): + rdns = [value.to_openldap()] else: - raise TypeError("must be str,unicode,tuple, or RDN, got %s instead" % \ - value.__class__.__name__) + raise TypeError("must be str, unicode, tuple, or RDN or DN, got %s instead" % + type(value)) + return rdns def _rdns_from_sequence(self, seq): rdns = [] for item in seq: - rdn = self._rdn_from_value(item) - if isinstance(rdn, list): - rdns.extend(rdn) - else: - rdns.append(rdn) + rdn = self._rdns_from_value(item) + rdns.extend(rdn) return rdns - def _to_openldap(self): - return [[(ava.attr.encode('utf-8'), ava.value.encode('utf-8'), self.flags) for ava in rdn] for rdn in self.rdns] + def __deepcopy__(self, memo): + if self.is_mutable: + cls = self.__class__ + clone = cls.__new__(cls) + clone.rdns = self._copy_rdns() + return clone + return self + + def _get_rdn(self, rdn): + return self.RDN_type(*rdn, **{'raw': True}) def __str__(self): - return dn2str(self._to_openldap()) + try: + return dn2str(self.rdns) + except Exception, e: + print len(self.rdns) + print self.rdns + raise def __repr__(self): return "%s.%s('%s')" % (self.__module__, self.__class__.__name__, self.__str__()) def _next(self): for rdn in self.rdns: - yield rdn + yield self._get_rdn(rdn) def __iter__(self): return self._next() @@ -1287,12 +1295,20 @@ class DN(object): return len(self.rdns) def __getitem__(self, key): - if isinstance(key, (int, long, slice)): - return self.rdns[key] + if isinstance(key, (int, long)): + return self._get_rdn(self.rdns[key]) + if isinstance(key, slice): + cls = self.__class__ + new_dn = cls.__new__(cls) + new_dn.rdns = self.rdns[key] + if self.is_mutable: + new_dn.rdns = self._copy_rdns(new_dn.rdns) + return new_dn elif isinstance(key, basestring): for rdn in self.rdns: - if key == rdn.attr: - return rdn.value + for ava in rdn: + if key == ava[0].decode('utf-8'): + return ava[1].decode('utf-8') raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) else: raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \ @@ -1305,11 +1321,16 @@ class DN(object): # hash value between two objects which compare as equal but # differ in case must yield the same hash value. - return hash(str(self).lower()) + str_dn = ';,'.join([ + '++'.join( + ['=='.join((atype, avalue or '')) for atype,avalue,dummy in rdn] + ) for rdn in self.rdns + ]) + return hash(str_dn.lower()) def __eq__(self, other): - # Try coercing string to DN, if successful compare to coerced object - if isinstance(other, basestring): + # Try coercing to DN, if successful compare to coerced object + if isinstance(other, (basestring, RDN, AVA)): try: other_dn = DN(other) return self.__eq__(other_dn) @@ -1321,7 +1342,7 @@ class DN(object): return False # Perform comparison between objects of same type - return self.rdns == other.rdns + return self.__cmp__(other) == 0 def __ne__(self, other): return not self.__eq__(other) @@ -1336,31 +1357,21 @@ class DN(object): return self._cmp_sequence(other, 0, len(self)) def _cmp_sequence(self, pattern, self_start, pat_len): + self_idx = self_start + self_len = len(self) pat_idx = 0 + # and self_idx < self_len while pat_idx < pat_len: - result = cmp(self[self_idx], pattern[pat_idx]) - if result != 0: - return result + r = cmp_rdns(self.rdns[self_idx], pattern.rdns[pat_idx]) + if r != 0: + return r self_idx += 1 pat_idx += 1 return 0 def __add__(self, other): - result = self.__class__(self) - if isinstance(other, DN): - for rdn in other.rdns: - result.rdns.append(self.RDN_type(rdn)) - elif isinstance(other, RDN): - result.rdns.append(self.RDN_type(other)) - elif isinstance(other, basestring): - dn = self.__class__(other) - for rdn in dn.rdns: - result.rdns.append(rdn) - else: - raise TypeError("expected DN, RDN or basestring but got %s" % (other.__class__.__name__)) - - return result + return self.__class__(self, other) # The implementation of startswith, endswith, tailmatch, adjust_indices # was based on the Python's stringobject.c implementation @@ -1402,10 +1413,10 @@ class DN(object): arguments. Returns 0 if not found and 1 if found. ''' + if isinstance(pattern, RDN): + pattern = DN(pattern) if isinstance(pattern, DN): pat_len = len(pattern) - elif isinstance(pattern, RDN): - pat_len = 1 else: raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__)) @@ -1423,16 +1434,16 @@ class DN(object): if end-pat_len >= start: start = end - pat_len - if isinstance(pattern, DN): - if end-start >= pat_len: - return not self._cmp_sequence(pattern, start, pat_len) - return 0 - else: - return self.rdns[start] == pattern + if end-start >= pat_len: + return not self._cmp_sequence(pattern, start, pat_len) + return 0 + def __contains__(self, other): 'Return the outcome of the test other in self. Note the reversed operands.' + if isinstance(other, RDN): + other = DN(other) if isinstance(other, DN): other_len = len(other) end = len(self) - other_len @@ -1443,16 +1454,13 @@ class DN(object): return True i += 1 return False - - elif isinstance(other, RDN): - return other in self.rdns else: raise TypeError("expected DN or RDN but got %s" % (other.__class__.__name__)) def find(self, pattern, start=None, end=None): ''' - Return the lowest index in the DN where pattern DN (or RDN) is found, + Return the lowest index in the DN where pattern DN is found, such that pattern is contained in the range [start, end]. Optional arguments start and end are interpreted as in slice notation. Return -1 if pattern is not found. @@ -1460,10 +1468,8 @@ class DN(object): if isinstance(pattern, DN): pat_len = len(pattern) - elif isinstance(pattern, RDN): - pat_len = 1 else: - raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__)) + raise TypeError("expected DN but got %s" % (pattern.__class__.__name__)) self_len = len(self) @@ -1476,19 +1482,14 @@ class DN(object): i = start stop = max(start, end - pat_len) - if isinstance(pattern, DN): - while i <= stop: - result = self._cmp_sequence(pattern, i, pat_len) - if result == 0: - return i - i += 1 - return -1 - else: - while i <= stop: - if self.rdns[i] == pattern: - return i - i += 1 - return -1 + + while i <= stop: + result = self._cmp_sequence(pattern, i, pat_len) + if result == 0: + return i + i += 1 + return -1 + def index(self, pattern, start=None, end=None): ''' @@ -1502,7 +1503,7 @@ class DN(object): def rfind(self, pattern, start=None, end=None): ''' - Return the highest index in the DN where pattern DN (or RDN) is found, + Return the highest index in the DN where pattern DN is found, such that pattern is contained in the range [start, end]. Optional arguments start and end are interpreted as in slice notation. Return -1 if pattern is not found. @@ -1510,10 +1511,8 @@ class DN(object): if isinstance(pattern, DN): pat_len = len(pattern) - elif isinstance(pattern, RDN): - pat_len = 1 else: - raise TypeError("expected DN or RDN but got %s" % (pattern.__class__.__name__)) + raise TypeError("expected DN but got %s" % (pattern.__class__.__name__)) self_len = len(self) @@ -1526,19 +1525,13 @@ class DN(object): i = max(start, min(end, self_len - pat_len)) stop = start - if isinstance(pattern, DN): - while i >= stop: - result = self._cmp_sequence(pattern, i, pat_len) - if result == 0: - return i - i -= 1 - return -1 - else: - while i >= stop: - if self.rdns[i] == pattern: - return i - i -= 1 - return -1 + + while i >= stop: + result = self._cmp_sequence(pattern, i, pat_len) + if result == 0: + return i + i -= 1 + return -1 def rindex(self, pattern, start=None, end=None): ''' @@ -1573,23 +1566,23 @@ class EditableDN(DN): def __setitem__(self, key, value): if isinstance(key, (int, long)): - new_rdn = self._rdn_from_value(value) - if isinstance(new_rdn, list): + new_rdns = self._rdns_from_value(value) + if len(new_rdns) > 1: raise TypeError("cannot assign multiple RDN's to single entry") - self.rdns[key] = new_rdn + self.rdns[key] = new_rdns[0] elif isinstance(key, slice): rdns = self._rdns_from_sequence(value) self.rdns[key] = rdns elif isinstance(key, basestring): - new_rdn = self._rdn_from_value(value) - if isinstance(new_rdn, list): + new_rdns = self._rdns_from_value(value) + if len(new_rdns) > 1: raise TypeError("cannot assign multiple values to single entry") found = False i = 0 while i < len(self.rdns): - if key == self.rdns[i].attr: + if key == self.rdns[i][0][0].decode('utf-8'): found = True - self.rdns[i] = new_rdn + self.rdns[i] = new_rdns[0] break i += 1 if not found: @@ -1602,10 +1595,9 @@ class EditableDN(DN): # If __iadd__ is not available Python will emulate += by # replacing the lhs object with the result of __add__ (if available). if isinstance(other, DN): - for rdn in other.rdns: - self.rdns.append(self.RDN_type(rdn)) + self.rdns.extend(other._copy_rdns()) elif isinstance(other, RDN): - self.rdns.append(self.RDN_type(other)) + self.rdns.append(other.to_openldap()) elif isinstance(other, basestring): dn = self.__class__(other) self.__iadd__(dn) @@ -1627,7 +1619,11 @@ class EditableDN(DN): for slice indices. ''' - self.rdns.insert(i, self._rdn_from_value(x)) + rdns = self._rdns_from_value(x) + if len(rdns) > 1: + raise TypeError("cannot assign multiple RDN's to single entry") + + self.rdns.insert(i, rdns[0]) def replace(self, old, new, count=sys.maxsize): ''' @@ -1656,3 +1652,4 @@ class EditableDN(DN): start = index + pat_len return n_replaced + |