summaryrefslogtreecommitdiffstats
path: root/custodia/forwarder.py
blob: 03fcfef1cfcbd462ae3a062f4f9ce39a22324d5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (C) 2015  Custodia Project Contributors - see LICENSE file

import json
import uuid

from custodia import log
from custodia.client import CustodiaHTTPClient
from custodia.httpd.consumer import HTTPConsumer
from custodia.httpd.server import HTTPError


class Forwarder(HTTPConsumer):

    def __init__(self, *args, **kwargs):
        super(Forwarder, self).__init__(*args, **kwargs)
        self._auditlog = log.AuditLog(self.config)
        self.client = CustodiaHTTPClient(self.config['forward_uri'])
        self.headers = json.loads(self.config.get('forward_headers', '{}'))
        self.uuid = str(uuid.uuid4())
        self.headers['X-LOOP-CUSTODIA'] = self.uuid

    def _path(self, request):
        trail = request.get('trail', [])
        prefix = request.get('remote_user', 'guest')
        return '/'.join([prefix.rstrip('/')] + trail)

    def _headers(self, request):
        headers = {}
        headers.update(self.headers)
        loop = request['headers'].get('X-LOOP-CUSTODIA', None)
        if loop is not None:
            headers['X-LOOP-CUSTODIA'] += ',' + loop
        return headers

    def _response(self, reply, response):
        if reply.status_code < 200 or reply.status_code > 299:
            raise HTTPError(reply.status_code)
        response['code'] = reply.status_code
        if reply.content:
            response['output'] = reply.content

    def _request(self, cmd, request, response, path, **kwargs):
        if self.uuid in request['headers'].get('X-LOOP-CUSTODIA', ''):
            raise HTTPError(502, "Loop detected")
        reply = cmd(path, **kwargs)
        self._response(reply, response)

    def GET(self, request, response):
        self._request(self.client.get, request, response,
                      self._path(request),
                      params=request.get('query', None),
                      headers=self._headers(request))

    def PUT(self, request, response):
        self._request(self.client.put, request, response,
                      self._path(request),
                      data=request.get('body', None),
                      params=request.get('query', None),
                      headers=self._headers(request))

    def DELETE(self, request, response):
        self._request(self.client.delete, request, response,
                      self._path(request),
                      params=request.get('query', None),
                      headers=self._headers(request))

    def POST(self, request, response):
        self._request(self.client.post, request, response,
                      self._path(request),
                      data=request.get('body', None),
                      params=request.get('query', None),
                      headers=self._headers(request))