From 3481090ea41b4a04552da580f44d229735f5dd7e Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Wed, 18 Mar 2015 14:12:09 -0400 Subject: Change the way operation keys are retrieved This way we have less confusion about what the function is supposed to do and less code duplication. Signed-off-by: Simo Sorce --- jwcrypto/jwe.py | 12 ++++++------ jwcrypto/jwk.py | 49 ++++++++++++++++++++++++------------------------- jwcrypto/jws.py | 12 ++++++------ jwcrypto/tests.py | 4 ++-- 4 files changed, 38 insertions(+), 39 deletions(-) diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index a44b7fe..8a9e4b6 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -106,13 +106,13 @@ class _rsa(_raw_key_mgmt): self.check_key(key) if not cek: cek = os.urandom(keylen) - rk = key.encrypt_key() + rk = key.get_op_key('encrypt') ek = rk.encrypt(cek, self.padfn) return (cek, ek) def unwrap(self, key, ek): self.check_key(key) - rk = key.decrypt_key() + rk = key.get_op_key('decrypt') cek = rk.decrypt(ek, self.padfn) return cek @@ -131,7 +131,7 @@ class _aes_kw(_raw_key_mgmt): self.check_key(key) if not cek: cek = os.urandom(keylen) - rk = base64url_decode(key.encrypt_key()) + rk = base64url_decode(key.get_op_key('encrypt')) # Implement RFC 3994 Key Unwrap - 2.2.2 # TODO: Use cryptography once issue #1733 is resolved @@ -153,7 +153,7 @@ class _aes_kw(_raw_key_mgmt): def unwrap(self, key, ek): self.check_key(key) - rk = base64url_decode(key.decrypt_key()) + rk = base64url_decode(key.get_op_key('decrypt')) # Implement RFC 3994 Key Unwrap - 2.2.3 # TODO: Use cryptography once issue #1733 is resolved @@ -189,7 +189,7 @@ class _direct(_raw_key_mgmt): self.check_key(key) if cek: return (cek, None) - k = base64url_decode(key.encrypt_key()) + k = base64url_decode(key.get_op_key('encrypt')) if len(k) != keylen: raise InvalidCEKeyLength(keylen, len(k)) return (k, '') @@ -198,7 +198,7 @@ class _direct(_raw_key_mgmt): self.check_key(key) if ek != '': raise InvalidJWEData('Invalid Encryption Key.') - return base64url_decode(key.decrypt_key()) + return base64url_decode(key.get_op_key('decrypt')) class _raw_jwe(object): diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py index c989e06..2268728 100644 --- a/jwcrypto/jwk.py +++ b/jwcrypto/jwk.py @@ -252,19 +252,7 @@ class JWK(object): return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), self._ec_pub(k, curve)) - def sign_key(self, arg=None): - self._check_constraints('sig', 'sign') - if self._params['kty'] == 'oct': - return self._key['k'] - elif self._params['kty'] == 'RSA': - return self._rsa_pri(self._key).private_key(default_backend()) - elif self._params['kty'] == 'EC': - return self._ec_pri(self._key, arg).private_key(default_backend()) - else: - raise NotImplementedError - - def verify_key(self, arg=None): - self._check_constraints('sig', 'verify') + def _get_public_key(self, arg=None): if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': @@ -274,25 +262,36 @@ class JWK(object): else: raise NotImplementedError - def encrypt_key(self, arg=None): - self._check_constraints('enc', 'encrypt') + def _get_private_key(self, arg=None): if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': - return self._rsa_pub(self._key).public_key(default_backend()) + return self._rsa_pri(self._key).private_key(default_backend()) elif self._params['kty'] == 'EC': - return self._ec_pub(self._key, arg).public_key(default_backend()) + return self._ec_pri(self._key, arg).private_key(default_backend()) else: raise NotImplementedError - def decrypt_key(self, arg=None): - self._check_constraints('enc', 'decrypt') - if self._params['kty'] == 'oct': - return self._key['k'] - elif self._params['kty'] == 'RSA': - return self._rsa_pri(self._key).private_key(default_backend()) - elif self._params['kty'] == 'EC': - return self._ec_pri(self._key, arg).private_key(default_backend()) + def get_op_key(self, operation=None, arg=None): + validops = self._params.get('key_ops', JWKOperationsRegistry.keys()) + if validops is not list: + validops = [validops] + if operation is None: + if self._params['kty'] == 'oct': + return self._key['k'] + raise InvalidJWKOperation(operation, validops) + elif operation == 'sign': + self._check_constraints('sig', operation) + return self._get_private_key(arg) + elif operation == 'verify': + self._check_constraints('sig', operation) + return self._get_public_key(arg) + elif operation == 'encrypt' or operation == 'wrapKey': + self._check_constraints('enc', operation) + return self._get_public_key(arg) + elif operation == 'decrypt' or operation == 'unwrapKey': + self._check_constraints('enc', operation) + return self._get_private_key(arg) else: raise NotImplementedError diff --git a/jwcrypto/jws.py b/jwcrypto/jws.py index 63174d6..e2e97c0 100644 --- a/jwcrypto/jws.py +++ b/jwcrypto/jws.py @@ -83,12 +83,12 @@ class _raw_hmac(_raw_jws): return h def sign(self, key, payload): - skey = base64url_decode(key.sign_key()) + skey = base64url_decode(key.get_op_key('sign')) h = self._hmac_setup(skey, payload) return h.finalize() def verify(self, key, payload, signature): - vkey = base64url_decode(key.verify_key()) + vkey = base64url_decode(key.get_op_key('verify')) h = self._hmac_setup(vkey, payload) try: h.verify(signature) @@ -102,13 +102,13 @@ class _raw_rsa(_raw_jws): self.hashfn = hashfn def sign(self, key, payload): - skey = key.sign_key() + skey = key.get_op_key('sign') signer = skey.signer(self.padfn, self.hashfn) signer.update(payload) return signer.finalize() def verify(self, key, payload, signature): - pkey = key.verify_key() + pkey = key.get_op_key('verify') verifier = pkey.verifier(signature, self.padfn, self.hashfn) verifier.update(payload) verifier.verify() @@ -126,7 +126,7 @@ class _raw_ec(_raw_jws): return e.decode('hex') def sign(self, key, payload): - skey = key.sign_key(self.curve) + skey = key.get_op_key('sign', self.curve) signer = skey.signer(ec.ECDSA(self.hashfn)) signer.update(payload) signature = signer.finalize() @@ -135,7 +135,7 @@ class _raw_ec(_raw_jws): return self.encode_int(r, l) + self.encode_int(s, l) def verify(self, key, payload, signature): - pkey = key.verify_key(self.curve) + pkey = key.get_op_key('verify', self.curve) r = signature[:len(signature)/2] s = signature[len(signature)/2:] enc_signature = ec_utils.encode_rfc6979_signature( diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index d66a5b4..e8fcd6b 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -174,8 +174,8 @@ class TestJWK(unittest.TestCase): keylist = SymmetricKeys['keys'] for key in keylist: jwkey = jwk.JWK(**key) # pylint: disable=star-args - _ = jwkey.sign_key() - _ = jwkey.verify_key() + _ = jwkey.get_op_key('sign') + _ = jwkey.get_op_key('verify') e = jwkey.export() self.assertEqual(json.loads(e), key) -- cgit