diff options
-rw-r--r-- | openstack/common/rpc/securemessage.py | 127 | ||||
-rw-r--r-- | tests/unit/rpc/test_securemessage.py | 2 |
2 files changed, 104 insertions, 25 deletions
diff --git a/openstack/common/rpc/securemessage.py b/openstack/common/rpc/securemessage.py index 523bb2e..c5e4738 100644 --- a/openstack/common/rpc/securemessage.py +++ b/openstack/common/rpc/securemessage.py @@ -129,6 +129,15 @@ class InvalidExpiredTicket(SecureMessageException): super(InvalidExpiredTicket, self).__init__(self.msg % (src, dst)) +class InvalidKDSReply(SecureMessageException): + """The KDS Reply could not be successfully verified.""" + + msg = "Invalid KDS Reply (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidKDSReply, self).__init__(self.msg % (src, dst)) + + class CommunicationError(SecureMessageException): """The Communication with the KDS failed.""" @@ -147,7 +156,8 @@ class InvalidArgument(SecureMessageException): super(InvalidArgument, self).__init__(self.msg % errmsg) -Ticket = collections.namedtuple('Ticket', ['skey', 'ekey', 'esek']) +Ticket = collections.namedtuple('Ticket', ['target', 'skey', 'ekey', 'esek']) +GroupKey = collections.namedtuple('GroupKey', ['generation', 'key']) class KeyStore(object): @@ -192,7 +202,7 @@ class KeyStore(object): :param esek: The token encrypted with the target key :param expiration: Expiration time in seconds since Epoch """ - keys = Ticket(skey, ekey, esek) + keys = Ticket(target, skey, ekey, esek) self._put(source, target, 'ticket', expiration, keys) def get_ticket(self, source, target): @@ -201,6 +211,24 @@ class KeyStore(object): """ return self._get(source, target, 'ticket') + def put_group_key(self, source, target, generation, key, expiration): + """Puts a sek pair in the cache. + + :param source: Client name + :param target: Target name + :param generation: The Generation number. + :param key: The Group Key. + :param expiration: Expiration time in seconds since Epoch + """ + keys = GroupKey(generation, key) + self._put(source, target, 'group_key', expiration, keys) + + def get_group_key(self, source, target): + """Returns a GroupKey (generation, key) namedtuple for the + source/target pair. + """ + return self._get(source, target, 'group_key') + _KEY_STORE = KeyStore() @@ -249,15 +277,15 @@ class _KDSClient(object): return reply - def _get_ticket(self, request, url=None, redirects=10): + def _make_request(self, request, url=None, redirects=10): """Send an HTTP request. Wraps around 'requests' to handle redirects and common errors. """ - if url is None: + if url.startswith('/'): if not self._endpoint: raise CommunicationError(url, 'Endpoint not configured') - url = self._endpoint + '/kds/ticket/' + request['signature'] + url = self._endpoint + url while redirects: resp = self._do_post(url, request) @@ -275,9 +303,7 @@ class _KDSClient(object): raise CommunicationError(url, "Too many redirections, giving up!") - def get_ticket(self, source, target, crypto, key): - - # prepare metadata + def _signed_request(self, source, target, crypto, key): md = {'requestor': source, 'target': target, 'timestamp': time.time(), @@ -287,25 +313,57 @@ class _KDSClient(object): # sign metadata signature = crypto.sign(key, metadata) - # HTTP request - reply = self._get_ticket({'metadata': metadata, - 'signature': signature}) + return {'metadata': metadata, 'signature': signature} + + def _check_signature(self, crypto, key, metadata, payload, signature): + sig = crypto.sign(key, metadata + payload) + if sig != signature: + raise InvalidKDSReply(metadata['source'], metadata['destination']) + + def get_ticket(self, source, target, crypto, key): + + request = self._signed_request(source, target, crypto, key) + reply = self._make_request(request, + url='/kds/ticket/' + request['signature']) + + self._check_signature(crypto, key, + reply['metadata'], + reply['ticket'], + reply['signature']) - # verify reply - signature = crypto.sign(key, (reply['metadata'] + reply['ticket'])) - if signature != reply['signature']: - raise InvalidEncryptedTicket(md['source'], md['destination']) md = jsonutils.loads(base64.b64decode(reply['metadata'])) - if ((md['source'] != source or - md['destination'] != target or - md['expiration'] < time.time())): - raise InvalidEncryptedTicket(md['source'], md['destination']) + if (md['source'] != source or + md['expiration'] < time.time() or + (md['destination'] != target and + md['destination'].split(':')[0] != target)): + raise InvalidKDSReply(md['source'], md['destination']) # return ticket data tkt = jsonutils.loads(crypto.decrypt(key, reply['ticket'])) return tkt, md['expiration'] + def get_group_key(self, source, target, crypto, key): + + request = self._signed_request(source, target, crypto, key) + reply = self._make_request(request, + url='/kds/group_key/' + target) + + self._check_signature(crypto, key, + reply['metadata'], + reply['group_key'], + reply['signature']) + + md = jsonutils.loads(base64.b64decode(reply['metadata'])) + if ((md['source'] != self._name or + md['destination'] != target or + md['expiration'] < time.time())): + raise InvalidKDSReply(md['source'], md['destination']) + + group_key = crypto.decrypt(key, reply['group_key']) + + return group_key, md['expiration'] + # we need to keep a global nonce, as this value should never repeat non # matter how many SecureMessage objects we create @@ -349,12 +407,13 @@ class SecureMessage(object): :param key: (optional) explicitly pass in endpoint private key. If not provided it will be sourced from the service config :param key_store: (optional) Storage class for local caching + :param group: (optional) Group Name used to retrieve group keys :param encrypt: (defaults to False) Whether to encrypt messages :param enctype: (defaults to AES) Cipher to use :param hashtype: (defaults to SHA256) Hash function to use for signatures """ - def __init__(self, topic, host, conf, key=None, key_store=None, + def __init__(self, topic, host, conf, key=None, key_store=None, group=None, encrypt=None, enctype='AES', hashtype='SHA256'): conf.register_group(secure_message_group) @@ -364,6 +423,7 @@ class SecureMessage(object): self._key = key self._conf = conf.secure_messages self._encrypt = self._conf.encrypt if (encrypt is None) else encrypt + self._group = group if (group is not None) else topic self._crypto = cryptoutils.SymmetricCrypto(enctype, hashtype) self._hkdf = cryptoutils.HKDF(hashtype) self._kds = _KDSClient(self._conf.kds_endpoint) @@ -412,6 +472,7 @@ class SecureMessage(object): :param traget: The name of the target service :param timestamp: The incoming message timestamp :param esek: a base64 encoded encrypted block containing a JSON string + :param generation: Key generation number, for group keys """ rkey = None @@ -454,6 +515,20 @@ class SecureMessage(object): tkt['esek'], expiration) return self._key_store.get_ticket(self._name, target) + def _get_group_key(self, target): + gk = self._key_store.get_group_key(self._name, target) + if gk is not None: + return gk.key + + group_key, expiration = self._kds.get_group_key(self._name, target, + self._crypto, + self._key) + + self._key_store.put_group_key(self._name, target, + long(target.split(':')[1]), + group_key, expiration) + return group_key + def encode(self, version, target, json_msg): """This is the main encoding function. @@ -468,7 +543,7 @@ class SecureMessage(object): ticket = self._get_ticket(target) metadata = jsonutils.dumps({'source': self._name, - 'destination': target, + 'destination': ticket.target, 'timestamp': time.time(), 'nonce': _get_nonce(), 'esek': ticket.esek, @@ -503,12 +578,16 @@ class SecureMessage(object): if arg not in md: raise InvalidMetadata('Missing metadata "%s"' % arg) - if md['destination'] != self._name: - # TODO(simo) handle group keys by checking target + dkey = None + if md['destination'] == self._name: + dkey = self._key + elif md['destination'].split(':')[0] == self._group: + dkey = self._get_group_key(md['destination']) + else: raise UnknownDestinationName(md['destination']) try: - skey, ekey = self._decode_esek(self._key, + skey, ekey = self._decode_esek(dkey, md['source'], md['destination'], md['timestamp'], md['esek']) except InvalidExpiredTicket: diff --git a/tests/unit/rpc/test_securemessage.py b/tests/unit/rpc/test_securemessage.py index 8c07df1..eb1755b 100644 --- a/tests/unit/rpc/test_securemessage.py +++ b/tests/unit/rpc/test_securemessage.py @@ -40,7 +40,7 @@ class RpcCryptoTestCase(test_utils.BaseTestCase): keys = store.get_ticket('foo', 'bar') self.assertIsNone(keys) - ticket = rpc_secmsg.Ticket('skey', 'ekey', 'esek') + ticket = rpc_secmsg.Ticket('bar', 'skey', 'ekey', 'esek') #add entry in the cache store.put_ticket('foo', 'bar', 'skey', 'ekey', 'esek', 2000000000) |