diff options
-rw-r--r-- | custodia/message/kem.py | 42 |
1 files changed, 24 insertions, 18 deletions
diff --git a/custodia/message/kem.py b/custodia/message/kem.py index 4f2d70d..489f657 100644 --- a/custodia/message/kem.py +++ b/custodia/message/kem.py @@ -95,6 +95,20 @@ class KEMKeysStore(SimplePathAuthz): return self._alg +def check_kem_claims(claims, name): + if 'sub' not in claims: + raise InvalidMessage('Missing subject in payload') + if claims['sub'] != name: + raise InvalidMessage('Key name %s does not match subject %s' % ( + name, claims['sub'])) + if 'exp' not in claims: + raise InvalidMessage('Missing expiration time in payload') + if claims['exp'] - (10 * 60) > int(time.time()): + raise InvalidMessage('Message expiration too far in the future') + if claims['exp'] < int(time.time()): + raise InvalidMessage('Message Expired') + + class KEMHandler(MessageHandler): """Handles 'kem' messages""" @@ -157,18 +171,8 @@ class KEMHandler(MessageHandler): except Exception as e: raise InvalidMessage('Failed to validate message: %s' % str(e)) - # FIXME: check name/time - if 'sub' not in claims: - raise InvalidMessage('Missing subject in payload') - if claims['sub'] != name: - raise InvalidMessage('Key name %s does not match subject %s' % ( - name, claims['sub'])) - if 'exp' not in claims: - raise InvalidMessage('Missing request time in payload') - if claims['exp'] - (10 * 60) > int(time.time()): - raise InvalidMessage('Message expiration too far in the future') - if claims['exp'] < int(time.time()): - raise InvalidMessage('Message Expired') + check_kem_claims(claims, name) + self.name = name return {'type': 'kem', 'value': {'kid': self.client_keys[KEY_USAGE_ENC].key_id, @@ -207,12 +211,14 @@ class KEMClient(object): self.client_keys[KEY_USAGE_SIG], alg, self.server_keys[KEY_USAGE_ENC], encalg) - def parse_reply(self, message): + def parse_reply(self, name, message): E = JWT(jwt=message, key=self.client_keys[KEY_USAGE_ENC]) S = JWT(jwt=E.claims, key=self.server_keys[KEY_USAGE_SIG]) - return S.claims + claims = json_decode(S.claims) + check_kem_claims(claims, name) + return claims['value'] def make_sig_kem(name, value, key, alg): @@ -381,8 +387,8 @@ class KEMTests(unittest.TestCase): req = cli.make_request("key name") kem.parse(req, "key name") msg = json_decode(kem.reply('key value')) - rep = json_decode(cli.parse_reply(msg['value'])) - self.assertEqual(rep['value'], 'key value') + rep = cli.parse_reply("key name", msg['value']) + self.assertEqual(rep, 'key value') def test_3_KEMClient(self): server_keys = [JWK(**test_keys[KEY_USAGE_SIG]), @@ -394,5 +400,5 @@ class KEMTests(unittest.TestCase): req = cli.make_request("key name", encalg=('RSA1_5', 'A256CBC-HS512')) kem.parse(req, "key name") msg = json_decode(kem.reply('key value')) - rep = json_decode(cli.parse_reply(msg['value'])) - self.assertEqual(rep['value'], 'key value') + rep = cli.parse_reply("key name", msg['value']) + self.assertEqual(rep, 'key value') |