diff options
Diffstat (limited to 'khashmir/khashmir.py')
-rw-r--r-- | khashmir/khashmir.py | 442 |
1 files changed, 442 insertions, 0 deletions
diff --git a/khashmir/khashmir.py b/khashmir/khashmir.py new file mode 100644 index 0000000..d939f67 --- /dev/null +++ b/khashmir/khashmir.py @@ -0,0 +1,442 @@ +# The contents of this file are subject to the BitTorrent Open Source License +# Version 1.1 (the License). You may not copy or use this file, in either +# source code or executable form, except in compliance with the License. You +# may obtain a copy of the License at http://www.bittorrent.com/license/. +# +# Software distributed under the License is distributed on an AS IS basis, +# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License +# for the specific language governing rights and limitations under the +# License. + +import const +from socket import gethostbyname + +from BitTorrent.platform import bttime as time + +from sha import sha +import re +from BitTorrent.defaultargs import common_options, rare_options +from BitTorrent.RawServer_magic import RawServer + +from ktable import KTable, K +from knode import * +from kstore import KStore +from khash import newID, newIDInRange + +from util import packNodes +from actions import FindNode, GetValue, KeyExpirer, StoreValue +import krpc + +import sys +import os +import traceback + +from BitTorrent.bencode import bencode, bdecode + +from BitTorrent.defer import Deferred +from random import randrange +from kstore import sample + +from threading import Event, Thread + +ip_pat = re.compile('[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}') + +class KhashmirDBExcept(Exception): + pass + +def foo(bytes): + pass + +# this is the base class, has base functionality and find node, no key-value mappings +class KhashmirBase: + _Node = KNodeBase + def __init__(self, host, port, data_dir, rawserver=None, max_ul_rate=1024, checkpoint=True, errfunc=None, rlcount=foo, config={'pause':False, 'max_rate_period':20}): + if rawserver: + self.rawserver = rawserver + else: + self.flag = Event() + d = dict([(x[0],x[1]) for x in common_options + rare_options]) + self.rawserver = RawServer(self.flag, d) + self.max_ul_rate = max_ul_rate + self.socket = None + self.config = config + self.setup(host, port, data_dir, rlcount, checkpoint) + + def setup(self, host, port, data_dir, rlcount, checkpoint=True): + self.host = host + self.port = port + self.ddir = data_dir + self.store = KStore() + self.pingcache = {} + self.socket = self.rawserver.create_udpsocket(self.port, self.host, False) + self.udp = krpc.hostbroker(self, (self.host, self.port), self.socket, self.rawserver.add_task, self.max_ul_rate, self.config, rlcount) + self._load() + self.rawserver.start_listening_udp(self.socket, self.udp) + self.last = time() + KeyExpirer(self.store, self.rawserver.add_task) + self.refreshTable(force=1) + if checkpoint: + self.rawserver.add_task(self.findCloseNodes, 30, (lambda a: a, True)) + self.rawserver.add_task(self.checkpoint, 60, (1,)) + + def Node(self): + n = self._Node(self.udp.connectionForAddr) + n.table = self + return n + + def __del__(self): + if self.socket is not None: + self.rawserver.stop_listening_udp(self.socket) + self.socket.close() + + def _load(self): + do_load = False + try: + s = open(os.path.join(self.ddir, "routing_table"), 'r').read() + dict = bdecode(s) + except: + id = newID() + else: + id = dict['id'] + do_load = True + + self.node = self._Node(self.udp.connectionForAddr).init(id, self.host, self.port) + self.table = KTable(self.node) + if do_load: + self._loadRoutingTable(dict['rt']) + + + def checkpoint(self, auto=0): + d = {} + d['id'] = self.node.id + d['rt'] = self._dumpRoutingTable() + try: + f = open(os.path.join(self.ddir, "routing_table"), 'wb') + f.write(bencode(d)) + f.close() + except Exception, e: + #XXX real error here + print ">>> unable to dump routing table!", str(e) + pass + + + if auto: + self.rawserver.add_task(self.checkpoint, + randrange(int(const.CHECKPOINT_INTERVAL * .9), + int(const.CHECKPOINT_INTERVAL * 1.1)), + (1,)) + + def _loadRoutingTable(self, nodes): + """ + load routing table nodes from database + it's usually a good idea to call refreshTable(force=1) after loading the table + """ + for rec in nodes: + n = self.Node().initWithDict(rec) + self.table.insertNode(n, contacted=0, nocheck=True) + + def _dumpRoutingTable(self): + """ + save routing table nodes to the database + """ + l = [] + for bucket in self.table.buckets: + for node in bucket.l: + l.append({'id':node.id, 'host':node.host, 'port':node.port, 'age':int(node.age)}) + return l + + + def _addContact(self, host, port, callback=None): + """ + ping this node and add the contact info to the table on pong! + """ + n =self.Node().init(const.NULL_ID, host, port) + try: + self.sendPing(n, callback=callback) + except krpc.KRPCSelfNodeError: + # our own node + pass + + + ####### + ####### LOCAL INTERFACE - use these methods! + def addContact(self, ip, port, callback=None): + """ + ping this node and add the contact info to the table on pong! + """ + if ip_pat.match(ip): + self._addContact(ip, port) + else: + def go(ip=ip, port=port): + ip = gethostbyname(ip) + self.rawserver.external_add_task(self._addContact, 0, (ip, port)) + t = Thread(target=go) + t.start() + + + ## this call is async! + def findNode(self, id, callback, errback=None): + """ returns the contact info for node, or the k closest nodes, from the global table """ + # get K nodes out of local table/cache, or the node we want + nodes = self.table.findNodes(id, invalid=True) + l = [x for x in nodes if x.invalid] + if len(l) > 4: + nodes = sample(l , 4) + self.table.findNodes(id, invalid=False)[:4] + + d = Deferred() + if errback: + d.addCallbacks(callback, errback) + else: + d.addCallback(callback) + if len(nodes) == 1 and nodes[0].id == id : + d.callback(nodes) + else: + # create our search state + state = FindNode(self, id, d.callback, self.rawserver.add_task) + self.rawserver.external_add_task(state.goWithNodes, 0, (nodes,)) + + def insertNode(self, n, contacted=1): + """ + insert a node in our local table, pinging oldest contact in bucket, if necessary + + If all you have is a host/port, then use addContact, which calls this method after + receiving the PONG from the remote node. The reason for the seperation is we can't insert + a node into the table without it's peer-ID. That means of course the node passed into this + method needs to be a properly formed Node object with a valid ID. + """ + old = self.table.insertNode(n, contacted=contacted) + if old and old != n: + if not old.inPing(): + self.checkOldNode(old, n, contacted) + else: + l = self.pingcache.get(old.id, []) + if len(l) < 10 or contacted: + l.append((n, contacted)) + self.pingcache[old.id] = l + + + + def checkOldNode(self, old, new, contacted=False): + ## these are the callbacks used when we ping the oldest node in a bucket + + def cmp(a, b): + if a[1] == 1 and b[1] == 0: + return -1 + elif b[1] == 1 and a[1] == 0: + return 1 + else: + return 0 + + def _staleNodeHandler(dict, old=old, new=new, contacted=contacted): + """ called if the pinged node never responds """ + if old.fails >= 2: + l = self.pingcache.get(old.id, []) + l.sort(cmp) + if l: + n, nc = l[0] + if (not contacted) and nc: + l = l[1:] + [(new, contacted)] + new = n + contacted = nc + o = self.table.replaceStaleNode(old, new) + if o and o != new: + self.checkOldNode(o, new) + try: + self.pingcache[o.id] = self.pingcache[old.id] + del(self.pingcache[old.id]) + except KeyError: + pass + else: + if l: + del(self.pingcache[old.id]) + l.sort(cmp) + for node in l: + self.insertNode(node[0], node[1]) + else: + l = self.pingcache.get(old.id, []) + if l: + del(self.pingcache[old.id]) + self.insertNode(new, contacted) + for node in l: + self.insertNode(node[0], node[1]) + + def _notStaleNodeHandler(dict, old=old, new=new, contacted=contacted): + """ called when we get a pong from the old node """ + self.table.insertNode(old, True) + self.insertNode(new, contacted) + l = self.pingcache.get(old.id, []) + l.sort(cmp) + for node in l: + self.insertNode(node[0], node[1]) + try: + del(self.pingcache[old.id]) + except KeyError: + pass + try: + df = old.ping(self.node.id) + except krpc.KRPCSelfNodeError: + pass + df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler) + + def sendPing(self, node, callback=None): + """ + ping a node + """ + try: + df = node.ping(self.node.id) + except krpc.KRPCSelfNodeError: + pass + else: + ## these are the callbacks we use when we issue a PING + def _pongHandler(dict, node=node, table=self.table, callback=callback): + _krpc_sender = dict['_krpc_sender'] + dict = dict['rsp'] + sender = {'id' : dict['id']} + sender['host'] = _krpc_sender[0] + sender['port'] = _krpc_sender[1] + n = self.Node().initWithDict(sender) + table.insertNode(n) + if callback: + callback() + def _defaultPong(err, node=node, table=self.table, callback=callback): + if callback: + callback() + + df.addCallbacks(_pongHandler,_defaultPong) + + def findCloseNodes(self, callback=lambda a: a, auto=False): + """ + This does a findNode on the ID one away from our own. + This will allow us to populate our table with nodes on our network closest to our own. + This is called as soon as we start up with an empty table + """ + if not self.config['pause']: + id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256) + self.findNode(id, callback) + if auto: + if not self.config['pause']: + self.refreshTable() + self.rawserver.external_add_task(self.findCloseNodes, randrange(int(const.FIND_CLOSE_INTERVAL *0.9), + int(const.FIND_CLOSE_INTERVAL *1.1)), (lambda a: True, True)) + + def refreshTable(self, force=0): + """ + force=1 will refresh table regardless of last bucket access time + """ + def callback(nodes): + pass + + refresh = [bucket for bucket in self.table.buckets if force or (len(bucket.l) < K) or len(filter(lambda a: a.invalid, bucket.l)) or (time() - bucket.lastAccessed > const.BUCKET_STALENESS)] + for bucket in refresh: + id = newIDInRange(bucket.min, bucket.max) + self.findNode(id, callback) + + def stats(self): + """ + Returns (num_contacts, num_nodes) + num_contacts: number contacts in our routing table + num_nodes: number of nodes estimated in the entire dht + """ + num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0) + num_nodes = const.K * (2**(len(self.table.buckets) - 1)) + return {'num_contacts':num_contacts, 'num_nodes':num_nodes} + + def krpc_ping(self, id, _krpc_sender): + sender = {'id' : id} + sender['host'] = _krpc_sender[0] + sender['port'] = _krpc_sender[1] + n = self.Node().initWithDict(sender) + self.insertNode(n, contacted=0) + return {"id" : self.node.id} + + def krpc_find_node(self, target, id, _krpc_sender): + nodes = self.table.findNodes(target, invalid=False) + nodes = map(lambda node: node.senderDict(), nodes) + sender = {'id' : id} + sender['host'] = _krpc_sender[0] + sender['port'] = _krpc_sender[1] + n = self.Node().initWithDict(sender) + self.insertNode(n, contacted=0) + return {"nodes" : packNodes(nodes), "id" : self.node.id} + + +## This class provides read-only access to the DHT, valueForKey +## you probably want to use this mixin and provide your own write methods +class KhashmirRead(KhashmirBase): + _Node = KNodeRead + def retrieveValues(self, key): + try: + l = self.store[key] + except KeyError: + l = [] + return l + ## also async + def valueForKey(self, key, callback, searchlocal = 1): + """ returns the values found for key in global table + callback will be called with a list of values for each peer that returns unique values + final callback will be an empty list - probably should change to 'more coming' arg + """ + nodes = self.table.findNodes(key) + + # get locals + if searchlocal: + l = self.retrieveValues(key) + if len(l) > 0: + self.rawserver.external_add_task(callback, 0, (l,)) + else: + l = [] + + # create our search state + state = GetValue(self, key, callback, self.rawserver.add_task) + self.rawserver.external_add_task(state.goWithNodes, 0, (nodes, l)) + + def krpc_find_value(self, key, id, _krpc_sender): + sender = {'id' : id} + sender['host'] = _krpc_sender[0] + sender['port'] = _krpc_sender[1] + n = self.Node().initWithDict(sender) + self.insertNode(n, contacted=0) + + l = self.retrieveValues(key) + if len(l) > 0: + return {'values' : l, "id": self.node.id} + else: + nodes = self.table.findNodes(key, invalid=False) + nodes = map(lambda node: node.senderDict(), nodes) + return {'nodes' : packNodes(nodes), "id": self.node.id} + +### provides a generic write method, you probably don't want to deploy something that allows +### arbitrary value storage +class KhashmirWrite(KhashmirRead): + _Node = KNodeWrite + ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor) + def storeValueForKey(self, key, value, callback=None): + """ stores the value for key in the global table, returns immediately, no status + in this implementation, peers respond but don't indicate status to storing values + a key can have many values + """ + def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table): + if not response: + # default callback + def _storedValueHandler(sender): + pass + response=_storedValueHandler + action = StoreValue(self, key, value, response, self.rawserver.add_task) + self.rawserver.external_add_task(action.goWithNodes, 0, (nodes,)) + + # this call is asynch + self.findNode(key, _storeValueForKey) + + def krpc_store_value(self, key, value, id, _krpc_sender): + t = "%0.6f" % time() + self.store[key] = value + sender = {'id' : id} + sender['host'] = _krpc_sender[0] + sender['port'] = _krpc_sender[1] + n = self.Node().initWithDict(sender) + self.insertNode(n, contacted=0) + return {"id" : self.node.id} + +# the whole shebang, for testing +class Khashmir(KhashmirWrite): + _Node = KNodeWrite |