diff options
Diffstat (limited to 'ipapython')
-rw-r--r-- | ipapython/nsslib.py | 108 |
1 files changed, 92 insertions, 16 deletions
diff --git a/ipapython/nsslib.py b/ipapython/nsslib.py index 129f1a0c5..3c42b61ad 100644 --- a/ipapython/nsslib.py +++ b/ipapython/nsslib.py @@ -21,12 +21,14 @@ import sys import httplib import getpass +import socket import logging from nss.error import NSPRError import nss.io as io import nss.nss as nss import nss.ssl as ssl +import nss.error as error def auth_certificate_callback(sock, check_sig, is_server, certdb): cert_is_valid = False @@ -113,11 +115,84 @@ def client_auth_data_callback(ca_names, chosen_nickname, password, certdb): return False return False -class NSSConnection(httplib.HTTPConnection): +class NSSAddressFamilyFallback(object): + def __init__(self, family): + self.sock_family = family + self.family = self._get_nss_family(self.sock_family) + + def _get_nss_family(self, sock_family): + """ + Translate a family from python socket module to nss family. + """ + if sock_family in [ socket.AF_INET, socket.AF_UNSPEC ]: + return io.PR_AF_INET + elif sock_family == socket.AF_INET6: + return io.PR_AF_INET6 + else: + raise ValueError('Uknown socket family %d\n', sock_family) + + def _get_next_family(self): + if self.sock_family == socket.AF_UNSPEC and \ + self.family == io.PR_AF_INET: + return io.PR_AF_INET6 + + return None + + def _create_socket(self): + self.sock = io.Socket(family=self.family) + + def _connect_socket_family(self, host, port, family): + logging.debug("connect_socket_family: host=%s port=%s family=%s", + host, port, io.addr_family_name(family)) + try: + addr_info = [ ai for ai in io.AddrInfo(host) if ai.family == family ] + # No suitable families + if len(addr_info) == 0: + raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR, + "Cannot resolve %s using family %s" % (host, io.addr_family_name(family))) + + # Try connecting to the NetworkAddresses + for net_addr in addr_info: + net_addr.port = port + logging.debug("connecting: %s", net_addr) + try: + self.sock.connect(net_addr, family) + except Exception, e: + logging.debug("Could not connect socket to %s, error: %s, retrying..", + net_addr, str(e)) + continue + else: + return + + # Could not connect with any of NetworkAddresses + raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR, + "Could not connect to %s using any address" % host) + except ValueError, e: + raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR, e.message) + + def connect_socket(self, host, port): + try: + self._connect_socket_family(host, port, self.family) + except NSPRError, e: + if e.errno == error.PR_ADDRESS_NOT_SUPPORTED_ERROR: + next_family = self._get_next_family() + if next_family: + self.family = next_family + self._create_socket() + self._connect_socket_family(host, port, self.family) + else: + logging.debug('No next family to try..') + raise e + else: + raise e + +class NSSConnection(httplib.HTTPConnection, NSSAddressFamilyFallback): default_port = httplib.HTTPSConnection.default_port - def __init__(self, host, port=None, strict=None, dbdir=None): + def __init__(self, host, port=None, strict=None, + dbdir=None, family=socket.AF_UNSPEC): httplib.HTTPConnection.__init__(self, host, port, strict) + NSSAddressFamilyFallback.__init__(self, family) if not dbdir: raise RuntimeError("dbdir is required") @@ -130,10 +205,12 @@ class NSSConnection(httplib.HTTPConnection): nss.nss_init(dbdir) ssl.set_domestic_policy() nss.set_password_callback(self.password_callback) + self._create_socket() + def _create_socket(self): # Create the socket here so we can do things like let the caller # override the NSS callbacks - self.sock = ssl.SSLSocket() + self.sock = ssl.SSLSocket(family=self.family) self.sock.set_ssl_option(ssl.SSL_SECURITY, True) self.sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_CLIENT, True) @@ -142,7 +219,8 @@ class NSSConnection(httplib.HTTPConnection): # Provide a callback to verify the servers certificate self.sock.set_auth_certificate_callback(auth_certificate_callback, - nss.get_default_certdb()) + nss.get_default_certdb()) + self.sock.set_hostname(self.host) def password_callback(self, slot, retry, password): if not retry and password: return password @@ -156,11 +234,7 @@ class NSSConnection(httplib.HTTPConnection): pass def connect(self): - logging.debug("connect: host=%s port=%s", self.host, self.port) - self.sock.set_hostname(self.host) - net_addr = io.NetworkAddress(self.host, self.port) - logging.debug("connect: %s", net_addr) - self.sock.connect(net_addr) + self.connect_socket(self.host, self.port) def endheaders(self, message=None): """ @@ -206,20 +280,22 @@ class NSSHTTPS(httplib.HTTP): port = None self._setup(self._connection_class(host, port, strict, dbdir=dbdir)) -class NSPRConnection(httplib.HTTPConnection): +class NSPRConnection(httplib.HTTPConnection, NSSAddressFamilyFallback): default_port = httplib.HTTPConnection.default_port - def __init__(self, host, port=None, strict=None): + def __init__(self, host, port=None, strict=None, family=socket.AF_UNSPEC): httplib.HTTPConnection.__init__(self, host, port, strict) + NSSAddressFamilyFallback.__init__(self, family) logging.debug('%s init %s', self.__class__.__name__, host) + self._create_socket() + + def _create_socket(self): + super(NSPRConnection, self)._create_socket() + self.sock.set_hostname(self.host) - self.sock = io.Socket() def connect(self): - logging.debug("connect: host=%s port=%s", self.host, self.port) - net_addr = io.NetworkAddress(self.host, self.port) - logging.debug("connect: %s", net_addr) - self.sock.connect(net_addr) + self.connect_socket(self.host, self.port) class NSPRHTTP(httplib.HTTP): _http_vsn = 11 |