summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNathaniel McCallum <npmccallum@redhat.com>2013-09-30 12:45:37 -0400
committerPetr Viktorin <pviktori@redhat.com>2013-10-09 18:05:37 +0200
commit4f6580f11ded1c456e0891023232b0d715d8aef7 (patch)
tree3aaf1b65fa94e570d8fb2ef44428aaa16b3e509b
parente05dfbd8b4b4e040266ecfba579bcd64e22b342b (diff)
Allow multiple types in Param type validation
Int already needed to take both int and long. This makes the functionality available for all Param classes.
-rw-r--r--ipalib/parameters.py53
-rw-r--r--ipatests/test_ipalib/test_parameters.py3
2 files changed, 20 insertions, 36 deletions
diff --git a/ipalib/parameters.py b/ipalib/parameters.py
index 79b9062bb..97e449c29 100644
--- a/ipalib/parameters.py
+++ b/ipalib/parameters.py
@@ -363,7 +363,9 @@ class Param(ReadOnly):
# This is a dummy type so that most of the functionality of Param can be
# unit tested directly without always creating a subclass; however, a real
- # (direct) subclass must *always* override this class attribute:
+ # (direct) subclass must *always* override this class attribute.
+ # If multiple types are permitted, set `type` to the canonical type and
+ # `allowed_types` to a tuple of all allowed types.
type = NoneType # Ouch, this wont be very useful in the real world!
# Subclasses should override this with something more specific:
@@ -400,6 +402,11 @@ class Param(ReadOnly):
# ('default', self.type, None),
)
+ @property
+ def allowed_types(self):
+ """The allowed datatypes for this Param"""
+ return (self.type,)
+
def __init__(self, name, *rules, **kw):
# We keep these values to use in __repr__():
self.param_spec = name
@@ -415,7 +422,7 @@ class Param(ReadOnly):
self.nice = '%s(%r)' % (self.__class__.__name__, self.param_spec)
# Add 'default' to self.kwargs and makes sure no unknown kw were given:
- assert type(self.type) is type
+ assert all(type(t) is type for t in self.allowed_types)
if kw.get('multivalue', True):
self.kwargs += (('default', tuple, None),)
else:
@@ -782,7 +789,7 @@ class Param(ReadOnly):
"""
Convert a single scalar value.
"""
- if type(value) is self.type:
+ if type(value) in self.allowed_types:
return value
raise ConversionError(name=self.name, index=index,
error=ugettext(self.type_error),
@@ -816,7 +823,7 @@ class Param(ReadOnly):
self._validate_scalar(value)
def _validate_scalar(self, value, index=None):
- if type(value) is not self.type:
+ if type(value) not in self.allowed_types:
raise TypeError(
TYPE_ERROR % (self.name, self.type, value, type(value))
)
@@ -942,7 +949,7 @@ class Bool(Param):
"""
Convert a single scalar value.
"""
- if type(value) is self.type:
+ if type(value) in self.allowed_types:
return value
if isinstance(value, basestring):
value = value.lower()
@@ -1009,7 +1016,7 @@ class Number(Param):
"""
Convert a single scalar value.
"""
- if type(value) is self.type:
+ if type(value) in self.allowed_types:
return value
if type(value) in (unicode, int, long, float):
try:
@@ -1030,6 +1037,7 @@ class Int(Number):
"""
type = int
+ allowed_types = int, long
type_error = _('must be an integer')
kwargs = Param.kwargs + (
@@ -1095,31 +1103,6 @@ class Int(Number):
maxvalue=self.maxvalue,
)
- def _validate_scalar(self, value, index=None):
- """
- This duplicates _validate_scalar in the Param class with
- the exception that it allows both int and long types. The
- min/max rules handle size enforcement.
- """
- if type(value) not in (int, long):
- raise TypeError(
- TYPE_ERROR % (self.name, self.type, value, type(value))
- )
- if index is not None and type(index) is not int:
- raise TypeError(
- TYPE_ERROR % ('index', int, index, type(index))
- )
- for rule in self.all_rules:
- error = rule(ugettext, value)
- if error is not None:
- raise ValidationError(
- name=self.get_param_name(),
- value=value,
- index=index,
- error=error,
- rule=rule,
- )
-
class Decimal(Number):
"""
@@ -1315,7 +1298,7 @@ class Data(Param):
"""
Check pattern (regex) contraint.
"""
- assert type(value) is self.type
+ assert type(value) in self.allowed_types
if self.re.match(value) is None:
if self.re_errmsg:
return self.re_errmsg % dict(pattern=self.pattern,)
@@ -1418,7 +1401,7 @@ class Str(Data):
"""
Convert a single scalar value.
"""
- if type(value) is self.type:
+ if type(value) in self.allowed_types:
return value
if type(value) in (int, long, float, decimal.Decimal):
return self.type(value)
@@ -1522,7 +1505,7 @@ class Enum(Param):
def __init__(self, name, *rules, **kw):
super(Enum, self).__init__(name, *rules, **kw)
for (i, v) in enumerate(self.values):
- if type(v) is not self.type:
+ if type(v) not in self.allowed_types:
n = '%s values[%d]' % (self.nice, i)
raise TypeError(
TYPE_ERROR % (n, self.type, v, type(v))
@@ -1789,7 +1772,7 @@ class DNParam(Param):
"""
Convert a single scalar value.
"""
- if type(value) is self.type:
+ if type(value) in self.allowed_types:
return value
try:
diff --git a/ipatests/test_ipalib/test_parameters.py b/ipatests/test_ipalib/test_parameters.py
index 71acfce71..22c7b7355 100644
--- a/ipatests/test_ipalib/test_parameters.py
+++ b/ipatests/test_ipalib/test_parameters.py
@@ -1173,7 +1173,8 @@ class test_Int(ClassChecker):
"""
# Test with no kwargs:
o = self.cls('my_number')
- assert o.type is int
+ assert o.type == int
+ assert o.allowed_types == (int, long)
assert isinstance(o, parameters.Int)
assert o.minvalue == int(MININT)
assert o.maxvalue == int(MAXINT)