diff options
Diffstat (limited to 'jwcrypto/jwe.py')
-rw-r--r-- | jwcrypto/jwe.py | 181 |
1 files changed, 174 insertions, 7 deletions
diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 6bea310..40f8edb 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -31,6 +31,22 @@ JWEHeaderRegistry = {'alg': ('Algorithm', True), 'crit': ('Critical', True)} +# Note: l is the number of bits, which should be a multiple of 16 +def encode_int(n, l): + e = hex(n).rstrip("L").lstrip("0x") + el = len(e) + L = ((l + 7) // 8) * 2 # number of bytes rounded up times 2 chars/bytes + if el > L: + e = e[:L] + else: + e = '0' * (L - el) + e # pad as necessary + return e.decode('hex') + + +def decode_int(n): + return int(n.encode('hex'), 16) + + class InvalidJWEData(Exception): def __init__(self, message=None, exception=None): msg = None @@ -61,6 +77,12 @@ class InvalidJWEOperation(Exception): super(InvalidJWEOperation, self).__init__(msg) +class InvalidJWEKeyType(Exception): + def __init__(self, expected, obtained): + msg = 'Expected key type %s, got %s' % (expected, obtained) + super(InvalidJWEKeyType, self).__init__(msg) + + class _raw_key_mgmt(object): def wrap(self, key, keylen, cek): @@ -75,7 +97,13 @@ class _rsa(_raw_key_mgmt): def __init__(self, padfn): self.padfn = padfn + def check_key(self, key): + if key.key_type != 'RSA': + raise InvalidJWEKeyType('RSA', key.key_type) + + # FIXME: get key size and insure > 2048 bits def wrap(self, key, keylen, cek): + self.check_key(key) if not cek: cek = os.urandom(keylen) rk = key.encrypt_key() @@ -83,14 +111,82 @@ class _rsa(_raw_key_mgmt): return (cek, ek) def unwrap(self, key, ek): + self.check_key(key) rk = key.decrypt_key() cek = rk.decrypt(ek, self.padfn) return cek +class _aes_kw(_raw_key_mgmt): + + def __init__(self, keysize): + self.backend = default_backend() + self.keysize = keysize + + def check_key(self, key): + if key.key_type != 'oct': + raise InvalidJWEKeyType('oct', key.key_type) + + def wrap(self, key, keylen, cek): + self.check_key(key) + if not cek: + cek = os.urandom(keylen) + rk = base64url_decode(key.encrypt_key()) + + # Implement RFC 3994 Key Unwrap - 2.2.2 + # TODO: Use cryptography once issue #1733 is resolved + iv = 'a6a6a6a6a6a6a6a6' + A = iv.decode('hex') + R = [cek[i:i+8] for i in range(0, len(cek), 8)] + n = len(R) + for j in range(0, 6): + for i in range(0, n): + e = Cipher(algorithms.AES(rk), modes.ECB(), + backend=self.backend).encryptor() + B = e.update(A + R[i]) + e.finalize() + A = encode_int(decode_int(B[:8]) ^ ((n*j)+i+1), 64) + R[i] = B[-8:] + ek = A + for i in range(0, n): + ek += R[i] + return (cek, ek) + + def unwrap(self, key, ek): + self.check_key(key) + rk = base64url_decode(key.decrypt_key()) + + # Implement RFC 3994 Key Unwrap - 2.2.3 + # TODO: Use cryptography once issue #1733 is resolved + iv = 'a6a6a6a6a6a6a6a6' + Aiv = iv.decode('hex') + + R = [ek[i:i+8] for i in range(0, len(ek), 8)] + A = R.pop(0) + n = len(R) + for j in range(5, -1, -1): + for i in range(n - 1, -1, -1): + AtR = encode_int((decode_int(A) ^ ((n*j)+i+1)), 64) + R[i] + d = Cipher(algorithms.AES(rk), modes.ECB(), + backend=self.backend).decryptor() + B = d.update(AtR) + d.finalize() + A = B[:8] + R[i] = B[-8:] + + if A != Aiv: + raise InvalidJWEData('Decryption Failed') + + cek = ''.join(R) + return cek + + class _direct(_raw_key_mgmt): + def check_key(self, key): + if key.key_type != 'oct': + raise InvalidJWEKeyType('oct', key.key_type) + def wrap(self, key, keylen, cek): + self.check_key(key) if cek: return (cek, None) k = base64url_decode(key.encrypt_key()) @@ -99,6 +195,7 @@ class _direct(_raw_key_mgmt): return (k, '') def unwrap(self, key, ek): + self.check_key(key) if ek != '': raise InvalidJWEData('Invalid Encryption Key.') return base64url_decode(key.decrypt_key()) @@ -106,12 +203,6 @@ class _direct(_raw_key_mgmt): class _raw_jwe(object): - def encode_int(self, n, l): - e = hex(n).rstrip("L").lstrip("0x") - L = (l + 7) / 8 # number of bytes rounded up - e = '0' * (L * 2 - len(e)) + e # pad as necessary - return e.decode('hex') - def encrypt(self, k, a, m): raise NotImplementedError @@ -131,7 +222,7 @@ class _aes_cbc_hmac_sha2(_raw_jwe): return self.blocksize * 2 def _mac(self, k, a, iv, e): - al = self.encode_int(len(a * 8), 64) + al = encode_int(len(a * 8), 64) h = hmac.HMAC(k, self.hashfn, backend=self.backend) h.update(a) h.update(iv) @@ -195,6 +286,54 @@ class _aes_cbc_hmac_sha2(_raw_jwe): return unpadder.update(d) + unpadder.finalize() +class _aes_gcm(_raw_jwe): + + def __init__(self, keybits): + self.backend = default_backend() + self.blocksize = keybits / 8 + + @property + def key_size(self): + return self.blocksize + + # draft-ietf-jose-json-web-algorithms-40 - 5.2.2 + def encrypt(self, k, a, m): + """ Encrypt accoriding to the selected encryption and hashing + functions. + + :param k: Encryption key (optional) + :param a: Additional Authentication Data + :param m: Plaintext + + Returns a dictionary with the computed data. + """ + iv = os.urandom(96 / 8) + cipher = Cipher(algorithms.AES(k), modes.GCM(iv), + backend=self.backend) + encryptor = cipher.encryptor() + encryptor.authenticate_additional_data(a) + e = encryptor.update(m) + encryptor.finalize() + + return (iv, e, encryptor.tag) + + def decrypt(self, k, a, iv, e, t): + """ Decrypt accoriding to the selected encryption and hashing + functions. + :param k: Encryption key (optional) + :param a: Additional Authenticated Data + :param iv: Initialization Vector + :param e: Ciphertext + :param t: Authentication Tag + + Returns plaintext or raises an error + """ + cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t), + backend=self.backend) + decryptor = cipher.decryptor() + decryptor.authenticate_additional_data(a) + return decryptor.update(e) + decryptor.finalize() + + class JWE(object): def __init__(self, plaintext=None, protected=None, unprotected=None, @@ -224,6 +363,19 @@ class JWE(object): def _jwa_RSA1_5(self): return _rsa(padding.PKCS1v15()) + def _jwa_RSA_OAEP(self): + return _rsa(padding.OAEP(padding.MGF1(hashes.SHA1()), + hashes.SHA1(), + None)) + + def _jwa_RSA_OAEP_256(self): + return _rsa(padding.OAEP(padding.MGF1(hashes.SHA256()), + hashes.SHA256(), + None)) + + def _jwa_A128KW(self): + return _aes_kw(128) + def _jwa_dir(self): return _direct() @@ -231,6 +383,21 @@ class JWE(object): def _jwa_A128CBC_HS256(self): return _aes_cbc_hmac_sha2(hashes.SHA256(), 128) + def _jwa_A192CBC_HS384(self): + return _aes_cbc_hmac_sha2(hashes.SHA384(), 192) + + def _jwa_A256CBC_HS512(self): + return _aes_cbc_hmac_sha2(hashes.SHA512(), 256) + + def _jwa_A128GCM(self): + return _aes_gcm(128) + + def _jwa_A192GCM(self): + return _aes_gcm(192) + + def _jwa_A256GCM(self): + return _aes_gcm(256) + def _jwa(self, name): attr = '_jwa_%s' % name.replace('-', '_').replace('+', '_') try: |