From 167e8a3c21cbd0a38887e6c29824aa5ead6a6b10 Mon Sep 17 00:00:00 2001 From: Simo Sorce Date: Sun, 8 Mar 2015 22:26:29 -0400 Subject: Implement JWE JSON Deserialization Also fix JWE JSON Serialization bug --- jwcrypto/jwe.py | 169 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 118 insertions(+), 51 deletions(-) diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py index 3e7ea19..6bea310 100644 --- a/jwcrypto/jwe.py +++ b/jwcrypto/jwe.py @@ -207,9 +207,10 @@ class JWE(object): :param unprotected(json): The shared unprotected header :param aad(bytes): Arbitrary additional authenticated data """ - self.objects = {'recipients': list()} + self.objects = dict() self.plaintext = plaintext self.cek = None + self.decryptlog = None if aad: self.objects['aad'] = aad if protected: @@ -263,7 +264,7 @@ class JWE(object): uh = json.loads(self.objects['unprotected']) ph = self.merge_headers(ph, uh) if header: - rh = json.loads('header') + rh = json.loads(header) ph = self.merge_headers(ph, rh) alg = self._jwa(ph.get('alg', None)) @@ -295,7 +296,21 @@ class JWE(object): self.objects['ciphertext'] = ciphertext self.objects['tag'] = tag - self.objects['recipients'].append(rec) + if 'recipients' in self.objects: + self.objects['recipients'].append(rec) + elif 'encrypted_key' in self.objects or 'header' in self.objects: + self.objects['recipients'] = list() + n = dict() + if 'encrypted_key' in self.objects: + n['encrypted_key'] = self.objects['encrypted_key'] + del self.objects['encrypted_key'] + if 'header' in self.objects: + n['header'] = self.objects['header'] + del self.objects['header'] + self.objects['recipients'].append(n) + self.objects['recipients'].append(rec) + else: + self.objects.update(rec) def serialize(self, compact=False): @@ -318,21 +333,29 @@ class JWE(object): obj = self.objects enc = {'ciphertext': base64url_encode(obj['ciphertext']), 'iv': base64url_encode(obj['iv']), - 'tag': base64url_encode(self.objects['tag']), - 'recipients': list()} + 'tag': base64url_encode(self.objects['tag'])} if 'protected' in obj: enc['protected'] = base64url_encode(obj['protected']) if 'unprotected' in obj: enc['unprotected'] = json.loads(obj['unprotected']) if 'aad' in obj: enc['aad'] = base64url_encode(obj['aad']) - for rec in obj['recipients']: - e = dict() - if 'encrypted_key' in rec: - e['encrypted_key'] = base64url_encode(rec['encrypted_key']) - if 'header' in rec: - e['header'] = json.loads(rec['header']) - rec['recipients'].append(e) + if 'recipients' in obj: + enc['recipients'] = list() + for rec in obj['recipients']: + e = dict() + if 'encrypted_key' in rec: + e['encrypted_key'] = \ + base64url_encode(rec['encrypted_key']) + if 'header' in rec: + e['header'] = json.loads(rec['header']) + enc['recipients'].append(e) + else: + if 'encrypted_key' in obj: + enc['encrypted_key'] = \ + base64url_encode(obj['encrypted_key']) + if 'header' in obj: + enc['header'] = json.loads(obj['header']) return json.dumps(enc) def check_crit(self, crit): @@ -345,56 +368,80 @@ class JWE(object): '"%s"' % k) # FIXME: allow to specify which algorithms to accept as valid - def decrypt(self, key): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if 'ciphertext' not in self.objects: - raise InvalidJWEOperation("No available ciphertext") + def decrypt(self, key, ppe): - for rec in self.objects['recipients']: + ph = json.loads(self.objects['protected']) + if 'unprotected' in self.objects: + uh = json.loads(self.objects['unprotected']) + ph = self.merge_headers(ph, uh) + if 'header' in ppe: + rh = json.loads(ppe['header']) + ph = self.merge_headers(ph, rh) + # TODO: allow caller to specify list of headers it understands + if 'crit' in ph: + self.check_crit(ph['crit']) - ph = json.loads(self.objects['protected']) - if 'unprotected' in self.objects: - uh = json.loads(self.objects['unprotected']) - ph = self.merge_headers(ph, uh) - if 'header' in rec: - rh = json.loads(rec['header']) - ph = self.merge_headers(ph, rh) - # TODO: allow caller to specify list of headers it understands - if 'crit' in ph: - self.check_crit(ph['crit']) + alg = self._jwa(ph.get('alg', None)) + enc = self._jwa(ph.get('enc', None)) - alg = self._jwa(ph.get('alg', None)) - enc = self._jwa(ph.get('enc', None)) + aad = base64url_encode(self.objects.get('protected', '')) + if 'aad' in self.objects: + aad += '.' + base64url_encode(self.objects['aad']) - aad = base64url_encode(self.objects.get('protected', '')) - if 'aad' in self.objects: - aad += '.' + base64url_encode(self.objects['aad']) + cek = alg.unwrap(key, ppe.get('encrypted_key', None)) + data = enc.decrypt(cek, aad, self.objects['iv'], + self.objects['ciphertext'], + self.objects['tag']) - cek = alg.unwrap(key, rec['encrypted_key']) - data = enc.decrypt(cek, aad, self.objects['iv'], - self.objects['ciphertext'], - self.objects['tag']) + self.decryptlog.append('Success') + self.cek = cek - compress = ph.get('zip', None) - if compress == 'DEF': - self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS) - elif compress is None: - self.plaintext = data - else: - raise ValueError('Unknown compression') + compress = ph.get('zip', None) + if compress == 'DEF': + self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS) + elif compress is None: + self.plaintext = data + else: + raise ValueError('Unknown compression') def deserialize(self, raw_jwe, key=None): """ Destroys any current status and tries to import the raw JWS provided. """ self.objects = dict() + self.plaintext = None + self.cek = None + o = dict() try: try: djwe = json.loads(raw_jwe) - _ = djwe - raise NotImplementedError + o['iv'] = base64url_decode(str(djwe['iv'])) + o['ciphertext'] = base64url_decode(str(djwe['ciphertext'])) + o['tag'] = base64url_decode(str(djwe['tag'])) + if 'protected' in djwe: + o['protected'] = base64url_decode(str(djwe['protected'])) + if 'unprotected' in djwe: + o['unprotected'] = json.dumps(djwe['unprotected']) + if 'aad' in djwe: + o['aad'] = base64url_decode(str(djwe['aad'])) + if 'recipients' in djwe: + o['recipients'] = list() + for rec in djwe['recipients']: + e = dict() + if 'encrypted_key' in rec: + e['encrypted_key'] = \ + base64url_decode(str(rec['encrypted_key'])) + if 'header' in rec: + e['header'] = json.dumps(rec['header']) + o['recipients'].append(e) + else: + if 'encrypted_key' in djwe: + o['encrypted_key'] = \ + base64url_decode(str(djwe['encrypted_key'])) + if 'header' in djwe: + o['header'] = json.dumps(djwe['header']) + except ValueError: c = raw_jwe.split('.') if len(c) != 5: @@ -403,11 +450,31 @@ class JWE(object): o['iv'] = base64url_decode(str(c[2])) o['ciphertext'] = base64url_decode(str(c[3])) o['tag'] = base64url_decode(str(c[4])) - o['recipients'] = [{'encrypted_key': - base64url_decode(str(c[1]))}] - self.objects = o - if key: - self.decrypt(key) + o['encrypted_key'] = base64url_decode(str(c[1])) + + self.objects = o except Exception, e: # pylint: disable=broad-except raise InvalidJWEData('Invalid format', e) + + if key: + if not isinstance(key, JWK): + raise ValueError('key is not a JWK object') + if 'ciphertext' not in self.objects: + raise InvalidJWEOperation("No available ciphertext") + self.decryptlog = list() + + if 'recipients' in self.objects: + for rec in self.objects['recipients']: + try: + self.decrypt(key, rec) + except Exception, e: # pylint: disable=broad-except + self.decryptlog.append('Failed: [%s]' % str(e)) + else: + try: + self.decrypt(key, self.objects) + except Exception, e: # pylint: disable=broad-except + self.decryptlog.append('Failed: [%s]' % str(e)) + + if not self.plaintext: + raise InvalidJWEData('No recipient matches the provided key') -- cgit