summaryrefslogtreecommitdiffstats
path: root/ipapython/nsslib.py
diff options
context:
space:
mode:
Diffstat (limited to 'ipapython/nsslib.py')
-rw-r--r--ipapython/nsslib.py108
1 files changed, 92 insertions, 16 deletions
diff --git a/ipapython/nsslib.py b/ipapython/nsslib.py
index 129f1a0c5..3c42b61ad 100644
--- a/ipapython/nsslib.py
+++ b/ipapython/nsslib.py
@@ -21,12 +21,14 @@
import sys
import httplib
import getpass
+import socket
import logging
from nss.error import NSPRError
import nss.io as io
import nss.nss as nss
import nss.ssl as ssl
+import nss.error as error
def auth_certificate_callback(sock, check_sig, is_server, certdb):
cert_is_valid = False
@@ -113,11 +115,84 @@ def client_auth_data_callback(ca_names, chosen_nickname, password, certdb):
return False
return False
-class NSSConnection(httplib.HTTPConnection):
+class NSSAddressFamilyFallback(object):
+ def __init__(self, family):
+ self.sock_family = family
+ self.family = self._get_nss_family(self.sock_family)
+
+ def _get_nss_family(self, sock_family):
+ """
+ Translate a family from python socket module to nss family.
+ """
+ if sock_family in [ socket.AF_INET, socket.AF_UNSPEC ]:
+ return io.PR_AF_INET
+ elif sock_family == socket.AF_INET6:
+ return io.PR_AF_INET6
+ else:
+ raise ValueError('Uknown socket family %d\n', sock_family)
+
+ def _get_next_family(self):
+ if self.sock_family == socket.AF_UNSPEC and \
+ self.family == io.PR_AF_INET:
+ return io.PR_AF_INET6
+
+ return None
+
+ def _create_socket(self):
+ self.sock = io.Socket(family=self.family)
+
+ def _connect_socket_family(self, host, port, family):
+ logging.debug("connect_socket_family: host=%s port=%s family=%s",
+ host, port, io.addr_family_name(family))
+ try:
+ addr_info = [ ai for ai in io.AddrInfo(host) if ai.family == family ]
+ # No suitable families
+ if len(addr_info) == 0:
+ raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
+ "Cannot resolve %s using family %s" % (host, io.addr_family_name(family)))
+
+ # Try connecting to the NetworkAddresses
+ for net_addr in addr_info:
+ net_addr.port = port
+ logging.debug("connecting: %s", net_addr)
+ try:
+ self.sock.connect(net_addr, family)
+ except Exception, e:
+ logging.debug("Could not connect socket to %s, error: %s, retrying..",
+ net_addr, str(e))
+ continue
+ else:
+ return
+
+ # Could not connect with any of NetworkAddresses
+ raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR,
+ "Could not connect to %s using any address" % host)
+ except ValueError, e:
+ raise NSPRError(error.PR_ADDRESS_NOT_SUPPORTED_ERROR, e.message)
+
+ def connect_socket(self, host, port):
+ try:
+ self._connect_socket_family(host, port, self.family)
+ except NSPRError, e:
+ if e.errno == error.PR_ADDRESS_NOT_SUPPORTED_ERROR:
+ next_family = self._get_next_family()
+ if next_family:
+ self.family = next_family
+ self._create_socket()
+ self._connect_socket_family(host, port, self.family)
+ else:
+ logging.debug('No next family to try..')
+ raise e
+ else:
+ raise e
+
+class NSSConnection(httplib.HTTPConnection, NSSAddressFamilyFallback):
default_port = httplib.HTTPSConnection.default_port
- def __init__(self, host, port=None, strict=None, dbdir=None):
+ def __init__(self, host, port=None, strict=None,
+ dbdir=None, family=socket.AF_UNSPEC):
httplib.HTTPConnection.__init__(self, host, port, strict)
+ NSSAddressFamilyFallback.__init__(self, family)
if not dbdir:
raise RuntimeError("dbdir is required")
@@ -130,10 +205,12 @@ class NSSConnection(httplib.HTTPConnection):
nss.nss_init(dbdir)
ssl.set_domestic_policy()
nss.set_password_callback(self.password_callback)
+ self._create_socket()
+ def _create_socket(self):
# Create the socket here so we can do things like let the caller
# override the NSS callbacks
- self.sock = ssl.SSLSocket()
+ self.sock = ssl.SSLSocket(family=self.family)
self.sock.set_ssl_option(ssl.SSL_SECURITY, True)
self.sock.set_ssl_option(ssl.SSL_HANDSHAKE_AS_CLIENT, True)
@@ -142,7 +219,8 @@ class NSSConnection(httplib.HTTPConnection):
# Provide a callback to verify the servers certificate
self.sock.set_auth_certificate_callback(auth_certificate_callback,
- nss.get_default_certdb())
+ nss.get_default_certdb())
+ self.sock.set_hostname(self.host)
def password_callback(self, slot, retry, password):
if not retry and password: return password
@@ -156,11 +234,7 @@ class NSSConnection(httplib.HTTPConnection):
pass
def connect(self):
- logging.debug("connect: host=%s port=%s", self.host, self.port)
- self.sock.set_hostname(self.host)
- net_addr = io.NetworkAddress(self.host, self.port)
- logging.debug("connect: %s", net_addr)
- self.sock.connect(net_addr)
+ self.connect_socket(self.host, self.port)
def endheaders(self, message=None):
"""
@@ -206,20 +280,22 @@ class NSSHTTPS(httplib.HTTP):
port = None
self._setup(self._connection_class(host, port, strict, dbdir=dbdir))
-class NSPRConnection(httplib.HTTPConnection):
+class NSPRConnection(httplib.HTTPConnection, NSSAddressFamilyFallback):
default_port = httplib.HTTPConnection.default_port
- def __init__(self, host, port=None, strict=None):
+ def __init__(self, host, port=None, strict=None, family=socket.AF_UNSPEC):
httplib.HTTPConnection.__init__(self, host, port, strict)
+ NSSAddressFamilyFallback.__init__(self, family)
logging.debug('%s init %s', self.__class__.__name__, host)
+ self._create_socket()
+
+ def _create_socket(self):
+ super(NSPRConnection, self)._create_socket()
+ self.sock.set_hostname(self.host)
- self.sock = io.Socket()
def connect(self):
- logging.debug("connect: host=%s port=%s", self.host, self.port)
- net_addr = io.NetworkAddress(self.host, self.port)
- logging.debug("connect: %s", net_addr)
- self.sock.connect(net_addr)
+ self.connect_socket(self.host, self.port)
class NSPRHTTP(httplib.HTTP):
_http_vsn = 11