summaryrefslogtreecommitdiffstats
path: root/custodia
diff options
context:
space:
mode:
Diffstat (limited to 'custodia')
-rw-r--r--custodia/client.py124
-rw-r--r--custodia/message/common.py1
-rw-r--r--custodia/message/kem.py14
-rw-r--r--custodia/message/simple.py5
-rw-r--r--custodia/secrets.py89
5 files changed, 213 insertions, 20 deletions
diff --git a/custodia/client.py b/custodia/client.py
index 221080a..9647d68 100644
--- a/custodia/client.py
+++ b/custodia/client.py
@@ -2,6 +2,9 @@
import socket
+from jwcrypto.common import json_decode
+from jwcrypto.jwk import JWK
+
import requests
from requests.adapters import HTTPAdapter
@@ -10,6 +13,10 @@ from requests.compat import unquote, urlparse
from requests.packages.urllib3.connection import HTTPConnection
from requests.packages.urllib3.connectionpool import HTTPConnectionPool
+from custodia.message.kem import (
+ check_kem_claims, decode_enc_kem, make_enc_kem
+)
+
class HTTPUnixConnection(HTTPConnection):
@@ -150,3 +157,120 @@ class CustodiaSimpleClient(CustodiaHTTPClient):
def del_secret(self, name):
r = self.delete(name)
r.raise_for_status()
+
+
+class CustodiaKEMClient(CustodiaHTTPClient):
+ def __init__(self, *args, **kwargs):
+ super(CustodiaKEMClient, self).__init__(*args, **kwargs)
+ self._cli_signing_key = None
+ self._cli_decryption_key = None
+ self._srv_verifying_key = None
+ self._srv_encryption_key = None
+ self._sig_alg = None
+ self._enc_alg = None
+
+ def _decode_key(self, key):
+ if key is None:
+ return None
+ elif isinstance(key, JWK):
+ return key
+ elif isinstance(key, dict):
+ return JWK(**key)
+ elif isinstance(key, str):
+ return JWK(**(json_decode(key)))
+ else:
+ raise TypeError("Invalid key type")
+
+ def set_server_public_keys(self, sig, enc):
+ self._srv_verifying_key = self._decode_key(sig)
+ self._srv_encryption_key = self._decode_key(enc)
+
+ def set_client_keys(self, sig, enc):
+ self._cli_signing_key = self._decode_key(sig)
+ self._cli_decryption_key = self._decode_key(enc)
+
+ def set_algorithms(self, sig, enc):
+ self._sig_alg = sig
+ self._enc_alg = enc
+
+ def _signing_algorithm(self, key):
+ if self._sig_alg is not None:
+ return self._sig_alg
+ elif key.key_type == 'RSA':
+ return 'RS256'
+ elif key.key_type == 'EC':
+ return 'ES256'
+ else:
+ raise ValueError('Unsupported key type')
+
+ def _encryption_algorithm(self, key):
+ if self._enc_alg is not None:
+ return self._enc_alg
+ elif key.key_type == 'RSA':
+ return ('RSA1_5', 'A256CBC-HS512')
+ elif key.key_type == 'EC':
+ return ('ECDH-ES+A256KW', 'A256CBC-HS512')
+ else:
+ raise ValueError('Unsupported key type')
+
+ def _kem_wrap(self, name, value):
+ if self._cli_signing_key is None:
+ raise KeyError("Client Signing key is not available")
+ if self._srv_encryption_key is None:
+ raise KeyError("Server Encryption key is not available")
+ sig_alg = self._signing_algorithm(self._cli_signing_key)
+ enc_alg = self._encryption_algorithm(self._srv_encryption_key)
+ return make_enc_kem(name, value,
+ self._cli_signing_key, sig_alg,
+ self._srv_encryption_key, enc_alg)
+
+ def _kem_unwrap(self, name, message):
+ if message.get("type", None) != "kem":
+ raise TypeError("Invalid token type, expected 'kem', got %s" % (
+ message.get("type", None),))
+
+ if self._cli_decryption_key is None:
+ raise KeyError("Client Decryption key is not available")
+ if self._srv_verifying_key is None:
+ raise KeyError("Server Verifying key is not available")
+ claims = decode_enc_kem(message["value"],
+ self._cli_decryption_key,
+ self._srv_verifying_key)
+ check_kem_claims(claims, name)
+ return claims
+
+ def create_container(self, name):
+ cname = self.container_name(name)
+ message = self._kem_wrap(cname, None)
+ r = self.post(cname, json={"type": "kem", "value": message})
+ r.raise_for_status()
+ self._kem_unwrap(cname, r.json())
+
+ def delete_container(self, name):
+ cname = self.container_name(name)
+ message = self._kem_wrap(cname, None)
+ r = self.delete(cname, json={"type": "kem", "value": message})
+ r.raise_for_status()
+ self._kem_unwrap(cname, r.json())
+
+ def list_container(self, name):
+ return json_decode(self.get_secret(self.container_name(name)))
+
+ def get_secret(self, name):
+ message = self._kem_wrap(name, None)
+ r = self.get(name, params={"type": "kem", "value": message})
+ r.raise_for_status()
+ claims = self._kem_unwrap(name, r.json())
+ return claims['value']
+
+ def set_secret(self, name, value):
+ message = self._kem_wrap(name, value)
+ r = self.put(name, json={"type": "kem", "value": message})
+ r.raise_for_status()
+ self._kem_unwrap(name, r.json())
+
+ def del_secret(self, name):
+ message = self._kem_wrap(name, None)
+ r = self.delete(name, json={"type": "kem", "value": message})
+ r.raise_for_status()
+ self._kem_unwrap(name, r.json())
diff --git a/custodia/message/common.py b/custodia/message/common.py
index d774e3c..bbcfb2b 100644
--- a/custodia/message/common.py
+++ b/custodia/message/common.py
@@ -42,6 +42,7 @@ class MessageHandler(object):
def __init__(self, request):
self.req = request
+ self.name = None
self.payload = None
def parse(self, msg, name):
diff --git a/custodia/message/kem.py b/custodia/message/kem.py
index 48b756b..add1c72 100644
--- a/custodia/message/kem.py
+++ b/custodia/message/kem.py
@@ -215,11 +215,9 @@ class KEMClient(object):
self.server_keys[KEY_USAGE_ENC], encalg)
def parse_reply(self, name, message):
- jwe = JWT(jwt=message,
- key=self.client_keys[KEY_USAGE_ENC])
- jws = JWT(jwt=jwe.claims,
- key=self.server_keys[KEY_USAGE_SIG])
- claims = json_decode(jws.claims)
+ claims = decode_enc_kem(message,
+ self.client_keys[KEY_USAGE_ENC],
+ self.server_keys[KEY_USAGE_SIG])
check_kem_claims(claims, name)
return claims['value']
@@ -242,6 +240,12 @@ def make_enc_kem(name, value, sig_key, alg, enc_key, enc):
return jwe.serialize(compact=True)
+def decode_enc_kem(message, enc_key, sig_key):
+ jwe = JWT(jwt=message, key=enc_key)
+ jws = JWT(jwt=jwe.claims, key=sig_key)
+ return json_decode(jws.claims)
+
+
# unit tests
test_keys = ({
"kty": "RSA",
diff --git a/custodia/message/simple.py b/custodia/message/simple.py
index 7186d12..6482c53 100644
--- a/custodia/message/simple.py
+++ b/custodia/message/simple.py
@@ -28,8 +28,13 @@ class SimpleKey(MessageHandler):
if not isinstance(msg, string_types):
raise InvalidMessage("The 'value' attribute is not a string")
+ self.name = name
self.payload = msg
def reply(self, output):
+ if self.name.endswith('/'):
+ # directory listings are pass-through with simple messages
+ return output
+
return json.dumps({'type': 'simple', 'value': output},
separators=(',', ':'))
diff --git a/custodia/secrets.py b/custodia/secrets.py
index 1c3248d..7735941 100644
--- a/custodia/secrets.py
+++ b/custodia/secrets.py
@@ -46,8 +46,30 @@ class Secrets(HTTPConsumer):
f = self._db_key([default, ''])
return f
- def _parse(self, request, value, name):
- return self._validator.parse(request, value, name)
+ def _parse(self, request, query, name):
+ return self._validator.parse(request, query, name)
+
+ def _parse_query(self, request, name):
+ # default to simple
+ query = request.get('query', '')
+ if len(query) == 0:
+ query = {'type': 'simple', 'value': ''}
+ return self._parse(request, query, name)
+
+ def _parse_body(self, request, name):
+ body = request.get('body')
+ if body is None:
+ raise HTTPError(400)
+ value = json.loads(bytes(body).decode('utf-8'))
+ return self._parse(request, value, name)
+
+ def _parse_maybe_body(self, request, name):
+ body = request.get('body')
+ if body is None:
+ value = {'type': 'simple', 'value': ''}
+ else:
+ value = json.loads(bytes(body).decode('utf-8'))
+ return self._parse(request, value, name)
def _parent_exists(self, default, trail):
# check that the containers exist
@@ -102,6 +124,11 @@ class Secrets(HTTPConsumer):
raise HTTPError(405)
def _list(self, trail, request, response):
+ try:
+ name = '/'.join(trail)
+ msg = self._parse_query(request, name)
+ except Exception as e:
+ raise HTTPError(406, str(e))
default = request.get('default_namespace', None)
basename = self._db_container_key(default, trail)
try:
@@ -109,11 +136,16 @@ class Secrets(HTTPConsumer):
self.logger.debug('list %s returned %r', basename, keylist)
if keylist is None:
raise HTTPError(404)
- response['output'] = json.dumps(keylist)
+ response['output'] = msg.reply(json.dumps(keylist))
except CSStoreError:
raise HTTPError(500)
def _create(self, trail, request, response):
+ try:
+ name = '/'.join(trail)
+ msg = self._parse_maybe_body(request, name)
+ except Exception as e:
+ raise HTTPError(406, str(e))
default = request.get('default_namespace', None)
basename = self._db_container_key(None, trail)
try:
@@ -128,9 +160,17 @@ class Secrets(HTTPConsumer):
except CSStoreError:
raise HTTPError(500)
+ output = msg.reply(None)
+ if output is not None:
+ response['output'] = output
response['code'] = 201
def _destroy(self, trail, request, response):
+ try:
+ name = '/'.join(trail)
+ msg = self._parse_maybe_body(request, name)
+ except Exception as e:
+ raise HTTPError(406, str(e))
basename = self._db_container_key(None, trail)
try:
keylist = self.root.store.list(basename)
@@ -145,7 +185,12 @@ class Secrets(HTTPConsumer):
if ret is False:
raise HTTPError(404)
- response['code'] = 204
+ output = msg.reply(None)
+ if output is None:
+ response['code'] = 204
+ else:
+ response['output'] = output
+ response['code'] = 200
def _client_name(self, request):
if 'remote_user' in request:
@@ -171,13 +216,9 @@ class Secrets(HTTPConsumer):
self._int_get_key, trail, request, response)
def _int_get_key(self, trail, request, response):
- # default to simple
- query = request.get('query', '')
- if len(query) == 0:
- query = {'type': 'simple', 'value': ''}
try:
name = '/'.join(trail)
- msg = self._parse(request, query, name)
+ msg = self._parse_query(request, name)
except Exception as e:
raise HTTPError(406, str(e))
key = self._db_key(trail)
@@ -198,13 +239,9 @@ class Secrets(HTTPConsumer):
dict()).get('Content-Type', '')
if content_type.split(';')[0].strip() != 'application/json':
raise HTTPError(400, 'Invalid Content-Type')
- body = request.get('body')
- if body is None:
- raise HTTPError(400)
- value = bytes(body).decode('utf-8')
try:
name = '/'.join(trail)
- msg = self._parse(request, json.loads(value), name)
+ msg = self._parse_body(request, name)
except UnknownMessageType as e:
raise HTTPError(406, str(e))
except UnallowedMessage as e:
@@ -229,6 +266,9 @@ class Secrets(HTTPConsumer):
except CSStoreError:
raise HTTPError(500)
+ output = msg.reply(None)
+ if output is not None:
+ response['output'] = output
response['code'] = 201
def _del_key(self, trail, request, response):
@@ -236,6 +276,11 @@ class Secrets(HTTPConsumer):
self._int_del_key, trail, request, response)
def _int_del_key(self, trail, request, response):
+ try:
+ name = '/'.join(trail)
+ msg = self._parse_maybe_body(request, name)
+ except Exception as e:
+ raise HTTPError(406, str(e))
key = self._db_key(trail)
try:
ret = self.root.store.cut(key)
@@ -245,7 +290,12 @@ class Secrets(HTTPConsumer):
if ret is False:
raise HTTPError(404)
- response['code'] = 204
+ output = msg.reply(None)
+ if output is None:
+ response['code'] = 204
+ else:
+ response['output'] = output
+ response['code'] = 200
# unit tests
@@ -427,6 +477,15 @@ class SecretsTests(unittest.TestCase):
self.GET(req, rep)
self.assertEqual(err.exception.code, 404)
+ def test_6_LISTkeys_errors_406_1(self):
+ req = {'remote_user': 'test',
+ 'query': {'type': 'invalid'},
+ 'trail': ['test', '']}
+ rep = {}
+ with self.assertRaises(HTTPError) as err:
+ self.GET(req, rep)
+ self.assertEqual(err.exception.code, 406)
+
def test_7_DELETEKey(self):
req = {'remote_user': 'test',
'trail': ['test', 'key1']}