From 1f8bd71e9d65fd23ac1ba2df7debd217285bb702 Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Sat, 7 Mar 2015 16:52:14 -0500 Subject: Add JWE implementation Implements: draft-ietf-jose-json-web-encryption-40 plus Tests --- jwcrypto/jwe.py | 413 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ jwcrypto/jwk.py | 74 ++++++---- jwcrypto/tests.py | 63 +++++++++ 3 files changed, 524 insertions(+), 26 deletions(-) create mode 100644 jwcrypto/jwe.py diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py new file mode 100644 index 0000000..7845b26 --- /dev/null +++ b/jwcrypto/jwe.py @@ -0,0 +1,413 @@ +# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.padding import PKCS7 +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from jwcrypto.common import base64url_encode, base64url_decode +from jwcrypto.common import InvalidJWAAlgorithm +from jwcrypto.jwk import JWK +import json +import os +import zlib + + +# draft-ietf-jose-json-web-encryption-40 - 4.1 +# name: (description, supported?) +JWEHeaderRegistry = {'alg': ('Algorithm', True), + 'enc': ('Encryption Algorithm', True), + 'zip': ('Compression Algorithm', True), + 'jku': ('JWK Set URL', False), + 'jwk': ('JSON Web Key', False), + 'kid': ('Key ID', True), + 'x5u': ('X.509 URL', False), + 'x5c': ('X.509 Certificate Chain', False), + 'x5t': ('X.509 Certificate SHA-1 Thumbprint', False), + 'x5t#S256': ('X.509 Certificate SHA-256 Thumbprint', + False), + 'typ': ('Type', True), + 'cty': ('Content Type', True), + 'crit': ('Critical', True)} + + +class InvalidJWEData(Exception): + def __init__(self, message=None, exception=None): + msg = None + if message: + msg = message + else: + msg = 'Unknown Data Verification Failure' + if exception: + msg += ' {%s}' % str(exception) + super(InvalidJWEData, self).__init__(msg) + + +class InvalidCEKeyLength(Exception): + def __init__(self, expected, obtained): + msg = 'Expected key og length %d, got %d' % (expected, obtained) + super(InvalidCEKeyLength, self).__init__(msg) + + +class InvalidJWEOperation(Exception): + def __init__(self, message=None, exception=None): + msg = None + if message: + msg = message + else: + msg = 'Unknown Operation Failure' + if exception: + msg += ' {%s}' % str(exception) + super(InvalidJWEOperation, self).__init__(msg) + + +class _raw_key_mgmt(object): + + def wrap(self, key, keylen, cek): + raise NotImplementedError + + def unwrap(self, key, ek): + raise NotImplementedError + + +class _rsa(_raw_key_mgmt): + + def __init__(self, padfn): + self.padfn = padfn + + def wrap(self, key, keylen, cek): + if not cek: + cek = os.urandom(keylen) + rk = key.encrypt_key() + ek = rk.encrypt(cek, self.padfn) + return (cek, ek) + + def unwrap(self, key, ek): + rk = key.decrypt_key() + cek = rk.decrypt(ek, self.padfn) + return cek + + +class _direct(_raw_key_mgmt): + + def wrap(self, key, keylen, cek): + if cek: + return (cek, None) + k = base64url_decode(key.encrypt_key()) + if len(k) != keylen: + raise InvalidCEKeyLength(keylen, len(k)) + return (k, '') + + def unwrap(self, key, ek): + if ek != '': + raise InvalidJWEData('Invalid Encryption Key.') + return base64url_decode(key.decrypt_key()) + + +class _raw_jwe(object): + + def encode_int(self, n, l): + e = hex(n).rstrip("L").lstrip("0x") + L = (l + 7) / 8 # number of bytes rounded up + e = '0' * (L * 2 - len(e)) + e # pad as necessary + return e.decode('hex') + + def encrypt(self, k, a, m): + raise NotImplementedError + + def decrypt(self, k, a, iv, e, t): + raise NotImplementedError + + +class _aes_cbc_hmac_sha2(_raw_jwe): + + def __init__(self, hashfn, keybits): + self.backend = default_backend() + self.hashfn = hashfn + self.blocksize = keybits / 8 + + @property + def key_size(self): + return self.blocksize * 2 + + def _mac(self, k, a, iv, e): + al = self.encode_int(len(a * 8), 64) + h = hmac.HMAC(k, self.hashfn, backend=self.backend) + h.update(a) + h.update(iv) + h.update(e) + h.update(al) + m = h.finalize() + return m[:self.blocksize] + + # draft-ietf-jose-json-web-algorithms-40 - 5.2.2 + def encrypt(self, k, a, m): + """ Encrypt accoriding to the selected encryption and hashing + functions. + + :param k: Encryption key (optional) + :param a: Additional Authentication Data + :param m: Plaintext + + Returns a dictionary with the computed data. + """ + hkey = k[:self.blocksize] + ekey = k[self.blocksize:] + + # encrypt + iv = os.urandom(self.blocksize) + cipher = Cipher(algorithms.AES(ekey), modes.CBC(iv), + backend=self.backend) + encryptor = cipher.encryptor() + padder = PKCS7(self.blocksize * 8).padder() + padded_data = padder.update(m) + padder.finalize() + e = encryptor.update(padded_data) + encryptor.finalize() + + # mac + t = self._mac(hkey, a, iv, e) + + return (iv, e, t) + + def decrypt(self, k, a, iv, e, t): + """ Decrypt accoriding to the selected encryption and hashing + functions. + :param k: Encryption key (optional) + :param a: Additional Authenticated Data + :param iv: Initialization Vector + :param e: Ciphertext + :param t: Authentication Tag + + Returns plaintext or raises an error + """ + hkey = k[:self.blocksize] + dkey = k[self.blocksize:] + + # verify mac + if t != self._mac(hkey, a, iv, e): + raise InvalidJWEData('Failed to verify MAC') + + # decrypt + cipher = Cipher(algorithms.AES(dkey), modes.CBC(iv), + backend=self.backend) + decryptor = cipher.decryptor() + d = decryptor.update(e) + decryptor.finalize() + unpadder = PKCS7(self.blocksize * 8).unpadder() + return unpadder.update(d) + unpadder.finalize() + + +class JWE(object): + + def __init__(self, plaintext=None, protected=None, unprotected=None, + aad=None): + """ Generates or verifies Generic JWE tokens. + See draft-ietf-jose-json-web-signature-41 + + :param plaintext(bytes): An arbitrary plaintext to be encrypted + :param protected(json): The shared protected header + :param unprotected(json): The shared unprotected header + :param aad(bytes): Arbitrary additional authenticated data + """ + self.objects = {'recipients': list()} + self.plaintext = plaintext + self.cek = None + if aad: + self.objects['aad'] = aad + if protected: + _ = json.loads(protected) # check header encoding + self.objects['protected'] = protected + if unprotected: + _ = json.loads(unprotected) # check header encoding + self.objects['unprotected'] = unprotected + + # key wrapping mechanisms + def _jwa_RSA1_5(self): + return _rsa(padding.PKCS1v15()) + + def _jwa_dir(self): + return _direct() + + # content encryption mechanisms + def _jwa_A128CBC_HS256(self): + return _aes_cbc_hmac_sha2(hashes.SHA256(), 128) + + def _jwa(self, name): + attr = '_jwa_%s' % name.replace('-', '_').replace('+', '_') + try: + return getattr(self, attr)() + except (KeyError, AttributeError): + raise InvalidJWAAlgorithm() + + def merge_headers(self, h1, h2): + for k in h1.keys(): + if k in h2: + raise InvalidJWEData('Duplicate header: "%s"' % k) + h1.update(h2) + return h1 + + def add_recipient(self, key, header=None): + """ Encrypt the provided payload with the given key. + + :param key: A JWK key of appropriate type for the "alg" + provided in the 'protected' json string. + See draft-ietf-jose-json-web-key-41 + + :param header: A JSON string representing the per-recipient header. + """ + if self.plaintext is None: + raise ValueError('Missing plaintext') + if not isinstance(key, JWK): + raise ValueError('key is not a JWK object') + + ph = json.loads(self.objects['protected']) + if 'unprotected' in self.objects: + uh = json.loads(self.objects['unprotected']) + ph = self.merge_headers(ph, uh) + if header: + rh = json.loads('header') + ph = self.merge_headers(ph, rh) + + alg = self._jwa(ph.get('alg', None)) + enc = self._jwa(ph.get('enc', None)) + + rec = dict() + if header: + rec['header'] = header + + self.cek, ek = alg.wrap(key, enc.key_size, self.cek) + if ek: + rec['encrypted_key'] = ek + + if 'ciphertext' not in self.objects: + aad = base64url_encode(self.objects.get('protected', '')) + if 'aad' in self.objects: + aad += '.' + base64url_encode(self.objects['aad']) + + compress = ph.get('zip', None) + if compress == 'DEF': + data = zlib.compress(self.plaintext)[2:-4] + elif compress is None: + data = self.plaintext + else: + raise ValueError('Unknown compression') + + iv, ciphertext, tag = enc.encrypt(self.cek, aad, data) + self.objects['iv'] = iv + self.objects['ciphertext'] = ciphertext + self.objects['tag'] = tag + + self.objects['recipients'].append(rec) + + def serialize(self, compact=False): + + if 'ciphertext' not in self.objects: + raise InvalidJWEOperation("No available ciphertext") + + if compact: + for invalid in 'aad', 'unprotected': + if invalid in self.objects: + raise InvalidJWEOperation("Can't use compact encoding") + if len(self.objects['recipients']) != 1: + raise InvalidJWEOperation("Invalid number of recipients") + rec = self.objects['recipients'][0] + return '.'.join([base64url_encode(self.objects['protected']), + base64url_encode(rec['encrypted_key']), + base64url_encode(self.objects['iv']), + base64url_encode(self.objects['ciphertext']), + base64url_encode(self.objects['tag'])]) + else: + obj = self.objects + enc = {'ciphertext': base64url_encode(obj['ciphertext']), + 'iv': base64url_encode(obj['iv']), + 'tag': base64url_encode(self.objects['tag']), + 'recipients': list()} + if 'protected' in obj: + enc['protected'] = base64url_encode(obj['protected']) + if 'unprotected' in obj: + enc['unprotected'] = json.loads(obj['unprotected']) + if 'aad' in obj: + enc['aad'] = base64url_encode(obj['aad']) + for rec in obj['recipients']: + e = dict() + if 'encrypted_key' in rec: + e['encrypted_key'] = base64url_encode(rec['encrypted_key']) + if 'header' in rec: + e['header'] = json.loads(rec['header']) + rec['recipients'].append(e) + return json.dumps(enc) + + def check_crit(self, crit): + for k in crit: + if k not in JWEHeaderRegistry: + raise InvalidJWEData('Unknown critical header: "%s"' % k) + else: + if not JWEHeaderRegistry[k][1]: + raise InvalidJWEData('Unsupported critical header: ' + '"%s"' % k) + + # FIXME: allow to specify which algorithms to accept as valid + def decrypt(self, key): + if not isinstance(key, JWK): + raise ValueError('key is not a JWK object') + if 'ciphertext' not in self.objects: + raise InvalidJWEOperation("No available ciphertext") + + for rec in self.objects['recipients']: + + ph = json.loads(self.objects['protected']) + if 'unprotected' in self.objects: + uh = json.loads(self.objects['unprotected']) + ph = self.merge_headers(ph, uh) + if 'header' in rec: + rh = json.loads(rec['header']) + ph = self.merge_headers(ph, rh) + # TODO: allow caller to specify list of headers it understands + if 'crit' in ph: + self.check_crit(ph['crit']) + + alg = self._jwa(ph.get('alg', None)) + enc = self._jwa(ph.get('enc', None)) + + aad = base64url_encode(self.objects.get('protected', '')) + if 'aad' in self.objects: + aad += '.' + base64url_encode(self.objects['aad']) + + cek = alg.unwrap(key, rec['encrypted_key']) + data = enc.decrypt(cek, aad, self.objects['iv'], + self.objects['ciphertext'], + self.objects['tag']) + + compress = ph.get('zip', None) + if compress == 'DEF': + self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS) + elif compress is None: + self.plaintext = data + else: + raise ValueError('Unknown compression') + + def deserialize(self, raw_jwe, key=None): + """ Destroys any current status and tries to import the raw + JWS provided. + """ + self.objects = dict() + o = dict() + try: + try: + djwe = json.loads(raw_jwe) + _ = djwe + raise NotImplementedError + except ValueError: + c = raw_jwe.split('.') + if len(c) != 5: + raise InvalidJWEData() + o['protected'] = base64url_decode(str(c[0])) + o['iv'] = base64url_decode(str(c[2])) + o['ciphertext'] = base64url_decode(str(c[3])) + o['tag'] = base64url_decode(str(c[4])) + o['recipients'] = [{'encrypted_key': + base64url_decode(str(c[1]))}] + self.objects = o + if key: + self.decrypt(key) + + except Exception, e: # pylint: disable=broad-except + raise InvalidJWEData('Invalid format', e) diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py index 25744ef..f689930 100644 --- a/jwcrypto/jwk.py +++ b/jwcrypto/jwk.py @@ -189,29 +189,36 @@ class JWK(object): def _decode_int(self, n): return int(base64url_decode(n).encode('hex'), 16) + def _rsa_pub(self, k): + return rsa.RSAPublicNumbers(self._decode_int(k['e']), + self._decode_int(k['n'])) + + def _rsa_pri(self, k): + return rsa.RSAPrivateNumbers(self._decode_int(k['p']), + self._decode_int(k['q']), + self._decode_int(k['d']), + self._decode_int(k['dp']), + self._decode_int(k['dq']), + self._decode_int(k['qi']), + self._rsa_pub(k)) + + def _ec_pub(self, k, curve): + return ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), + self._decode_int(k['y']), + self.get_curve(curve)) + + def _ec_pri(self, k, curve): + return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), + self._ec_pub(k, curve)) + def sign_key(self, arg=None): self._check_constraints('sig', 'sign') if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': - k = self._key - pub = rsa.RSAPublicNumbers(self._decode_int(k['e']), - self._decode_int(k['n'])) - pri = rsa.RSAPrivateNumbers(self._decode_int(k['p']), - self._decode_int(k['q']), - self._decode_int(k['d']), - self._decode_int(k['dp']), - self._decode_int(k['dq']), - self._decode_int(k['qi']), pub) - return pri.private_key(default_backend()) + return self._rsa_pri(self._key).private_key(default_backend()) elif self._params['kty'] == 'EC': - k = self._key - pub = ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), - self._decode_int(k['y']), - self.get_curve(arg)) - pri = ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), - pub) - return pri.private_key(default_backend()) + return self._ec_pri(self._key, arg).private_key(default_backend()) else: raise NotImplementedError @@ -220,16 +227,31 @@ class JWK(object): if self._params['kty'] == 'oct': return self._key['k'] elif self._params['kty'] == 'RSA': - k = self._key - pub = rsa.RSAPublicNumbers(self._decode_int(k['e']), - self._decode_int(k['n'])) - return pub.public_key(default_backend()) + return self._rsa_pub(self._key).public_key(default_backend()) + elif self._params['kty'] == 'EC': + return self._ec_pub(self._key, arg).public_key(default_backend()) + else: + raise NotImplementedError + + def encrypt_key(self, arg=None): + self._check_constraints('enc', 'encrypt') + if self._params['kty'] == 'oct': + return self._key['k'] + elif self._params['kty'] == 'RSA': + return self._rsa_pub(self._key).public_key(default_backend()) + elif self._params['kty'] == 'EC': + return self._ec_pub(self._key, arg).public_key(default_backend()) + else: + raise NotImplementedError + + def decrypt_key(self, arg=None): + self._check_constraints('enc', 'decrypt') + if self._params['kty'] == 'oct': + return self._key['k'] + elif self._params['kty'] == 'RSA': + return self._rsa_pri(self._key).private_key(default_backend()) elif self._params['kty'] == 'EC': - k = self._key - pub = ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), - self._decode_int(k['y']), - self.get_curve(arg)) - return pub.public_key(default_backend()) + return self._ec_pri(self._key, arg).private_key(default_backend()) else: raise NotImplementedError diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index 881dfd7..181ca6f 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -3,6 +3,7 @@ from jwcrypto.common import base64url_decode from jwcrypto import jwk from jwcrypto import jws +from jwcrypto import jwe import json import unittest @@ -434,3 +435,65 @@ class TestJWS(unittest.TestCase): S = jws.JWS(A6_example['payload']) S.deserialize(E_negative) self.assertEqual(False, S.objects['valid']) + + +E_A2_plaintext = "Live long and prosper." +E_A2_protected = "eyJhbGciOiJSU0ExXzUiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0" +E_A2_key = \ + {"kty": "RSA", + "n": "sXchDaQebHnPiGvyDOAT4saGEUetSyo9MKLOoWFsueri23bOdgWp4Dy1Wl" + "UzewbgBHod5pcM9H95GQRV3JDXboIRROSBigeC5yjU1hGzHHyXss8UDpre" + "cbAYxknTcQkhslANGRUZmdTOQ5qTRsLAt6BTYuyvVRdhS8exSZEy_c4gs_" + "7svlJJQ4H9_NxsiIoLwAEk7-Q3UXERGYw_75IDrGA84-lA_-Ct4eTlXHBI" + "Y2EaV7t7LjJaynVJCpkv4LKjTTAumiGUIuQhrNhZLuF_RJLqHpM2kgWFLU" + "7-VTdL1VbC2tejvcI2BlMkEpk1BzBZI0KQB0GaDWFLN-aEAw3vRw", + "e": "AQAB", + "d": "VFCWOqXr8nvZNyaaJLXdnNPXZKRaWCjkU5Q2egQQpTBMwhprMzWzpR8Sxq" + "1OPThh_J6MUD8Z35wky9b8eEO0pwNS8xlh1lOFRRBoNqDIKVOku0aZb-ry" + "nq8cxjDTLZQ6Fz7jSjR1Klop-YKaUHc9GsEofQqYruPhzSA-QgajZGPbE_" + "0ZaVDJHfyd7UUBUKunFMScbflYAAOYJqVIVwaYR5zWEEceUjNnTNo_CVSj" + "-VvXLO5VZfCUAVLgW4dpf1SrtZjSt34YLsRarSb127reG_DUwg9Ch-Kyvj" + "T1SkHgUWRVGcyly7uvVGRSDwsXypdrNinPA4jlhoNdizK2zF2CWQ", + "p": "9gY2w6I6S6L0juEKsbeDAwpd9WMfgqFoeA9vEyEUuk4kLwBKcoe1x4HG68" + "ik918hdDSE9vDQSccA3xXHOAFOPJ8R9EeIAbTi1VwBYnbTp87X-xcPWlEP" + "krdoUKW60tgs1aNd_Nnc9LEVVPMS390zbFxt8TN_biaBgelNgbC95sM", + "q": "uKlCKvKv_ZJMVcdIs5vVSU_6cPtYI1ljWytExV_skstvRSNi9r66jdd9-y" + "BhVfuG4shsp2j7rGnIio901RBeHo6TPKWVVykPu1iYhQXw1jIABfw-MVsN" + "-3bQ76WLdt2SDxsHs7q7zPyUyHXmps7ycZ5c72wGkUwNOjYelmkiNS0", + "dp": "w0kZbV63cVRvVX6yk3C8cMxo2qCM4Y8nsq1lmMSYhG4EcL6FWbX5h9yuv" + "ngs4iLEFk6eALoUS4vIWEwcL4txw9LsWH_zKI-hwoReoP77cOdSL4AVcra" + "Hawlkpyd2TWjE5evgbhWtOxnZee3cXJBkAi64Ik6jZxbvk-RR3pEhnCs", + "dq": "o_8V14SezckO6CNLKs_btPdFiO9_kC1DsuUTd2LAfIIVeMZ7jn1Gus_Ff" + "7B7IVx3p5KuBGOVF8L-qifLb6nQnLysgHDh132NDioZkhH7mI7hPG-PYE_" + "odApKdnqECHWw0J-F0JWnUd6D2B_1TvF9mXA2Qx-iGYn8OVV1Bsmp6qU", + "qi": "eNho5yRBEBxhGBtQRww9QirZsB66TrfFReG_CcteI1aCneT0ELGhYlRlC" + "tUkTRclIfuEPmNsNDPbLoLqqCVznFbvdB7x-Tl-m0l_eFTj2KiqwGqE9PZ" + "B9nNTwMVvH3VRRSLWACvPnSiwP8N5Usy-WRXS-V7TbpxIhvepTfE0NNo"} +E_A2_vector = \ + "eyJhbGciOiJSU0ExXzUiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0." \ + "UGhIOguC7IuEvf_NPVaXsGMoLOmwvc1GyqlIKOK1nN94nHPoltGRhWhw7Zx0-kFm" \ + "1NJn8LE9XShH59_i8J0PH5ZZyNfGy2xGdULU7sHNF6Gp2vPLgNZ__deLKxGHZ7Pc" \ + "HALUzoOegEI-8E66jX2E4zyJKx-YxzZIItRzC5hlRirb6Y5Cl_p-ko3YvkkysZIF" \ + "NPccxRU7qve1WYPxqbb2Yw8kZqa2rMWI5ng8OtvzlV7elprCbuPhcCdZ6XDP0_F8" \ + "rkXds2vE4X-ncOIM8hAYHHi29NX0mcKiRaD0-D-ljQTP-cFPgwCp6X-nZZd9OHBv" \ + "-B3oWh2TbqmScqXMR4gp_A." \ + "AxY8DCtDaGlsbGljb3RoZQ." \ + "KDlTtXchhZTGufMYmOYGS4HffxPSUrfmqCHXaI9wOGY." \ + "9hH0vgRfYgPnAHOd8stkvw" + +E_A2_ex = {'key': jwk.JWK(**E_A2_key), # pylint: disable=star-args + 'protected': base64url_decode(E_A2_protected), + 'plaintext': E_A2_plaintext, + 'vector': E_A2_vector} + + +class TestJWE(unittest.TestCase): + def test_A2(self): + E = jwe.JWE(E_A2_ex['plaintext'], E_A2_ex['protected']) + E.add_recipient(E_A2_ex['key']) + # Encrypt and serialize using compact + e = E.serialize(compact=True) + # And test that we can decrypt our own + E.deserialize(e, E_A2_ex['key']) + # Now test the Spec Test Vector + E.deserialize(E_A2_ex['vector'], E_A2_ex['key']) -- cgit