summaryrefslogtreecommitdiffstats
path: root/ipapython
diff options
context:
space:
mode:
authorPetr Viktorin <pviktori@redhat.com>2015-09-18 15:28:23 +0200
committerJan Cholasta <jcholast@redhat.com>2015-10-07 10:27:20 +0200
commited96f8d9ba399d2c6909806692321d638b11ba8b (patch)
tree4c4175d101efea17aeedea847f43b1305e4ba95d /ipapython
parentc9ca8de7a25af063a50d6735bec6a5181e96d758 (diff)
ipapython.dn: Use rich comparisons
__cmp__ and cmp were removed from Python 3. Reviewed-By: David Kupka <dkupka@redhat.com> Reviewed-By: Jan Cholasta <jcholast@redhat.com> Reviewed-By: Martin Basti <mbasti@redhat.com>
Diffstat (limited to 'ipapython')
-rw-r--r--ipapython/dn.py117
1 files changed, 71 insertions, 46 deletions
diff --git a/ipapython/dn.py b/ipapython/dn.py
index 5a42ab37e..8585ee95f 100644
--- a/ipapython/dn.py
+++ b/ipapython/dn.py
@@ -420,6 +420,7 @@ to the constructor. The result may share underlying structure.
from __future__ import print_function
import sys
+import functools
from ldap.dn import str2dn, dn2str
from ldap import DECODING_ERROR
@@ -449,9 +450,11 @@ def _adjust_indices(start, end, length):
def _normalize_ava_input(val):
- if not isinstance(val, six.string_types):
- val = unicode(val).encode('utf-8')
- elif isinstance(val, unicode):
+ if six.PY3 and isinstance(val, bytes):
+ raise TypeError('expected str, got bytes: %s' % val)
+ elif not isinstance(val, six.string_types):
+ val = val_encode(six.text_type(val))
+ elif six.PY2 and isinstance(val, unicode):
val = val.encode('utf-8')
return val
@@ -512,29 +515,47 @@ def get_ava(*args):
def sort_avas(rdn):
if len(rdn) <= 1:
return
- rdn.sort(cmp=cmp_avas)
+ rdn.sort(key=ava_key)
-def cmp_avas(a, b):
- r = cmp(a[0].lower(), b[0].lower())
- if r == 0:
- r = cmp(a[1].lower(), b[1].lower())
- return r
+def ava_key(ava):
+ return ava[0].lower(), ava[1].lower()
def cmp_rdns(a, b):
+ key_a = rdn_key(a)
+ key_b = rdn_key(b)
+ if key_a == key_b:
+ return 0
+ elif key_a < key_b:
+ return -1
+ else:
+ return 1
+
+
+def rdn_key(rdn):
+ return (len(rdn),) + tuple(ava_key(k) for k in rdn)
+
- l = len(a)
- r = cmp(l, len(b))
- if r != 0:
- return r
+if six.PY2:
+ # Python 2: Input/output is unicode; we store UTF-8 bytes
+ def val_encode(s):
+ return s.encode('utf-8')
- for i, ava_a in enumerate(a):
- r = cmp_avas(ava_a, b[i])
- if r != 0:
- return r
- return 0
+ def val_decode(s):
+ return s.decode('utf-8')
+else:
+ # Python 3: Everything is unicode (str)
+ def val_encode(s):
+ if isinstance(s, bytes):
+ raise TypeError('expected str, got bytes: %s' % s)
+ return s
+ def val_decode(s):
+ return s
+
+
+@functools.total_ordering
class AVA(object):
'''
AVA(arg0, ...)
@@ -593,23 +614,23 @@ class AVA(object):
self._ava = get_ava(*args)
def _get_attr(self):
- return self._ava[0].decode('utf-8')
+ return val_decode(self._ava[0])
def _set_attr(self, new_attr):
try:
self._ava[0] = _normalize_ava_input(new_attr)
- except Exception, e:
+ except Exception as e:
raise ValueError('unable to convert attr "%s": %s' % (new_attr, e))
attr = property(_get_attr)
def _get_value(self):
- return self._ava[1].decode('utf-8')
+ return val_decode(self._ava[1])
def _set_value(self, new_value):
try:
self._ava[1] = _normalize_ava_input(new_value)
- except Exception, e:
+ except Exception as e:
raise ValueError('unable to convert value "%s": %s' % (new_value, e))
value = property(_get_value)
@@ -669,20 +690,21 @@ class AVA(object):
return False
# Perform comparison between objects of same type
- return cmp_avas(self._ava, other._ava) == 0
+ return ava_key(self._ava) == ava_key(other._ava)
def __ne__(self, other):
return not self.__eq__(other)
- def __cmp__(self, other):
+ def __lt__(self, other):
'comparison is case insensitive, see __eq__ doc for explanation'
if not isinstance(other, AVA):
raise TypeError("expected AVA but got %s" % (other.__class__.__name__))
- return cmp_avas(self._ava, other._ava)
+ return ava_key(self._ava) < ava_key(other._ava)
+@functools.total_ordering
class RDN(object):
'''
RDN(arg0, ...)
@@ -843,8 +865,8 @@ class RDN(object):
return [self._get_ava(ava) for ava in self._avas[key]]
elif isinstance(key, six.string_types):
for ava in self._avas:
- if key == ava[0].decode('utf-8'):
- return ava[1].decode('utf-8')
+ if key == val_decode(ava[0]):
+ return val_decode(ava[1])
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for RDN indexing, must be int, basestring or slice; not %s" % \
@@ -853,25 +875,25 @@ class RDN(object):
def _get_attr(self):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- return self._avas[0][0].decode('utf-8')
+ return val_decode(self._avas[0][0])
def _set_attr(self, new_attr):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- self._avas[0][0] = unicode(new_attr).encode('utf-8')
+ self._avas[0][0] = val_encode(six.text_type(new_attr))
attr = property(_get_attr)
def _get_value(self):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- return self._avas[0][1].decode('utf-8')
+ return val_decode(self._avas[0][1])
def _set_value(self, new_value):
if len(self._avas) == 0:
raise IndexError("No AVA's in this RDN")
- self._avas[0][1] = unicode(new_value).encode('utf-8')
+ self._avas[0][1] = val_encode(six.text_type(new_value))
value = property(_get_value)
@@ -898,16 +920,16 @@ class RDN(object):
return False
# Perform comparison between objects of same type
- return cmp_rdns(self._avas, other._avas) == 0
+ return rdn_key(self._avas) == rdn_key(other._avas)
def __ne__(self, other):
return not self.__eq__(other)
- def __cmp__(self, other):
+ def __lt__(self, other):
if not isinstance(other, RDN):
raise TypeError("expected RDN but got %s" % (other.__class__.__name__))
- return cmp_rdns(self._avas, other._avas)
+ return rdn_key(self._avas) < rdn_key(other._avas)
def __add__(self, other):
result = self.__class__(self)
@@ -927,6 +949,7 @@ class RDN(object):
return result
+@functools.total_ordering
class DN(object):
'''
DN(arg0, ...)
@@ -1088,8 +1111,8 @@ class DN(object):
def _rdns_from_value(self, value):
if isinstance(value, six.string_types):
try:
- if isinstance(value, unicode):
- value = value.encode('utf-8')
+ if isinstance(value, six.text_type):
+ value = val_encode(value)
rdns = str2dn(value)
except DECODING_ERROR:
raise ValueError("malformed RDN string = \"%s\"" % value)
@@ -1124,7 +1147,7 @@ class DN(object):
def __str__(self):
try:
return dn2str(self.rdns)
- except Exception, e:
+ except Exception as e:
print(len(self.rdns))
print(self.rdns)
raise
@@ -1153,8 +1176,8 @@ class DN(object):
elif isinstance(key, six.string_types):
for rdn in self.rdns:
for ava in rdn:
- if key == ava[0].decode('utf-8'):
- return ava[1].decode('utf-8')
+ if key == val_decode(ava[0]):
+ return val_decode(ava[1])
raise KeyError("\"%s\" not found in %s" % (key, self.__str__()))
else:
raise TypeError("unsupported type for DN indexing, must be int, basestring or slice; not %s" % \
@@ -1187,23 +1210,25 @@ class DN(object):
if not isinstance(other, DN):
return False
+ if len(self) != len(other):
+ return False
+
# Perform comparison between objects of same type
- return self.__cmp__(other) == 0
+ return self._cmp_sequence(other, 0, len(self)) == 0
def __ne__(self, other):
return not self.__eq__(other)
- def __cmp__(self, other):
+ def __lt__(self, other):
if not isinstance(other, DN):
raise TypeError("expected DN but got %s" % (other.__class__.__name__))
- result = cmp(len(self), len(other))
- if result != 0:
- return result
- return self._cmp_sequence(other, 0, len(self))
+ if len(self) != len(other):
+ return len(self) < len(other)
- def _cmp_sequence(self, pattern, self_start, pat_len):
+ return self._cmp_sequence(other, 0, len(self)) < 0
+ def _cmp_sequence(self, pattern, self_start, pat_len):
self_idx = self_start
self_len = len(self)
pat_idx = 0