diff options
Diffstat (limited to 'custodia/httpd/server.py')
-rw-r--r-- | custodia/httpd/server.py | 94 |
1 files changed, 68 insertions, 26 deletions
diff --git a/custodia/httpd/server.py b/custodia/httpd/server.py index 8f02a78..dfc89d6 100644 --- a/custodia/httpd/server.py +++ b/custodia/httpd/server.py @@ -11,13 +11,13 @@ import six try: # pylint: disable=import-error from BaseHTTPServer import BaseHTTPRequestHandler - from SocketServer import ForkingMixIn, UnixStreamServer + from SocketServer import ForkingTCPServer from urlparse import urlparse, parse_qs from urllib import unquote except ImportError: # pylint: disable=import-error,no-name-in-module from http.server import BaseHTTPRequestHandler - from socketserver import ForkingMixIn, UnixStreamServer + from socketserver import ForkingTCPServer from urllib.parse import urlparse, parse_qs, unquote from custodia import log @@ -39,8 +39,7 @@ class HTTPError(Exception): super(HTTPError, self).__init__(errstring) -class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer): - +class ForkingHTTPServer(ForkingTCPServer): """ A forking HTTP Server. Each request runs into a forked server so that the whole environment @@ -50,13 +49,12 @@ class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer): When a request is received it is parsed by the handler_class provided at server initialization. """ - server_string = "Custodia/0.1" allow_reuse_address = True socket_file = None def __init__(self, server_address, handler_class, config): - UnixStreamServer.__init__(self, server_address, handler_class) + ForkingTCPServer.__init__(self, server_address, handler_class) if 'consumers' not in config: raise ValueError('Configuration does not provide any consumer') self.config = config @@ -64,14 +62,20 @@ class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer): self.server_string = self.config['server_string'] self._auditlog = log.AuditLog(self.config) + +class ForkingUnixHTTPServer(ForkingHTTPServer): + address_family = socket.AF_UNIX + def server_bind(self): oldmask = os.umask(000) - UnixStreamServer.server_bind(self) - os.umask(oldmask) + try: + ForkingHTTPServer.server_bind(self) + finally: + os.umask(oldmask) self.socket_file = self.socket.getsockname() -class LocalHTTPRequestHandler(BaseHTTPRequestHandler): +class HTTPRequestHandler(BaseHTTPRequestHandler): """ This request handler is a slight modification of BaseHTTPRequestHandler @@ -107,7 +111,6 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): protocol_version = "HTTP/1.0" def __init__(self, *args, **kwargs): - BaseHTTPRequestHandler.__init__(self, *args, **kwargs) self.requestline = '' self.request_version = '' self.command = '' @@ -118,15 +121,21 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): self.url = None self.body = None self.loginuid = None + self._creds = False + BaseHTTPRequestHandler.__init__(self, *args, **kwargs) def version_string(self): return self.server.server_string def _get_loginuid(self, pid): loginuid = None + # NOTE: Using proc to find the login uid is not reliable + # this is why login uid is fetched separately and not stored + # into 'creds', to avoid giving the false impression it can be + # used to perform access control decisions try: - with open("/proc/" + str(pid) + "/loginuid", "r") as f: - loginuid = int(f.read(), 10) + with open("/proc/%i/loginuid" % pid, "r") as f: + loginuid = int(f.read()) except IOError as e: if e.errno != errno.ENOENT: raise @@ -136,6 +145,12 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): @property def peer_creds(self): + if self._creds is not False: + return self._creds + # works only for unix sockets + if self.request.family != socket.AF_UNIX: + self._creds = None + return self._creds creds = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, struct.calcsize('3i')) pid, uid, gid = struct.unpack('3i', creds) @@ -147,7 +162,16 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): log.debug("Couldn't retrieve SELinux Context: (%s)" % str(e)) context = None - return {'pid': pid, 'uid': uid, 'gid': gid, 'context': context} + self._creds = {'pid': pid, 'uid': uid, 'gid': gid, 'context': context} + return self._creds + + @property + def peer_info(self): + if self.peer_creds is not None: + return self._creds['pid'] + elif self.request.family in {socket.AF_INET, socket.AF_INET6}: + return self.request.getpeername() + return None def parse_request(self, *args, **kwargs): if not BaseHTTPRequestHandler.parse_request(self, *args, **kwargs): @@ -155,7 +179,8 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): # grab the loginuid from `/proc` as soon as possible creds = self.peer_creds - self.loginuid = self._get_loginuid(creds['pid']) + if creds is not None: + self.loginuid = self._get_loginuid(creds['pid']) # after basic parsing also use urlparse to retrieve individual # elements of a request. @@ -182,8 +207,9 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): self.body = self.rfile.read(length) def handle_one_request(self): - # Set a fake client address to make log functions happy - self.client_address = ['127.0.0.1', 0] + if self.request.family == socket.AF_UNIX: + # Set a fake client address to make log functions happy + self.client_address = ['127.0.0.1', 0] try: if not self.server.config: self.close_connection = 1 @@ -209,6 +235,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): self.wfile.flush() return request = {'creds': self.peer_creds, + 'client_id': self.peer_info, 'command': self.command, 'path': self.path, 'query': self.query, @@ -300,7 +327,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): valid_once = True if valid_once is not True: self.server._auditlog.svc_access(log.AUDIT_SVC_AUTH_FAIL, - request['creds']['pid'], "MAIN", + request['client_id'], "MAIN", 'No auth') raise HTTPError(403) @@ -314,7 +341,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): break if valid is not True: self.server._auditlog.svc_access(log.AUDIT_SVC_AUTHZ_FAIL, - request['creds']['pid'], "MAIN", + request['client_id'], "MAIN", request.get('path', '/')) raise HTTPError(403) @@ -340,15 +367,30 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler): raise HTTPError(404) -class LocalHTTPServer(object): +class HTTPServer(object): + + def __init__(self, srvurl, config): + url = urlparse(srvurl) + address = unquote(url.netloc) + if url.scheme == 'http+unix': + # Unix socket + serverclass = ForkingUnixHTTPServer + if address[0] != '/': + raise ValueError('Must use absolute unix socket name') + if os.path.exists(address): + os.remove(address) + elif url.scheme == 'http': + host, port = address.split(":") + address = (host, int(port)) + serverclass = ForkingHTTPServer + elif url.scheme == 'https': + raise NotImplementedError + else: + raise ValueError('Unknown URL Scheme: %s' % url.scheme) - def __init__(self, address, config): - if address[0] != '/': - raise ValueError('Must use absolute unix socket name') - if os.path.exists(address): - os.remove(address) - self.httpd = ForkingLocalHTTPServer(address, LocalHTTPRequestHandler, - config) + self.httpd = serverclass(address, + HTTPRequestHandler, + config) def get_socket(self): return (self.httpd.socket, self.httpd.socket_file) |