summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimo Sorce <simo@redhat.com>2015-03-18 14:12:09 -0400
committerSimo Sorce <simo@redhat.com>2015-03-18 14:12:09 -0400
commit3481090ea41b4a04552da580f44d229735f5dd7e (patch)
treee8456b1c881fa039ec7e8bdb32d2974e28400d4e
parent33f36ea10c1db2aaa74818c60933a20a9abe672f (diff)
downloadjwcrypto-3481090ea41b4a04552da580f44d229735f5dd7e.tar.gz
jwcrypto-3481090ea41b4a04552da580f44d229735f5dd7e.tar.xz
jwcrypto-3481090ea41b4a04552da580f44d229735f5dd7e.zip
Change the way operation keys are retrieved
This way we have less confusion about what the function is supposed to do and less code duplication. Signed-off-by: Simo Sorce <simo@redhat.com>
-rw-r--r--jwcrypto/jwe.py12
-rw-r--r--jwcrypto/jwk.py49
-rw-r--r--jwcrypto/jws.py12
-rw-r--r--jwcrypto/tests.py4
4 files changed, 38 insertions, 39 deletions
diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py
index a44b7fe..8a9e4b6 100644
--- a/jwcrypto/jwe.py
+++ b/jwcrypto/jwe.py
@@ -106,13 +106,13 @@ class _rsa(_raw_key_mgmt):
self.check_key(key)
if not cek:
cek = os.urandom(keylen)
- rk = key.encrypt_key()
+ rk = key.get_op_key('encrypt')
ek = rk.encrypt(cek, self.padfn)
return (cek, ek)
def unwrap(self, key, ek):
self.check_key(key)
- rk = key.decrypt_key()
+ rk = key.get_op_key('decrypt')
cek = rk.decrypt(ek, self.padfn)
return cek
@@ -131,7 +131,7 @@ class _aes_kw(_raw_key_mgmt):
self.check_key(key)
if not cek:
cek = os.urandom(keylen)
- rk = base64url_decode(key.encrypt_key())
+ rk = base64url_decode(key.get_op_key('encrypt'))
# Implement RFC 3994 Key Unwrap - 2.2.2
# TODO: Use cryptography once issue #1733 is resolved
@@ -153,7 +153,7 @@ class _aes_kw(_raw_key_mgmt):
def unwrap(self, key, ek):
self.check_key(key)
- rk = base64url_decode(key.decrypt_key())
+ rk = base64url_decode(key.get_op_key('decrypt'))
# Implement RFC 3994 Key Unwrap - 2.2.3
# TODO: Use cryptography once issue #1733 is resolved
@@ -189,7 +189,7 @@ class _direct(_raw_key_mgmt):
self.check_key(key)
if cek:
return (cek, None)
- k = base64url_decode(key.encrypt_key())
+ k = base64url_decode(key.get_op_key('encrypt'))
if len(k) != keylen:
raise InvalidCEKeyLength(keylen, len(k))
return (k, '')
@@ -198,7 +198,7 @@ class _direct(_raw_key_mgmt):
self.check_key(key)
if ek != '':
raise InvalidJWEData('Invalid Encryption Key.')
- return base64url_decode(key.decrypt_key())
+ return base64url_decode(key.get_op_key('decrypt'))
class _raw_jwe(object):
diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py
index c989e06..2268728 100644
--- a/jwcrypto/jwk.py
+++ b/jwcrypto/jwk.py
@@ -252,19 +252,7 @@ class JWK(object):
return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']),
self._ec_pub(k, curve))
- def sign_key(self, arg=None):
- self._check_constraints('sig', 'sign')
- if self._params['kty'] == 'oct':
- return self._key['k']
- elif self._params['kty'] == 'RSA':
- return self._rsa_pri(self._key).private_key(default_backend())
- elif self._params['kty'] == 'EC':
- return self._ec_pri(self._key, arg).private_key(default_backend())
- else:
- raise NotImplementedError
-
- def verify_key(self, arg=None):
- self._check_constraints('sig', 'verify')
+ def _get_public_key(self, arg=None):
if self._params['kty'] == 'oct':
return self._key['k']
elif self._params['kty'] == 'RSA':
@@ -274,25 +262,36 @@ class JWK(object):
else:
raise NotImplementedError
- def encrypt_key(self, arg=None):
- self._check_constraints('enc', 'encrypt')
+ def _get_private_key(self, arg=None):
if self._params['kty'] == 'oct':
return self._key['k']
elif self._params['kty'] == 'RSA':
- return self._rsa_pub(self._key).public_key(default_backend())
+ return self._rsa_pri(self._key).private_key(default_backend())
elif self._params['kty'] == 'EC':
- return self._ec_pub(self._key, arg).public_key(default_backend())
+ return self._ec_pri(self._key, arg).private_key(default_backend())
else:
raise NotImplementedError
- def decrypt_key(self, arg=None):
- self._check_constraints('enc', 'decrypt')
- if self._params['kty'] == 'oct':
- return self._key['k']
- elif self._params['kty'] == 'RSA':
- return self._rsa_pri(self._key).private_key(default_backend())
- elif self._params['kty'] == 'EC':
- return self._ec_pri(self._key, arg).private_key(default_backend())
+ def get_op_key(self, operation=None, arg=None):
+ validops = self._params.get('key_ops', JWKOperationsRegistry.keys())
+ if validops is not list:
+ validops = [validops]
+ if operation is None:
+ if self._params['kty'] == 'oct':
+ return self._key['k']
+ raise InvalidJWKOperation(operation, validops)
+ elif operation == 'sign':
+ self._check_constraints('sig', operation)
+ return self._get_private_key(arg)
+ elif operation == 'verify':
+ self._check_constraints('sig', operation)
+ return self._get_public_key(arg)
+ elif operation == 'encrypt' or operation == 'wrapKey':
+ self._check_constraints('enc', operation)
+ return self._get_public_key(arg)
+ elif operation == 'decrypt' or operation == 'unwrapKey':
+ self._check_constraints('enc', operation)
+ return self._get_private_key(arg)
else:
raise NotImplementedError
diff --git a/jwcrypto/jws.py b/jwcrypto/jws.py
index 63174d6..e2e97c0 100644
--- a/jwcrypto/jws.py
+++ b/jwcrypto/jws.py
@@ -83,12 +83,12 @@ class _raw_hmac(_raw_jws):
return h
def sign(self, key, payload):
- skey = base64url_decode(key.sign_key())
+ skey = base64url_decode(key.get_op_key('sign'))
h = self._hmac_setup(skey, payload)
return h.finalize()
def verify(self, key, payload, signature):
- vkey = base64url_decode(key.verify_key())
+ vkey = base64url_decode(key.get_op_key('verify'))
h = self._hmac_setup(vkey, payload)
try:
h.verify(signature)
@@ -102,13 +102,13 @@ class _raw_rsa(_raw_jws):
self.hashfn = hashfn
def sign(self, key, payload):
- skey = key.sign_key()
+ skey = key.get_op_key('sign')
signer = skey.signer(self.padfn, self.hashfn)
signer.update(payload)
return signer.finalize()
def verify(self, key, payload, signature):
- pkey = key.verify_key()
+ pkey = key.get_op_key('verify')
verifier = pkey.verifier(signature, self.padfn, self.hashfn)
verifier.update(payload)
verifier.verify()
@@ -126,7 +126,7 @@ class _raw_ec(_raw_jws):
return e.decode('hex')
def sign(self, key, payload):
- skey = key.sign_key(self.curve)
+ skey = key.get_op_key('sign', self.curve)
signer = skey.signer(ec.ECDSA(self.hashfn))
signer.update(payload)
signature = signer.finalize()
@@ -135,7 +135,7 @@ class _raw_ec(_raw_jws):
return self.encode_int(r, l) + self.encode_int(s, l)
def verify(self, key, payload, signature):
- pkey = key.verify_key(self.curve)
+ pkey = key.get_op_key('verify', self.curve)
r = signature[:len(signature)/2]
s = signature[len(signature)/2:]
enc_signature = ec_utils.encode_rfc6979_signature(
diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py
index d66a5b4..e8fcd6b 100644
--- a/jwcrypto/tests.py
+++ b/jwcrypto/tests.py
@@ -174,8 +174,8 @@ class TestJWK(unittest.TestCase):
keylist = SymmetricKeys['keys']
for key in keylist:
jwkey = jwk.JWK(**key) # pylint: disable=star-args
- _ = jwkey.sign_key()
- _ = jwkey.verify_key()
+ _ = jwkey.get_op_key('sign')
+ _ = jwkey.get_op_key('verify')
e = jwkey.export()
self.assertEqual(json.loads(e), key)