summaryrefslogtreecommitdiffstats
path: root/func/SSLConnection.py
blob: 98ed8a01843ea1125c89431299b5f26358aec395 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Higher-level SSL objects used by rpclib
#
# Copyright (c) 2002 Red Hat, Inc.
#
# Author: Mihai Ibanescu <misa@redhat.com>
# Modifications by Dan Williams <dcbw@redhat.com>


from OpenSSL import SSL
import time, socket, select
from func.CommonErrors import canIgnoreSSLError


class SSLConnection:
    """
    This whole class exists just to filter out a parameter
    passed in to the shutdown() method in SimpleXMLRPC.doPOST()
    """

    DEFAULT_TIMEOUT = 20

    def __init__(self, conn):
        """
        Connection is not yet a new-style class,
        so I'm making a proxy instead of subclassing.
        """
        self.__dict__["conn"] = conn
        self.__dict__["close_refcount"] = 0
        self.__dict__["closed"] = False
        self.__dict__["timeout"] = self.DEFAULT_TIMEOUT

    def __del__(self):
        self.__dict__["conn"].close()

    def __getattr__(self,name):
        return getattr(self.__dict__["conn"], name)

    def __setattr__(self,name, value):
        setattr(self.__dict__["conn"], name, value)

    def settimeout(self, timeout):
        if timeout == None:
            self.__dict__["timeout"] = self.DEFAULT_TIMEOUT
        else:
            self.__dict__["timeout"] = timeout
        self.__dict__["conn"].settimeout(timeout)

    def shutdown(self, how=1):
        """
        SimpleXMLRpcServer.doPOST calls shutdown(1),
        and Connection.shutdown() doesn't take
        an argument. So we just discard the argument.
        """
        self.__dict__["conn"].shutdown()

    def accept(self):
        """
        This is the other part of the shutdown() workaround.
        Since servers create new sockets, we have to infect
        them with our magic. :)
        """
        c, a = self.__dict__["conn"].accept()
        return (SSLConnection(c), a)

    def makefile(self, mode, bufsize):
        """
        We need to use socket._fileobject Because SSL.Connection
        doesn't have a 'dup'. Not exactly sure WHY this is, but
        this is backed up by comments in socket.py and SSL/connection.c

        Since httplib.HTTPSResponse/HTTPConnection depend on the
        socket being duplicated when they close it, we refcount the
        socket object and don't actually close until its count is 0.
        """
        self.__dict__["close_refcount"] = self.__dict__["close_refcount"] + 1
        return PlgFileObject(self, mode, bufsize)

    def close(self):
        if self.__dict__["closed"]:
            return
        self.__dict__["close_refcount"] = self.__dict__["close_refcount"] - 1
        if self.__dict__["close_refcount"] == 0:
            self.shutdown()
            self.__dict__["conn"].close()
            self.__dict__["closed"] = True

    def sendall(self, data, flags=0):
        """
        - Use select() to simulate a socket timeout without setting the socket
            to non-blocking mode.
        - Don't use pyOpenSSL's sendall() either, since it just loops on WantRead
            or WantWrite, consuming 100% CPU, and never times out.
        """
        timeout = self.__dict__["timeout"]
        con = self.__dict__["conn"]
        (read, write, excpt) = select.select([], [con], [], timeout)
        if not con in write:
            raise socket.timeout((110, "Operation timed out."))

        starttime = time.time()
        origlen = len(data)
        sent = -1
        while len(data):
            curtime = time.time()
            if curtime - starttime > timeout:
                raise socket.timeout((110, "Operation timed out."))

            try:
                sent = con.send(data, flags)
            except SSL.SysCallError, e:
                if e[0] == 32:      # Broken Pipe
                    self.close()
                    sent = 0
                else:
                    raise socket.error(e)
            except (SSL.WantWriteError, SSL.WantReadError):
                time.sleep(0.2)
                continue

            data = data[sent:]
        return origlen - len(data)

    def recv(self, bufsize, flags=0):
        """
        Use select() to simulate a socket timeout without setting the socket
        to non-blocking mode
        """
        timeout = self.__dict__["timeout"]
        con = self.__dict__["conn"]
        (read, write, excpt) = select.select([con], [], [], timeout)
        if not con in read:
            raise socket.timeout((110, "Operation timed out."))

        starttime = time.time()
        while True:
            curtime = time.time()
            if curtime - starttime > timeout:
                raise socket.timeout((110, "Operation timed out."))

            try:
                return con.recv(bufsize, flags)
            except SSL.ZeroReturnError:
                return None
            except SSL.WantReadError:
                time.sleep(0.2)
            except Exception, e:
                if canIgnoreSSLError(e):
                    return None
                else:
                    raise e
        return None


class PlgFileObject(socket._fileobject):
    def close(self):
        """
        socket._fileobject doesn't actually _close_ the socket,
        which we want it to do, so we have to override.
        """
        try:
            if self._sock:
                self.flush()
                self._sock.close()
        finally:
            self._sock = None