diff options
Diffstat (limited to 'jwcrypto/jwk.py')
-rw-r--r-- | jwcrypto/jwk.py | 74 |
1 files changed, 48 insertions, 26 deletions
diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py index 25744ef..f689930 100644 --- a/jwcrypto/jwk.py +++ b/jwcrypto/jwk.py @@ -189,29 +189,36 @@ class JWK(object): def _decode_int(self, n): return int(base64url_decode(n).encode('hex'), 16) + def _rsa_pub(self, k): + return rsa.RSAPublicNumbers(self._decode_int(k['e']), + self._decode_int(k['n'])) + + def _rsa_pri(self, k): + return rsa.RSAPrivateNumbers(self._decode_int(k['p']), + self._decode_int(k['q']), + self._decode_int(k['d']), + self._decode_int(k['dp']), + self._decode_int(k['dq']), + self._decode_int(k['qi']), + self._rsa_pub(k)) + + def _ec_pub(self, k, curve): + return ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), + self._decode_int(k['y']), + self.get_curve(curve)) + + def _ec_pri(self, k, curve): + return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), + self._ec_pub(k, curve)) + def sign_key(self, arg=None): self._check_constraints('sig', 'sign') if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': - k = self._key - pub = rsa.RSAPublicNumbers(self._decode_int(k['e']), - self._decode_int(k['n'])) - pri = rsa.RSAPrivateNumbers(self._decode_int(k['p']), - self._decode_int(k['q']), - self._decode_int(k['d']), - self._decode_int(k['dp']), - self._decode_int(k['dq']), - self._decode_int(k['qi']), pub) - return pri.private_key(default_backend()) + return self._rsa_pri(self._key).private_key(default_backend()) elif self._params['kty'] == 'EC': - k = self._key - pub = ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), - self._decode_int(k['y']), - self.get_curve(arg)) - pri = ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), - pub) - return pri.private_key(default_backend()) + return self._ec_pri(self._key, arg).private_key(default_backend()) else: raise NotImplementedError @@ -220,16 +227,31 @@ class JWK(object): if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': - k = self._key - pub = rsa.RSAPublicNumbers(self._decode_int(k['e']), - self._decode_int(k['n'])) - return pub.public_key(default_backend()) + return self._rsa_pub(self._key).public_key(default_backend()) + elif self._params['kty'] == 'EC': + return self._ec_pub(self._key, arg).public_key(default_backend()) + else: + raise NotImplementedError + + def encrypt_key(self, arg=None): + self._check_constraints('enc', 'encrypt') + if self._params['kty'] == 'oct': + return self._key['k'] + elif self._params['kty'] == 'RSA': + return self._rsa_pub(self._key).public_key(default_backend()) + elif self._params['kty'] == 'EC': + return self._ec_pub(self._key, arg).public_key(default_backend()) + else: + raise NotImplementedError + + def decrypt_key(self, arg=None): + self._check_constraints('enc', 'decrypt') + if self._params['kty'] == 'oct': + return self._key['k'] + elif self._params['kty'] == 'RSA': + return self._rsa_pri(self._key).private_key(default_backend()) elif self._params['kty'] == 'EC': - k = self._key - pub = ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), - self._decode_int(k['y']), - self.get_curve(arg)) - return pub.public_key(default_backend()) + return self._ec_pri(self._key, arg).private_key(default_backend()) else: raise NotImplementedError |