diff options
Diffstat (limited to 'BitTorrent/NatCheck.py')
-rw-r--r-- | BitTorrent/NatCheck.py | 175 |
1 files changed, 110 insertions, 65 deletions
diff --git a/BitTorrent/NatCheck.py b/BitTorrent/NatCheck.py index 6363ded..fe59a15 100644 --- a/BitTorrent/NatCheck.py +++ b/BitTorrent/NatCheck.py @@ -17,85 +17,130 @@ protocol_name = 'BitTorrent protocol' # header, reserved, download id, my id, [length, message] +from twisted.internet.protocol import Protocol, ClientFactory +from twisted.internet import reactor +from twisted.python import log class NatCheck(object): - def __init__(self, resultfunc, downloadid, peerid, ip, port, rawserver): + def __init__(self, resultfunc, downloadid, peerid, ip, port): self.resultfunc = resultfunc self.downloadid = downloadid self.peerid = peerid self.ip = ip self.port = port - self.closed = False - self.buffer = StringIO() - self.next_len = 1 - self.next_func = self.read_header_len - rawserver.async_start_connection((ip, port), self) + self.answered = False - def connection_started(self, s): - self.connection = s - self.connection.write(chr(len(protocol_name)) + protocol_name + - (chr(0) * 8) + self.downloadid) + factory = NatCheckProtocolFactory(self, downloadid, peerid) - def connection_failed(self, addr, exception): - self.answer(False) + reactor.connectTCP(ip, port, factory) def answer(self, result): - self.closed = True - try: - self.connection.close() - except AttributeError: - pass - self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port) - - def read_header_len(self, s): - if ord(s) != len(protocol_name): - return None - return len(protocol_name), self.read_header - - def read_header(self, s): - if s != protocol_name: - return None - return 8, self.read_reserved - - def read_reserved(self, s): - return 20, self.read_download_id - - def read_download_id(self, s): - if s != self.downloadid: - return None - return 20, self.read_peer_id - - def read_peer_id(self, s): - if s != self.peerid: - return None - self.answer(True) - return None - - def data_came_in(self, connection, s): - while True: - if self.closed: + if not self.answered: + self.answered = True + log.msg('NAT check for %s:%i is %s' % (self.ip, self.port, result)) + self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port) + +class NatCheckProtocolFactory(ClientFactory): + def __init__(self, natcheck, downloadid, peerid): + self.natcheck = natcheck + self.downloadid = downloadid + self.peerid = peerid + + def startedConnecting(self, connector): + log.msg('Started to connect.') + + def buildProtocol(self, addr): + return NatCheckProtocol(self, self.downloadid, self.peerid) + + def clientConnectionLost(self, connector, reason): + self.natcheck.answer(False) + log.msg('Lost connection. Reason: %s' % reason) + + def clientConnectionFailed(self, connector, reason): + self.natcheck.answer(False) + log.msg('Connection failed. Reason: %s' % reason) + +class NatCheckProtocol(Protocol): + def __init__(self, factory, downloadid, peerid): + self.factory = factory + self.downloadid = downloadid + self.peerid = peerid + self.data = '' + self.received_protocol_name_len = None + self.received_protocol_name = None + self.received_reserved = None + self.received_downloadid = None + self.received_peerid = None + + def connectionMade(self): + self.transport.write(chr(len(protocol_name))) + self.transport.write(protocol_name) + self.transport.write(chr(0) * 8) + self.transport.write(self.downloadid) + + def dataReceived(self, data): + self.data += data + + if self.received_protocol_name_len is None: + if len(self.data) >= 1: + self.received_protocol_name_len = ord(self.data[0]) + self.data = self.data[1:] + if self.received_protocol_name_len != len(protocol_name): + self.factory.natcheck.answer(False) + self.transport.loseConnection() + return + else: return - i = self.next_len - self.buffer.tell() - if i > len(s): - self.buffer.write(s) + + if self.received_protocol_name is None: + if len(self.data) >= self.received_protocol_name_len: + self.received_protocol_name = self.data[:self.received_protocol_name_len] + self.data = self.data[self.received_protocol_name_len:] + if self.received_protocol_name != protocol_name: + log.err('Received protocol name did not match!') + self.factory.natcheck.answer(False) + self.transport.loseConnection() + return + else: return - self.buffer.write(s[:i]) - s = s[i:] - m = self.buffer.getvalue() - self.buffer.reset() - self.buffer.truncate() - x = self.next_func(m) - if x is None: - if not self.closed: - self.answer(False) + + if self.received_reserved is None: + if len(self.data) >= 8: + self.received_reserved = self.data[:8] + self.data = self.data[8:] + else: return - self.next_len, self.next_func = x - def connection_lost(self, connection): - if not self.closed: - self.closed = True - self.resultfunc(False, self.downloadid, self.peerid, self.ip, self.port) + if self.received_downloadid is None: + if len(self.data) >= 20: + self.received_downloadid = self.data[:20] + self.data = self.data[20:] + if self.received_downloadid != self.downloadid: + log.err('Received download id did not match!') + self.factory.natcheck.answer(False) + self.transport.loseConnection() + return + else: + return + + if self.received_peerid is None: + if len(self.data) >= 20: + log.msg('Peerid length: %i' % len(self.peerid)) + self.received_peerid = self.data[:20] + self.data = self.data[20:] + log.msg('Received: %s' % self.received_peerid.encode('hex')) + log.msg('Received: %s' % self.received_peerid.encode('quoted-printable')) + log.msg('Expected: %s' % self.peerid.encode('hex')) + log.msg('Expected: %s' % self.peerid.encode('quoted-printable')) + if self.received_peerid != self.peerid: + log.err('Received peer id did not match!') + self.factory.natcheck.answer(False) + self.transport.loseConnection() + return + else: + return - def connection_flushed(self, connection): - pass + if self.received_protocol_name == protocol_name and self.received_downloadid == self.downloadid and self.received_peerid == self.peerid: + self.factory.natcheck.answer(True) + self.transport.loseConnection() |