summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--custodia/message/formats.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/custodia/message/formats.py b/custodia/message/formats.py
index 48d5955..00845a3 100644
--- a/custodia/message/formats.py
+++ b/custodia/message/formats.py
@@ -22,9 +22,7 @@ class Validator(object):
:param allowed: list of allowed message types (optional)
"""
self.allowed = allowed or default_types
- self.types = dict()
- for t in self.allowed:
- self.types[t] = key_types[t]
+ self.types = key_types.copy()
def add_types(self, types):
self.types.update(types)
@@ -36,16 +34,30 @@ class Validator(object):
if 'type' not in msg:
raise InvalidMessage('The type is missing')
+ if isinstance(msg['type'], list):
+ if len(msg['type']) != 1:
+ raise InvalidMessage('Type is multivalued: %s' % msg['type'])
+ msg_type = msg['type'][0]
+ else:
+ msg_type = msg['type']
+
if 'value' not in msg:
raise InvalidMessage('The value is missing')
- if msg['type'] not in list(self.types.keys()):
- raise UnknownMessageType("Type '%s' is unknown" % msg['type'])
+ if isinstance(msg['value'], list):
+ if len(msg['value']) != 1:
+ raise InvalidMessage('Value is multivalued: %s' % msg['value'])
+ msg_value = msg['value'][0]
+ else:
+ msg_value = msg['value']
+
+ if msg_type not in list(self.types.keys()):
+ raise UnknownMessageType("Type '%s' is unknown" % msg_type)
- if msg['type'] not in self.allowed:
+ if msg_type not in self.allowed:
raise UnallowedMessage("Message type '%s' not allowed" % (
- msg['type'],))
+ msg_type,))
- handler = self.types[msg['type']](request)
- handler.parse(msg['value'])
+ handler = self.types[msg_type](request)
+ handler.parse(msg_value)
return handler