diff options
| author | Petr Viktorin <pviktori@redhat.com> | 2015-09-18 15:28:23 +0200 |
|---|---|---|
| committer | Jan Cholasta <jcholast@redhat.com> | 2015-10-07 10:27:20 +0200 |
| commit | ed96f8d9ba399d2c6909806692321d638b11ba8b (patch) | |
| tree | 4c4175d101efea17aeedea847f43b1305e4ba95d /ipapython | |
| parent | c9ca8de7a25af063a50d6735bec6a5181e96d758 (diff) | |
ipapython.dn: Use rich comparisons
__cmp__ and cmp were removed from Python 3.
Reviewed-By: David Kupka <dkupka@redhat.com>
Reviewed-By: Jan Cholasta <jcholast@redhat.com>
Reviewed-By: Martin Basti <mbasti@redhat.com>
Diffstat (limited to 'ipapython')
| -rw-r--r-- | ipapython/dn.py | 117 |
1 files changed, 71 insertions, 46 deletions
diff --git a/ipapython/dn.py b/ipapython/dn.py index 5a42ab37e..8585ee95f 100644 --- a/ipapython/dn.py +++ b/ipapython/dn.py @@ -420,6 +420,7 @@ to the constructor. The result may share underlying structure. from __future__ import print_function import sys +import functools from ldap.dn import str2dn, dn2str from ldap import DECODING_ERROR @@ -449,9 +450,11 @@ def _adjust_indices(start, end, length): def _normalize_ava_input(val): - if not isinstance(val, six.string_types): - val = unicode(val).encode('utf-8') - elif isinstance(val, unicode): + if six.PY3 and isinstance(val, bytes): + raise TypeError('expected str, got bytes: %s' % val) + elif not isinstance(val, six.string_types): + val = val_encode(six.text_type(val)) + elif six.PY2 and isinstance(val, unicode): val = val.encode('utf-8') return val @@ -512,29 +515,47 @@ def get_ava(*args): def sort_avas(rdn): if len(rdn) <= 1: return - rdn.sort(cmp=cmp_avas) + rdn.sort(key=ava_key) -def cmp_avas(a, b): - r = cmp(a[0].lower(), b[0].lower()) - if r == 0: - r = cmp(a[1].lower(), b[1].lower()) - return r +def ava_key(ava): + return ava[0].lower(), ava[1].lower() def cmp_rdns(a, b): + key_a = rdn_key(a) + key_b = rdn_key(b) + if key_a == key_b: + return 0 + elif key_a < key_b: + return -1 + else: + return 1 + + +def rdn_key(rdn): + return (len(rdn),) + tuple(ava_key(k) for k in rdn) + - l = len(a) - r = cmp(l, len(b)) - if r != 0: - return r +if six.PY2: + # Python 2: Input/output is unicode; we store UTF-8 bytes + def val_encode(s): + return s.encode('utf-8') - for i, ava_a in enumerate(a): - r = cmp_avas(ava_a, b[i]) - if r != 0: - return r - return 0 + def val_decode(s): + return s.decode('utf-8') +else: + # Python 3: Everything is unicode (str) + def val_encode(s): + if isinstance(s, bytes): + raise TypeError('expected str, got bytes: %s' % s) + return s + def val_decode(s): + return s + + +@functools.total_ordering class AVA(object): ''' AVA(arg0, ...) @@ -593,23 +614,23 @@ class AVA(object): self._ava = get_ava(*args) def _get_attr(self): - return self._ava[0].decode('utf-8') + return val_decode(self._ava[0]) def _set_attr(self, new_attr): try: self._ava[0] = _normalize_ava_input(new_attr) - except Exception, e: + except Exception as e: raise ValueError('unable to convert attr "%s": %s' % (new_attr, e)) attr = property(_get_attr) def _get_value(self): - return self._ava[1].decode('utf-8') + return val_decode(self._ava[1]) def _set_value(self, new_value): try: self._ava[1] = _normalize_ava_input(new_value) - except Exception, e: + except Exception as e: raise ValueError('unable to convert value "%s": %s' % (new_value, e)) value = property(_get_value) @@ -669,20 +690,21 @@ class AVA(object): return False # Perform comparison between objects of same type - return cmp_avas(self._ava, other._ava) == 0 + return ava_key(self._ava) == ava_key(other._ava) def __ne__(self, other): return not self.__eq__(other) - def __cmp__(self, other): + def __lt__(self, other): 'comparison is case insensitive, see __eq__ doc for explanation' if not isinstance(other, AVA): raise TypeError("expected AVA but got %s" % (other.__class__.__name__)) - return cmp_avas(self._ava, other._ava) + return ava_key(self._ava) < ava_key(other._ava) +@functools.total_ordering class RDN(object): ''' RDN(arg0, ...) @@ -843,8 +865,8 @@ class RDN(object): return [self._get_ava(ava) for ava in self._avas[key]] elif isinstance(key, six.string_types): for ava in self._avas: - if key == ava[0].decode('utf-8'): - return ava[1].decode('utf-8') + if key == val_decode(ava[0]): + return val_decode(ava[1]) raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) else: raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \ @@ -853,25 +875,25 @@ class RDN(object): def _get_attr(self): if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - return self._avas[0][0].decode('utf-8') + return val_decode(self._avas[0][0]) def _set_attr(self, new_attr): if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - self._avas[0][0] = unicode(new_attr).encode('utf-8') + self._avas[0][0] = val_encode(six.text_type(new_attr)) attr = property(_get_attr) def _get_value(self): if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - return self._avas[0][1].decode('utf-8') + return val_decode(self._avas[0][1]) def _set_value(self, new_value): if len(self._avas) == 0: raise IndexError("No AVA's in this RDN") - self._avas[0][1] = unicode(new_value).encode('utf-8') + self._avas[0][1] = val_encode(six.text_type(new_value)) value = property(_get_value) @@ -898,16 +920,16 @@ class RDN(object): return False # Perform comparison between objects of same type - return cmp_rdns(self._avas, other._avas) == 0 + return rdn_key(self._avas) == rdn_key(other._avas) def __ne__(self, other): return not self.__eq__(other) - def __cmp__(self, other): + def __lt__(self, other): if not isinstance(other, RDN): raise TypeError("expected RDN but got %s" % (other.__class__.__name__)) - return cmp_rdns(self._avas, other._avas) + return rdn_key(self._avas) < rdn_key(other._avas) def __add__(self, other): result = self.__class__(self) @@ -927,6 +949,7 @@ class RDN(object): return result +@functools.total_ordering class DN(object): ''' DN(arg0, ...) @@ -1088,8 +1111,8 @@ class DN(object): def _rdns_from_value(self, value): if isinstance(value, six.string_types): try: - if isinstance(value, unicode): - value = value.encode('utf-8') + if isinstance(value, six.text_type): + value = val_encode(value) rdns = str2dn(value) except DECODING_ERROR: raise ValueError("malformed RDN string = \"%s\"" % value) @@ -1124,7 +1147,7 @@ class DN(object): def __str__(self): try: return dn2str(self.rdns) - except Exception, e: + except Exception as e: print(len(self.rdns)) print(self.rdns) raise @@ -1153,8 +1176,8 @@ class DN(object): elif isinstance(key, six.string_types): for rdn in self.rdns: for ava in rdn: - if key == ava[0].decode('utf-8'): - return ava[1].decode('utf-8') + if key == val_decode(ava[0]): + return val_decode(ava[1]) raise KeyError("\"%s\" not found in %s" % (key, self.__str__())) else: raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \ @@ -1187,23 +1210,25 @@ class DN(object): if not isinstance(other, DN): return False + if len(self) != len(other): + return False + # Perform comparison between objects of same type - return self.__cmp__(other) == 0 + return self._cmp_sequence(other, 0, len(self)) == 0 def __ne__(self, other): return not self.__eq__(other) - def __cmp__(self, other): + def __lt__(self, other): if not isinstance(other, DN): raise TypeError("expected DN but got %s" % (other.__class__.__name__)) - result = cmp(len(self), len(other)) - if result != 0: - return result - return self._cmp_sequence(other, 0, len(self)) + if len(self) != len(other): + return len(self) < len(other) - def _cmp_sequence(self, pattern, self_start, pat_len): + return self._cmp_sequence(other, 0, len(self)) < 0 + def _cmp_sequence(self, pattern, self_start, pat_len): self_idx = self_start self_len = len(self) pat_idx = 0 |
