summaryrefslogtreecommitdiffstats
path: root/jwcrypto/jwe.py
diff options
context:
space:
mode:
Diffstat (limited to 'jwcrypto/jwe.py')
-rw-r--r--jwcrypto/jwe.py181
1 files changed, 174 insertions, 7 deletions
diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py
index 6bea310..40f8edb 100644
--- a/jwcrypto/jwe.py
+++ b/jwcrypto/jwe.py
@@ -31,6 +31,22 @@ JWEHeaderRegistry = {'alg': ('Algorithm', True),
'crit': ('Critical', True)}
+# Note: l is the number of bits, which should be a multiple of 16
+def encode_int(n, l):
+ e = hex(n).rstrip("L").lstrip("0x")
+ el = len(e)
+ L = ((l + 7) // 8) * 2 # number of bytes rounded up times 2 chars/bytes
+ if el > L:
+ e = e[:L]
+ else:
+ e = '0' * (L - el) + e # pad as necessary
+ return e.decode('hex')
+
+
+def decode_int(n):
+ return int(n.encode('hex'), 16)
+
+
class InvalidJWEData(Exception):
def __init__(self, message=None, exception=None):
msg = None
@@ -61,6 +77,12 @@ class InvalidJWEOperation(Exception):
super(InvalidJWEOperation, self).__init__(msg)
+class InvalidJWEKeyType(Exception):
+ def __init__(self, expected, obtained):
+ msg = 'Expected key type %s, got %s' % (expected, obtained)
+ super(InvalidJWEKeyType, self).__init__(msg)
+
+
class _raw_key_mgmt(object):
def wrap(self, key, keylen, cek):
@@ -75,7 +97,13 @@ class _rsa(_raw_key_mgmt):
def __init__(self, padfn):
self.padfn = padfn
+ def check_key(self, key):
+ if key.key_type != 'RSA':
+ raise InvalidJWEKeyType('RSA', key.key_type)
+
+ # FIXME: get key size and insure > 2048 bits
def wrap(self, key, keylen, cek):
+ self.check_key(key)
if not cek:
cek = os.urandom(keylen)
rk = key.encrypt_key()
@@ -83,14 +111,82 @@ class _rsa(_raw_key_mgmt):
return (cek, ek)
def unwrap(self, key, ek):
+ self.check_key(key)
rk = key.decrypt_key()
cek = rk.decrypt(ek, self.padfn)
return cek
+class _aes_kw(_raw_key_mgmt):
+
+ def __init__(self, keysize):
+ self.backend = default_backend()
+ self.keysize = keysize
+
+ def check_key(self, key):
+ if key.key_type != 'oct':
+ raise InvalidJWEKeyType('oct', key.key_type)
+
+ def wrap(self, key, keylen, cek):
+ self.check_key(key)
+ if not cek:
+ cek = os.urandom(keylen)
+ rk = base64url_decode(key.encrypt_key())
+
+ # Implement RFC 3994 Key Unwrap - 2.2.2
+ # TODO: Use cryptography once issue #1733 is resolved
+ iv = 'a6a6a6a6a6a6a6a6'
+ A = iv.decode('hex')
+ R = [cek[i:i+8] for i in range(0, len(cek), 8)]
+ n = len(R)
+ for j in range(0, 6):
+ for i in range(0, n):
+ e = Cipher(algorithms.AES(rk), modes.ECB(),
+ backend=self.backend).encryptor()
+ B = e.update(A + R[i]) + e.finalize()
+ A = encode_int(decode_int(B[:8]) ^ ((n*j)+i+1), 64)
+ R[i] = B[-8:]
+ ek = A
+ for i in range(0, n):
+ ek += R[i]
+ return (cek, ek)
+
+ def unwrap(self, key, ek):
+ self.check_key(key)
+ rk = base64url_decode(key.decrypt_key())
+
+ # Implement RFC 3994 Key Unwrap - 2.2.3
+ # TODO: Use cryptography once issue #1733 is resolved
+ iv = 'a6a6a6a6a6a6a6a6'
+ Aiv = iv.decode('hex')
+
+ R = [ek[i:i+8] for i in range(0, len(ek), 8)]
+ A = R.pop(0)
+ n = len(R)
+ for j in range(5, -1, -1):
+ for i in range(n - 1, -1, -1):
+ AtR = encode_int((decode_int(A) ^ ((n*j)+i+1)), 64) + R[i]
+ d = Cipher(algorithms.AES(rk), modes.ECB(),
+ backend=self.backend).decryptor()
+ B = d.update(AtR) + d.finalize()
+ A = B[:8]
+ R[i] = B[-8:]
+
+ if A != Aiv:
+ raise InvalidJWEData('Decryption Failed')
+
+ cek = ''.join(R)
+ return cek
+
+
class _direct(_raw_key_mgmt):
+ def check_key(self, key):
+ if key.key_type != 'oct':
+ raise InvalidJWEKeyType('oct', key.key_type)
+
def wrap(self, key, keylen, cek):
+ self.check_key(key)
if cek:
return (cek, None)
k = base64url_decode(key.encrypt_key())
@@ -99,6 +195,7 @@ class _direct(_raw_key_mgmt):
return (k, '')
def unwrap(self, key, ek):
+ self.check_key(key)
if ek != '':
raise InvalidJWEData('Invalid Encryption Key.')
return base64url_decode(key.decrypt_key())
@@ -106,12 +203,6 @@ class _direct(_raw_key_mgmt):
class _raw_jwe(object):
- def encode_int(self, n, l):
- e = hex(n).rstrip("L").lstrip("0x")
- L = (l + 7) / 8 # number of bytes rounded up
- e = '0' * (L * 2 - len(e)) + e # pad as necessary
- return e.decode('hex')
-
def encrypt(self, k, a, m):
raise NotImplementedError
@@ -131,7 +222,7 @@ class _aes_cbc_hmac_sha2(_raw_jwe):
return self.blocksize * 2
def _mac(self, k, a, iv, e):
- al = self.encode_int(len(a * 8), 64)
+ al = encode_int(len(a * 8), 64)
h = hmac.HMAC(k, self.hashfn, backend=self.backend)
h.update(a)
h.update(iv)
@@ -195,6 +286,54 @@ class _aes_cbc_hmac_sha2(_raw_jwe):
return unpadder.update(d) + unpadder.finalize()
+class _aes_gcm(_raw_jwe):
+
+ def __init__(self, keybits):
+ self.backend = default_backend()
+ self.blocksize = keybits / 8
+
+ @property
+ def key_size(self):
+ return self.blocksize
+
+ # draft-ietf-jose-json-web-algorithms-40 - 5.2.2
+ def encrypt(self, k, a, m):
+ """ Encrypt accoriding to the selected encryption and hashing
+ functions.
+
+ :param k: Encryption key (optional)
+ :param a: Additional Authentication Data
+ :param m: Plaintext
+
+ Returns a dictionary with the computed data.
+ """
+ iv = os.urandom(96 / 8)
+ cipher = Cipher(algorithms.AES(k), modes.GCM(iv),
+ backend=self.backend)
+ encryptor = cipher.encryptor()
+ encryptor.authenticate_additional_data(a)
+ e = encryptor.update(m) + encryptor.finalize()
+
+ return (iv, e, encryptor.tag)
+
+ def decrypt(self, k, a, iv, e, t):
+ """ Decrypt accoriding to the selected encryption and hashing
+ functions.
+ :param k: Encryption key (optional)
+ :param a: Additional Authenticated Data
+ :param iv: Initialization Vector
+ :param e: Ciphertext
+ :param t: Authentication Tag
+
+ Returns plaintext or raises an error
+ """
+ cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t),
+ backend=self.backend)
+ decryptor = cipher.decryptor()
+ decryptor.authenticate_additional_data(a)
+ return decryptor.update(e) + decryptor.finalize()
+
+
class JWE(object):
def __init__(self, plaintext=None, protected=None, unprotected=None,
@@ -224,6 +363,19 @@ class JWE(object):
def _jwa_RSA1_5(self):
return _rsa(padding.PKCS1v15())
+ def _jwa_RSA_OAEP(self):
+ return _rsa(padding.OAEP(padding.MGF1(hashes.SHA1()),
+ hashes.SHA1(),
+ None))
+
+ def _jwa_RSA_OAEP_256(self):
+ return _rsa(padding.OAEP(padding.MGF1(hashes.SHA256()),
+ hashes.SHA256(),
+ None))
+
+ def _jwa_A128KW(self):
+ return _aes_kw(128)
+
def _jwa_dir(self):
return _direct()
@@ -231,6 +383,21 @@ class JWE(object):
def _jwa_A128CBC_HS256(self):
return _aes_cbc_hmac_sha2(hashes.SHA256(), 128)
+ def _jwa_A192CBC_HS384(self):
+ return _aes_cbc_hmac_sha2(hashes.SHA384(), 192)
+
+ def _jwa_A256CBC_HS512(self):
+ return _aes_cbc_hmac_sha2(hashes.SHA512(), 256)
+
+ def _jwa_A128GCM(self):
+ return _aes_gcm(128)
+
+ def _jwa_A192GCM(self):
+ return _aes_gcm(192)
+
+ def _jwa_A256GCM(self):
+ return _aes_gcm(256)
+
def _jwa(self, name):
attr = '_jwa_%s' % name.replace('-', '_').replace('+', '_')
try: