summaryrefslogtreecommitdiffstats
path: root/custodia/httpd/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'custodia/httpd/server.py')
-rw-r--r--custodia/httpd/server.py94
1 files changed, 68 insertions, 26 deletions
diff --git a/custodia/httpd/server.py b/custodia/httpd/server.py
index 8f02a78..dfc89d6 100644
--- a/custodia/httpd/server.py
+++ b/custodia/httpd/server.py
@@ -11,13 +11,13 @@ import six
try:
# pylint: disable=import-error
from BaseHTTPServer import BaseHTTPRequestHandler
- from SocketServer import ForkingMixIn, UnixStreamServer
+ from SocketServer import ForkingTCPServer
from urlparse import urlparse, parse_qs
from urllib import unquote
except ImportError:
# pylint: disable=import-error,no-name-in-module
from http.server import BaseHTTPRequestHandler
- from socketserver import ForkingMixIn, UnixStreamServer
+ from socketserver import ForkingTCPServer
from urllib.parse import urlparse, parse_qs, unquote
from custodia import log
@@ -39,8 +39,7 @@ class HTTPError(Exception):
super(HTTPError, self).__init__(errstring)
-class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer):
-
+class ForkingHTTPServer(ForkingTCPServer):
"""
A forking HTTP Server.
Each request runs into a forked server so that the whole environment
@@ -50,13 +49,12 @@ class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer):
When a request is received it is parsed by the handler_class provided
at server initialization.
"""
-
server_string = "Custodia/0.1"
allow_reuse_address = True
socket_file = None
def __init__(self, server_address, handler_class, config):
- UnixStreamServer.__init__(self, server_address, handler_class)
+ ForkingTCPServer.__init__(self, server_address, handler_class)
if 'consumers' not in config:
raise ValueError('Configuration does not provide any consumer')
self.config = config
@@ -64,14 +62,20 @@ class ForkingLocalHTTPServer(ForkingMixIn, UnixStreamServer):
self.server_string = self.config['server_string']
self._auditlog = log.AuditLog(self.config)
+
+class ForkingUnixHTTPServer(ForkingHTTPServer):
+ address_family = socket.AF_UNIX
+
def server_bind(self):
oldmask = os.umask(000)
- UnixStreamServer.server_bind(self)
- os.umask(oldmask)
+ try:
+ ForkingHTTPServer.server_bind(self)
+ finally:
+ os.umask(oldmask)
self.socket_file = self.socket.getsockname()
-class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
+class HTTPRequestHandler(BaseHTTPRequestHandler):
"""
This request handler is a slight modification of BaseHTTPRequestHandler
@@ -107,7 +111,6 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
protocol_version = "HTTP/1.0"
def __init__(self, *args, **kwargs):
- BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
self.requestline = ''
self.request_version = ''
self.command = ''
@@ -118,15 +121,21 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
self.url = None
self.body = None
self.loginuid = None
+ self._creds = False
+ BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
def version_string(self):
return self.server.server_string
def _get_loginuid(self, pid):
loginuid = None
+ # NOTE: Using proc to find the login uid is not reliable
+ # this is why login uid is fetched separately and not stored
+ # into 'creds', to avoid giving the false impression it can be
+ # used to perform access control decisions
try:
- with open("/proc/" + str(pid) + "/loginuid", "r") as f:
- loginuid = int(f.read(), 10)
+ with open("/proc/%i/loginuid" % pid, "r") as f:
+ loginuid = int(f.read())
except IOError as e:
if e.errno != errno.ENOENT:
raise
@@ -136,6 +145,12 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
@property
def peer_creds(self):
+ if self._creds is not False:
+ return self._creds
+ # works only for unix sockets
+ if self.request.family != socket.AF_UNIX:
+ self._creds = None
+ return self._creds
creds = self.request.getsockopt(socket.SOL_SOCKET, SO_PEERCRED,
struct.calcsize('3i'))
pid, uid, gid = struct.unpack('3i', creds)
@@ -147,7 +162,16 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
log.debug("Couldn't retrieve SELinux Context: (%s)" % str(e))
context = None
- return {'pid': pid, 'uid': uid, 'gid': gid, 'context': context}
+ self._creds = {'pid': pid, 'uid': uid, 'gid': gid, 'context': context}
+ return self._creds
+
+ @property
+ def peer_info(self):
+ if self.peer_creds is not None:
+ return self._creds['pid']
+ elif self.request.family in {socket.AF_INET, socket.AF_INET6}:
+ return self.request.getpeername()
+ return None
def parse_request(self, *args, **kwargs):
if not BaseHTTPRequestHandler.parse_request(self, *args, **kwargs):
@@ -155,7 +179,8 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
# grab the loginuid from `/proc` as soon as possible
creds = self.peer_creds
- self.loginuid = self._get_loginuid(creds['pid'])
+ if creds is not None:
+ self.loginuid = self._get_loginuid(creds['pid'])
# after basic parsing also use urlparse to retrieve individual
# elements of a request.
@@ -182,8 +207,9 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
self.body = self.rfile.read(length)
def handle_one_request(self):
- # Set a fake client address to make log functions happy
- self.client_address = ['127.0.0.1', 0]
+ if self.request.family == socket.AF_UNIX:
+ # Set a fake client address to make log functions happy
+ self.client_address = ['127.0.0.1', 0]
try:
if not self.server.config:
self.close_connection = 1
@@ -209,6 +235,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
self.wfile.flush()
return
request = {'creds': self.peer_creds,
+ 'client_id': self.peer_info,
'command': self.command,
'path': self.path,
'query': self.query,
@@ -300,7 +327,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
valid_once = True
if valid_once is not True:
self.server._auditlog.svc_access(log.AUDIT_SVC_AUTH_FAIL,
- request['creds']['pid'], "MAIN",
+ request['client_id'], "MAIN",
'No auth')
raise HTTPError(403)
@@ -314,7 +341,7 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
break
if valid is not True:
self.server._auditlog.svc_access(log.AUDIT_SVC_AUTHZ_FAIL,
- request['creds']['pid'], "MAIN",
+ request['client_id'], "MAIN",
request.get('path', '/'))
raise HTTPError(403)
@@ -340,15 +367,30 @@ class LocalHTTPRequestHandler(BaseHTTPRequestHandler):
raise HTTPError(404)
-class LocalHTTPServer(object):
+class HTTPServer(object):
+
+ def __init__(self, srvurl, config):
+ url = urlparse(srvurl)
+ address = unquote(url.netloc)
+ if url.scheme == 'http+unix':
+ # Unix socket
+ serverclass = ForkingUnixHTTPServer
+ if address[0] != '/':
+ raise ValueError('Must use absolute unix socket name')
+ if os.path.exists(address):
+ os.remove(address)
+ elif url.scheme == 'http':
+ host, port = address.split(":")
+ address = (host, int(port))
+ serverclass = ForkingHTTPServer
+ elif url.scheme == 'https':
+ raise NotImplementedError
+ else:
+ raise ValueError('Unknown URL Scheme: %s' % url.scheme)
- def __init__(self, address, config):
- if address[0] != '/':
- raise ValueError('Must use absolute unix socket name')
- if os.path.exists(address):
- os.remove(address)
- self.httpd = ForkingLocalHTTPServer(address, LocalHTTPRequestHandler,
- config)
+ self.httpd = serverclass(address,
+ HTTPRequestHandler,
+ config)
def get_socket(self):
return (self.httpd.socket, self.httpd.socket_file)