summaryrefslogtreecommitdiffstats
path: root/BitTorrent/NatCheck.py
diff options
context:
space:
mode:
Diffstat (limited to 'BitTorrent/NatCheck.py')
-rw-r--r--BitTorrent/NatCheck.py175
1 files changed, 110 insertions, 65 deletions
diff --git a/BitTorrent/NatCheck.py b/BitTorrent/NatCheck.py
index 6363ded..fe59a15 100644
--- a/BitTorrent/NatCheck.py
+++ b/BitTorrent/NatCheck.py
@@ -17,85 +17,130 @@ protocol_name = 'BitTorrent protocol'
# header, reserved, download id, my id, [length, message]
+from twisted.internet.protocol import Protocol, ClientFactory
+from twisted.internet import reactor
+from twisted.python import log
class NatCheck(object):
- def __init__(self, resultfunc, downloadid, peerid, ip, port, rawserver):
+ def __init__(self, resultfunc, downloadid, peerid, ip, port):
self.resultfunc = resultfunc
self.downloadid = downloadid
self.peerid = peerid
self.ip = ip
self.port = port
- self.closed = False
- self.buffer = StringIO()
- self.next_len = 1
- self.next_func = self.read_header_len
- rawserver.async_start_connection((ip, port), self)
+ self.answered = False
- def connection_started(self, s):
- self.connection = s
- self.connection.write(chr(len(protocol_name)) + protocol_name +
- (chr(0) * 8) + self.downloadid)
+ factory = NatCheckProtocolFactory(self, downloadid, peerid)
- def connection_failed(self, addr, exception):
- self.answer(False)
+ reactor.connectTCP(ip, port, factory)
def answer(self, result):
- self.closed = True
- try:
- self.connection.close()
- except AttributeError:
- pass
- self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port)
-
- def read_header_len(self, s):
- if ord(s) != len(protocol_name):
- return None
- return len(protocol_name), self.read_header
-
- def read_header(self, s):
- if s != protocol_name:
- return None
- return 8, self.read_reserved
-
- def read_reserved(self, s):
- return 20, self.read_download_id
-
- def read_download_id(self, s):
- if s != self.downloadid:
- return None
- return 20, self.read_peer_id
-
- def read_peer_id(self, s):
- if s != self.peerid:
- return None
- self.answer(True)
- return None
-
- def data_came_in(self, connection, s):
- while True:
- if self.closed:
+ if not self.answered:
+ self.answered = True
+ log.msg('NAT check for %s:%i is %s' % (self.ip, self.port, result))
+ self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port)
+
+class NatCheckProtocolFactory(ClientFactory):
+ def __init__(self, natcheck, downloadid, peerid):
+ self.natcheck = natcheck
+ self.downloadid = downloadid
+ self.peerid = peerid
+
+ def startedConnecting(self, connector):
+ log.msg('Started to connect.')
+
+ def buildProtocol(self, addr):
+ return NatCheckProtocol(self, self.downloadid, self.peerid)
+
+ def clientConnectionLost(self, connector, reason):
+ self.natcheck.answer(False)
+ log.msg('Lost connection. Reason: %s' % reason)
+
+ def clientConnectionFailed(self, connector, reason):
+ self.natcheck.answer(False)
+ log.msg('Connection failed. Reason: %s' % reason)
+
+class NatCheckProtocol(Protocol):
+ def __init__(self, factory, downloadid, peerid):
+ self.factory = factory
+ self.downloadid = downloadid
+ self.peerid = peerid
+ self.data = ''
+ self.received_protocol_name_len = None
+ self.received_protocol_name = None
+ self.received_reserved = None
+ self.received_downloadid = None
+ self.received_peerid = None
+
+ def connectionMade(self):
+ self.transport.write(chr(len(protocol_name)))
+ self.transport.write(protocol_name)
+ self.transport.write(chr(0) * 8)
+ self.transport.write(self.downloadid)
+
+ def dataReceived(self, data):
+ self.data += data
+
+ if self.received_protocol_name_len is None:
+ if len(self.data) >= 1:
+ self.received_protocol_name_len = ord(self.data[0])
+ self.data = self.data[1:]
+ if self.received_protocol_name_len != len(protocol_name):
+ self.factory.natcheck.answer(False)
+ self.transport.loseConnection()
+ return
+ else:
return
- i = self.next_len - self.buffer.tell()
- if i > len(s):
- self.buffer.write(s)
+
+ if self.received_protocol_name is None:
+ if len(self.data) >= self.received_protocol_name_len:
+ self.received_protocol_name = self.data[:self.received_protocol_name_len]
+ self.data = self.data[self.received_protocol_name_len:]
+ if self.received_protocol_name != protocol_name:
+ log.err('Received protocol name did not match!')
+ self.factory.natcheck.answer(False)
+ self.transport.loseConnection()
+ return
+ else:
return
- self.buffer.write(s[:i])
- s = s[i:]
- m = self.buffer.getvalue()
- self.buffer.reset()
- self.buffer.truncate()
- x = self.next_func(m)
- if x is None:
- if not self.closed:
- self.answer(False)
+
+ if self.received_reserved is None:
+ if len(self.data) >= 8:
+ self.received_reserved = self.data[:8]
+ self.data = self.data[8:]
+ else:
return
- self.next_len, self.next_func = x
- def connection_lost(self, connection):
- if not self.closed:
- self.closed = True
- self.resultfunc(False, self.downloadid, self.peerid, self.ip, self.port)
+ if self.received_downloadid is None:
+ if len(self.data) >= 20:
+ self.received_downloadid = self.data[:20]
+ self.data = self.data[20:]
+ if self.received_downloadid != self.downloadid:
+ log.err('Received download id did not match!')
+ self.factory.natcheck.answer(False)
+ self.transport.loseConnection()
+ return
+ else:
+ return
+
+ if self.received_peerid is None:
+ if len(self.data) >= 20:
+ log.msg('Peerid length: %i' % len(self.peerid))
+ self.received_peerid = self.data[:20]
+ self.data = self.data[20:]
+ log.msg('Received: %s' % self.received_peerid.encode('hex'))
+ log.msg('Received: %s' % self.received_peerid.encode('quoted-printable'))
+ log.msg('Expected: %s' % self.peerid.encode('hex'))
+ log.msg('Expected: %s' % self.peerid.encode('quoted-printable'))
+ if self.received_peerid != self.peerid:
+ log.err('Received peer id did not match!')
+ self.factory.natcheck.answer(False)
+ self.transport.loseConnection()
+ return
+ else:
+ return
- def connection_flushed(self, connection):
- pass
+ if self.received_protocol_name == protocol_name and self.received_downloadid == self.downloadid and self.received_peerid == self.peerid:
+ self.factory.natcheck.answer(True)
+ self.transport.loseConnection()