summaryrefslogtreecommitdiffstats
path: root/khashmir/ktable.py
diff options
context:
space:
mode:
Diffstat (limited to 'khashmir/ktable.py')
-rw-r--r--khashmir/ktable.py338
1 files changed, 338 insertions, 0 deletions
diff --git a/khashmir/ktable.py b/khashmir/ktable.py
new file mode 100644
index 0000000..e0a07b4
--- /dev/null
+++ b/khashmir/ktable.py
@@ -0,0 +1,338 @@
+# The contents of this file are subject to the BitTorrent Open Source License
+# Version 1.1 (the License). You may not copy or use this file, in either
+# source code or executable form, except in compliance with the License. You
+# may obtain a copy of the License at http://www.bittorrent.com/license/.
+#
+# Software distributed under the License is distributed on an AS IS basis,
+# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
+# for the specific language governing rights and limitations under the
+# License.
+
+from BitTorrent.platform import bttime as time
+from bisect import *
+from types import *
+
+import khash as hash
+import const
+from const import K, HASH_LENGTH, NULL_ID, MAX_FAILURES, MIN_PING_INTERVAL
+from node import Node
+
+
+def ls(a, b):
+ return cmp(a.lastSeen, b.lastSeen)
+
+class KTable(object):
+ __slots__ = ('node', 'buckets')
+ """local routing table for a kademlia like distributed hash table"""
+ def __init__(self, node):
+ # this is the root node, a.k.a. US!
+ self.node = node
+ self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
+ self.insertNode(node)
+
+ def _bucketIndexForInt(self, num):
+ """the index of the bucket that should hold int"""
+ return bisect_left(self.buckets, num)
+
+ def bucketForInt(self, num):
+ return self.buckets[self._bucketIndexForInt(num)]
+
+ def findNodes(self, id, invalid=True):
+ """
+ return K nodes in our own local table closest to the ID.
+ """
+
+ if isinstance(id, str):
+ num = hash.intify(id)
+ elif isinstance(id, Node):
+ num = id.num
+ elif isinstance(id, int) or isinstance(id, long):
+ num = id
+ else:
+ raise TypeError, "findNodes requires an int, string, or Node"
+
+ nodes = []
+ i = self._bucketIndexForInt(num)
+
+ # if this node is already in our table then return it
+ try:
+ node = self.buckets[i].getNodeWithInt(num)
+ except ValueError:
+ pass
+ else:
+ return [node]
+
+ # don't have the node, get the K closest nodes
+ nodes = nodes + self.buckets[i].l
+ if not invalid:
+ nodes = [a for a in nodes if not a.invalid]
+ if len(nodes) < K:
+ # need more nodes
+ min = i - 1
+ max = i + 1
+ while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
+ #ASw: note that this requires K be even
+ if min >= 0:
+ nodes = nodes + self.buckets[min].l
+ if max < len(self.buckets):
+ nodes = nodes + self.buckets[max].l
+ min = min - 1
+ max = max + 1
+ if not invalid:
+ nodes = [a for a in nodes if not a.invalid]
+
+ nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
+ return nodes[:K]
+
+ def _splitBucket(self, a):
+ diff = (a.max - a.min) / 2
+ b = KBucket([], a.max - diff, a.max)
+ self.buckets.insert(self.buckets.index(a.min) + 1, b)
+ a.max = a.max - diff
+ # transfer nodes to new bucket
+ for anode in a.l[:]:
+ if anode.num >= a.max:
+ a.removeNode(anode)
+ b.addNode(anode)
+
+ def replaceStaleNode(self, stale, new):
+ """this is used by clients to replace a node returned by insertNode after
+ it fails to respond to a Pong message"""
+ i = self._bucketIndexForInt(stale.num)
+
+ if self.buckets[i].hasNode(stale):
+ self.buckets[i].removeNode(stale)
+ if new and self.buckets[i].hasNode(new):
+ self.buckets[i].seenNode(new)
+ elif new:
+ self.buckets[i].addNode(new)
+
+ return
+
+ def insertNode(self, node, contacted=1, nocheck=False):
+ """
+ this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
+ the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
+ contacted means that yes, we contacted THEM and we know the node is reachable
+ """
+ if node.id == NULL_ID or node.id == self.node.id:
+ return
+
+ if contacted:
+ node.updateLastSeen()
+
+ # get the bucket for this node
+ i = self._bucketIndexForInt(node.num)
+ # check to see if node is in the bucket already
+ if self.buckets[i].hasNode(node):
+ it = self.buckets[i].l.index(node.num)
+ xnode = self.buckets[i].l[it]
+ if contacted:
+ node.age = xnode.age
+ self.buckets[i].seenNode(node)
+ elif xnode.lastSeen != 0 and xnode.port == node.port and xnode.host == node.host:
+ xnode.updateLastSeen()
+ return
+
+ # we don't have this node, check to see if the bucket is full
+ if not self.buckets[i].bucketFull():
+ # no, append this node and return
+ self.buckets[i].addNode(node)
+ return
+
+ # full bucket, check to see if any nodes are invalid
+ t = time()
+ invalid = [x for x in self.buckets[i].invalid.values() if x.invalid]
+ if len(invalid) and not nocheck:
+ invalid.sort(ls)
+ while invalid and not self.buckets[i].hasNode(invalid[0]):
+ del(self.buckets[i].invalid[invalid[0].num])
+ invalid = invalid[1:]
+ if invalid and (invalid[0].lastSeen == 0 and invalid[0].fails < MAX_FAILURES):
+ return invalid[0]
+ elif invalid:
+ self.replaceStaleNode(invalid[0], node)
+ return
+
+ stale = [n for n in self.buckets[i].l if (t - n.lastSeen) > MIN_PING_INTERVAL]
+ if len(stale) and not nocheck:
+ stale.sort(ls)
+ return stale[0]
+
+ # bucket is full and all nodes are valid, check to see if self.node is in the bucket
+ if not (self.buckets[i].min <= self.node < self.buckets[i].max):
+ return
+
+ # this bucket is full and contains our node, split the bucket
+ if len(self.buckets) >= HASH_LENGTH:
+ # our table is FULL, this is really unlikely
+ print "Hash Table is FULL! Increase K!"
+ return
+
+ self._splitBucket(self.buckets[i])
+
+ # now that the bucket is split and balanced, try to insert the node again
+ return self.insertNode(node, contacted)
+
+ def justSeenNode(self, id):
+ """call this any time you get a message from a node
+ it will update it in the table if it's there """
+ try:
+ n = self.findNodes(id)[0]
+ except IndexError:
+ return None
+ else:
+ tstamp = n.lastSeen
+ n.updateLastSeen()
+ bucket = self.bucketForInt(n.num)
+ bucket.seenNode(n)
+ return tstamp
+
+ def invalidateNode(self, n):
+ """
+ forget about node n - use when you know that node is invalid
+ """
+ n.invalid = True
+ bucket = self.bucketForInt(n.num)
+ bucket.invalidateNode(n)
+
+ def nodeFailed(self, node):
+ """ call this when a node fails to respond to a message, to invalidate that node """
+ try:
+ n = self.findNodes(node.num)[0]
+ except IndexError:
+ return None
+ else:
+ if n.msgFailed() >= const.MAX_FAILURES:
+ self.invalidateNode(n)
+
+ def numPeers(self):
+ """ estimated number of connectable nodes in global table """
+ return 8 * (2 ** (len(self.buckets) - 1))
+
+class KBucket(object):
+ __slots__ = ('min', 'max', 'lastAccessed', 'l', 'index', 'invalid')
+ def __init__(self, contents, min, max):
+ self.l = contents
+ self.index = {}
+ self.invalid = {}
+ self.min = min
+ self.max = max
+ self.lastAccessed = time()
+
+ def touch(self):
+ self.lastAccessed = time()
+
+ def lacmp(self, a, b):
+ if a.lastSeen > b.lastSeen:
+ return 1
+ elif b.lastSeen > a.lastSeen:
+ return -1
+ return 0
+
+ def sort(self):
+ self.l.sort(self.lacmp)
+
+ def getNodeWithInt(self, num):
+ try:
+ node = self.index[num]
+ except KeyError:
+ raise ValueError
+ return node
+
+ def addNode(self, node):
+ if len(self.l) >= K:
+ return
+ if self.index.has_key(node.num):
+ return
+ self.l.append(node)
+ self.index[node.num] = node
+ self.touch()
+
+ def removeNode(self, node):
+ assert self.index.has_key(node.num)
+ del(self.l[self.l.index(node.num)])
+ del(self.index[node.num])
+ try:
+ del(self.invalid[node.num])
+ except KeyError:
+ pass
+ self.touch()
+
+ def invalidateNode(self, node):
+ self.invalid[node.num] = node
+
+ def seenNode(self, node):
+ try:
+ del(self.invalid[node.num])
+ except KeyError:
+ pass
+ it = self.l.index(node.num)
+ del(self.l[it])
+ self.l.append(node)
+ self.index[node.num] = node
+
+ def hasNode(self, node):
+ return self.index.has_key(node.num)
+
+ def bucketFull(self):
+ return len(self.l) >= K
+
+ def __repr__(self):
+ return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
+
+ ## Comparators
+ # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
+ # compares integer or node object with the bucket's range
+ def __lt__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.max <= a
+ def __le__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.min < a
+ def __gt__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.min > a
+ def __ge__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.max >= a
+ def __eq__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.min <= a and self.max > a
+ def __ne__(self, a):
+ if isinstance(a, Node): a = a.num
+ return self.min >= a or self.max < a
+
+
+### UNIT TESTS ###
+import unittest
+
+class TestKTable(unittest.TestCase):
+ def setUp(self):
+ self.a = Node().init(hash.newID(), 'localhost', 2002)
+ self.t = KTable(self.a)
+
+ def testAddNode(self):
+ self.b = Node().init(hash.newID(), 'localhost', 2003)
+ self.t.insertNode(self.b)
+ self.assertEqual(len(self.t.buckets[0].l), 1)
+ self.assertEqual(self.t.buckets[0].l[0], self.b)
+
+ def testRemove(self):
+ self.testAddNode()
+ self.t.invalidateNode(self.b)
+ self.assertEqual(len(self.t.buckets[0].l), 0)
+
+ def testFail(self):
+ self.testAddNode()
+ for i in range(const.MAX_FAILURES - 1):
+ self.t.nodeFailed(self.b)
+ self.assertEqual(len(self.t.buckets[0].l), 1)
+ self.assertEqual(self.t.buckets[0].l[0], self.b)
+
+ self.t.nodeFailed(self.b)
+ self.assertEqual(len(self.t.buckets[0].l), 0)
+
+
+if __name__ == "__main__":
+ unittest.main()