# nim-eth # Copyright (c) 2018-2021 Status Research & Development GmbH # Licensed and distributed under either of # * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). # * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). # at your option. This file may not be copied, modified, or distributed except according to those terms. {.push raises: [Defect].} import std/[tables, hashes, times, algorithm, sets, sequtils, random], chronos, bearssl, chronicles, stint, nimcrypto/keccak, ../keys, ./enode export sets # TODO: This should not be needed, but compilation fails otherwise logScope: topics = "kademlia" type KademliaProtocol* [Wire] = ref object wire: Wire thisNode: Node routing: RoutingTable pongFutures: Table[seq[byte], Future[bool]] pingFutures: Table[Node, Future[bool]] neighboursCallbacks: Table[Node, proc(n: seq[Node]) {.gcsafe, raises: [Defect].}] rng: ref BrHmacDrbgContext NodeId* = UInt256 Node* = ref object node*: ENode id*: NodeId RoutingTable = object thisNode: Node buckets: seq[KBucket] KBucket = ref object istart, iend: UInt256 nodes: seq[Node] replacementCache: seq[Node] lastUpdated: float # epochTime const BUCKET_SIZE = 16 BITS_PER_HOP = 8 REQUEST_TIMEOUT = chronos.milliseconds(5000) # timeout of message round trips FIND_CONCURRENCY = 3 # parallel find node lookups ID_SIZE = 256 proc toNodeId*(pk: PublicKey): NodeId = readUintBE[256](keccak256.digest(pk.toRaw()).data) proc newNode*(pk: PublicKey, address: Address): Node = result.new() result.node = ENode(pubkey: pk, address: address) result.id = pk.toNodeId() proc newNode*(uriString: string): Node = result.new() result.node = ENode.fromString(uriString)[] result.id = result.node.pubkey.toNodeId() proc newNode*(enode: ENode): Node = result.new() result.node = enode result.id = result.node.pubkey.toNodeId() proc distanceTo(n: Node, id: NodeId): UInt256 = n.id xor id proc `$`*(n: Node): string = if n == nil: "Node[local]" else: "Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]" proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.toRaw) proc `==`*(a, b: Node): bool = (a.isNil and b.isNil) or (not a.isNil and not b.isNil and a.node.pubkey == b.node.pubkey) proc newKBucket(istart, iend: NodeId): KBucket = result.new() result.istart = istart result.iend = iend result.nodes = @[] result.replacementCache = @[] proc midpoint(k: KBucket): NodeId = k.istart + (k.iend - k.istart) div 2.u256 proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] = sortedByIt(k.nodes, it.distanceTo(id)) proc len(k: KBucket): int = k.nodes.len proc head(k: KBucket): Node = k.nodes[0] proc add(k: KBucket, n: Node): Node = ## Try to add the given node to this bucket. ## If the node is already present, it is moved to the tail of the list, and we return None. ## If the node is not already present and the bucket has fewer than k entries, it is inserted ## at the tail of the list, and we return None. ## If the bucket is full, we add the node to the bucket's replacement cache and return the ## node at the head of the list (i.e. the least recently seen), which should be evicted if it ## fails to respond to a ping. k.lastUpdated = epochTime() let nodeIdx = k.nodes.find(n) if nodeIdx != -1: k.nodes.delete(nodeIdx) k.nodes.add(n) elif k.len < BUCKET_SIZE: k.nodes.add(n) else: k.replacementCache.add(n) return k.head return nil proc removeNode(k: KBucket, n: Node) = let i = k.nodes.find(n) if i != -1: k.nodes.delete(i) proc split(k: KBucket): tuple[lower, upper: KBucket] = ## Split at the median id let splitid = k.midpoint result.lower = newKBucket(k.istart, splitid) result.upper = newKBucket(splitid + 1.u256, k.iend) for node in k.nodes: let bucket = if node.id <= splitid: result.lower else: result.upper discard bucket.add(node) for node in k.replacementCache: let bucket = if node.id <= splitid: result.lower else: result.upper bucket.replacementCache.add(node) proc inRange(k: KBucket, n: Node): bool = k.istart <= n.id and n.id <= k.iend proc isFull(k: KBucket): bool = k.len == BUCKET_SIZE proc contains(k: KBucket, n: Node): bool = n in k.nodes proc binaryGetBucketForNode(buckets: openArray[KBucket], n: Node): KBucket {.raises: [ValueError, Defect].} = ## Given a list of ordered buckets, returns the bucket for a given node. let bucketPos = lowerBound(buckets, n.id) do(a: KBucket, b: NodeId) -> int: cmp(a.iend, b) # Prevents edge cases where bisect_left returns an out of range index if bucketPos < buckets.len: let bucket = buckets[bucketPos] if bucket.istart <= n.id and n.id <= bucket.iend: result = bucket if result.isNil: raise newException(ValueError, "No bucket found for node with id " & $n.id) proc computeSharedPrefixBits(nodes: openArray[Node]): int = ## Count the number of prefix bits shared by all nodes. if nodes.len < 2: return ID_SIZE var mask = zero(UInt256) let one = one(UInt256) for i in 1 .. ID_SIZE: mask = mask or (one shl (ID_SIZE - i)) let reference = nodes[0].id and mask for j in 1 .. nodes.high: if (nodes[j].id and mask) != reference: return i - 1 doAssert(false, "Unable to calculate number of shared prefix bits") proc init(r: var RoutingTable, thisNode: Node) = r.thisNode = thisNode r.buckets = @[newKBucket(0.u256, high(UInt256))] randomize() # for later `randomNodes` selection proc splitBucket(r: var RoutingTable, index: int) = let bucket = r.buckets[index] let (a, b) = bucket.split() r.buckets[index] = a r.buckets.insert(b, index + 1) proc bucketForNode(r: RoutingTable, n: Node): KBucket {.raises: [ValueError, Defect].} = binaryGetBucketForNode(r.buckets, n) proc removeNode(r: var RoutingTable, n: Node) {.raises: [ValueError, Defect].} = r.bucketForNode(n).removeNode(n) proc addNode(r: var RoutingTable, n: Node): Node {.raises: [ValueError, Defect].} = if n == r.thisNode: warn "Trying to add ourselves to the routing table", node = n return let bucket = r.bucketForNode(n) let evictionCandidate = bucket.add(n) if not evictionCandidate.isNil: # Split if the bucket has the local node in its range or if the depth is not congruent # to 0 mod BITS_PER_HOP let depth = computeSharedPrefixBits(bucket.nodes) if bucket.inRange(r.thisNode) or (depth mod BITS_PER_HOP != 0 and depth != ID_SIZE): r.splitBucket(r.buckets.find(bucket)) return r.addNode(n) # retry # Nothing added, ping evictionCandidate return evictionCandidate proc contains(r: RoutingTable, n: Node): bool {.raises: [ValueError, Defect].} = n in r.bucketForNode(n) proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] = sortedByIt(r.buckets, it.distanceTo(id)) proc notFullBuckets(r: RoutingTable): seq[KBucket] = r.buckets.filterIt(not it.isFull) proc neighbours(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE): seq[Node] = ## Return up to k neighbours of the given node. result = newSeqOfCap[Node](k * 2) for bucket in r.bucketsByDistanceTo(id): for n in bucket.nodesByDistanceTo(id): if n.id != id: result.add(n) if result.len == k * 2: break result = sortedByIt(result, it.distanceTo(id)) if result.len > k: result.setLen(k) proc len(r: RoutingTable): int = for b in r.buckets: result += b.len proc newKademliaProtocol*[Wire]( thisNode: Node, wire: Wire, rng = newRng()): KademliaProtocol[Wire] = if rng == nil: raiseAssert "Need an RNG" # doAssert gives compile error on mac result.new() result.thisNode = thisNode result.wire = wire result.routing.init(thisNode) result.rng = rng proc bond(k: KademliaProtocol, n: Node): Future[bool] {.async.} proc bondDiscard(k: KademliaProtocol, n: Node) {.async.} proc updateRoutingTable(k: KademliaProtocol, n: Node) {.raises: [ValueError, Defect], gcsafe.} = ## Update the routing table entry for the given node. let evictionCandidate = k.routing.addNode(n) if not evictionCandidate.isNil: # This means we couldn't add the node because its bucket is full, so schedule a bond() # with the least recently seen node on that bucket. If the bonding fails the node will # be removed from the bucket and a new one will be picked from the bucket's # replacement cache. asyncSpawn k.bondDiscard(evictionCandidate) proc doSleep(p: proc() {.gcsafe, raises: [Defect].}) {.async.} = await sleepAsync(REQUEST_TIMEOUT) p() template onTimeout(b: untyped) = asyncSpawn doSleep() do(): b proc pingId(n: Node, token: seq[byte]): seq[byte] = result = token & @(n.node.pubkey.toRaw) proc waitPong(k: KademliaProtocol, n: Node, pingid: seq[byte]): Future[bool] = doAssert(pingid notin k.pongFutures, "Already waiting for pong from " & $n) result = newFuture[bool]("waitPong") let fut = result k.pongFutures[pingid] = result onTimeout: if not fut.finished: k.pongFutures.del(pingid) fut.complete(false) proc ping(k: KademliaProtocol, n: Node): seq[byte] = doAssert(n != k.thisNode) result = k.wire.sendPing(n) proc waitPing(k: KademliaProtocol, n: Node): Future[bool] = result = newFuture[bool]("waitPing") doAssert(n notin k.pingFutures) k.pingFutures[n] = result let fut = result onTimeout: if not fut.finished: k.pingFutures.del(n) fut.complete(false) proc waitNeighbours(k: KademliaProtocol, remote: Node): Future[seq[Node]] {.raises: [Defect].} = doAssert(remote notin k.neighboursCallbacks) result = newFuture[seq[Node]]("waitNeighbours") let fut = result var neighbours = newSeqOfCap[Node](BUCKET_SIZE) k.neighboursCallbacks[remote] = proc(n: seq[Node]) {.gcsafe, raises: [Defect].} = # This callback is expected to be called multiple times because nodes usually # split the neighbours replies into multiple packets, so we only complete the # future event.set() we've received enough neighbours. for i in n: if i != k.thisNode: neighbours.add(i) if neighbours.len == BUCKET_SIZE: k.neighboursCallbacks.del(remote) doAssert(not fut.finished) fut.complete(neighbours) onTimeout: if not fut.finished: k.neighboursCallbacks.del(remote) fut.complete(neighbours) # Exported for test. proc findNode*(k: KademliaProtocol, nodesSeen: ref HashSet[Node], nodeId: NodeId, remote: Node): Future[seq[Node]] {.async.} = if remote in k.neighboursCallbacks: # Sometimes findNode is called while another findNode is already in flight. # It's a bug when this happens, and the logic should probably be fixed # elsewhere. However, this small fix has been tested and proven adequate. debug "Ignoring peer already in k.neighboursCallbacks", peer = remote result = newSeq[Node]() return k.wire.sendFindNode(remote, nodeId) var candidates = await k.waitNeighbours(remote) if candidates.len == 0: trace "Got no candidates from peer, returning", peer = remote result = candidates else: # The following line: # 1. Add new candidates to nodesSeen so that we don't attempt to bond with failing ones # in the future # 2. Removes all previously seen nodes from candidates # 3. Deduplicates candidates candidates.keepItIf(not nodesSeen[].containsOrIncl(it)) trace "Got new candidates", count = candidates.len var bondedNodes: seq[Future[bool]] = @[] for node in candidates: bondedNodes.add(k.bond(node)) await allFutures(bondedNodes) for i in 0.. maxResults: nodes.setLen(maxResults) proc lookup*(k: KademliaProtocol, nodeId: NodeId): Future[seq[Node]] {.async.} = ## Lookup performs a network search for nodes close to the given target. ## It approaches the target by querying nodes that are closer to it on each iteration. The ## given target does not need to be an actual node identifier. var nodesAsked = initHashSet[Node]() let nodesSeen = new(HashSet[Node]) proc excludeIfAsked(nodes: seq[Node]): seq[Node] = result = toSeq(items(nodes.toHashSet() - nodesAsked)) sortByDistance(result, nodeId, FIND_CONCURRENCY) var closest = k.routing.neighbours(nodeId) trace "Starting lookup; initial neighbours: ", closest var nodesToAsk = excludeIfAsked(closest) while nodesToAsk.len != 0: trace "Node lookup; querying ", nodesToAsk nodesAsked.incl(nodesToAsk.toHashSet()) var findNodeRequests: seq[Future[seq[Node]]] = @[] for node in nodesToAsk: findNodeRequests.add(k.findNode(nodesSeen, nodeId, node)) await allFutures(findNodeRequests) for candidates in findNodeRequests: # `findNode` will not raise so there should be no failures, # and for cancellation this should be fine to raise for now. doAssert(candidates.finished() and not(candidates.failed())) closest.add(candidates.read()) sortByDistance(closest, nodeId, BUCKET_SIZE) nodesToAsk = excludeIfAsked(closest) trace "Kademlia lookup finished", target = nodeId.toHex, closest result = closest proc lookupRandom*(k: KademliaProtocol): Future[seq[Node]] = var id: NodeId var buf: array[sizeof(id), byte] brHmacDrbgGenerate(k.rng[], buf) copyMem(addr id, addr buf[0], sizeof(id)) k.lookup(id) proc resolve*(k: KademliaProtocol, id: NodeId): Future[Node] {.async.} = let closest = await k.lookup(id) for n in closest: if n.id == id: return n proc bootstrap*(k: KademliaProtocol, bootstrapNodes: seq[Node], retries = 0) {.async.} = ## Bond with bootstrap nodes and do initial lookup. Retry `retries` times ## in case of failure, or indefinitely if `retries` is 0. var retryInterval = chronos.milliseconds(2) var numTries = 0 if bootstrapNodes.len != 0: while true: var bondedNodes: seq[Future[bool]] = @[] for node in bootstrapNodes: bondedNodes.add(k.bond(node)) await allFutures(bondedNodes) # `bond` will not raise so there should be no failures, # and for cancellation this should be fine to raise for now. let bonded = bondedNodes.mapIt(it.read()) if true notin bonded: inc numTries if retries == 0 or numTries < retries: info "Failed to bond with bootstrap nodes, retrying" retryInterval = min(chronos.seconds(10), retryInterval * 2) await sleepAsync(retryInterval) else: info "Failed to bond with bootstrap nodes" return else: break discard await k.lookupRandom() # Prepopulate the routing table else: info "Skipping discovery bootstrap, no bootnodes provided" proc recvPong*(k: KademliaProtocol, n: Node, token: seq[byte]) = trace "<<< pong from ", n let pingid = token & @(n.node.pubkey.toRaw) var future: Future[bool] if k.pongFutures.take(pingid, future): future.complete(true) proc recvPing*(k: KademliaProtocol, n: Node, msgHash: any) {.raises: [ValueError, Defect].} = trace "<<< ping from ", n k.updateRoutingTable(n) k.wire.sendPong(n, msgHash) var future: Future[bool] if k.pingFutures.take(n, future): future.complete(true) proc recvNeighbours*(k: KademliaProtocol, remote: Node, neighbours: seq[Node]) = ## Process a neighbours response. ## ## Neighbours responses should only be received as a reply to a find_node, and that is only ## done as part of node lookup, so the actual processing is left to the callback from ## neighbours_callbacks, which is added (and removed after it's done or timed out) in ## wait_neighbours(). trace "Received neighbours", remote, neighbours let cb = k.neighboursCallbacks.getOrDefault(remote) if not cb.isNil: cb(neighbours) else: trace "Unexpected neighbours, probably came too late", remote proc recvFindNode*(k: KademliaProtocol, remote: Node, nodeId: NodeId) {.raises: [ValueError, Defect].} = if remote notin k.routing: # FIXME: This is not correct; a node we've bonded before may have become unavailable # and thus removed from self.routing, but once it's back online we should accept # find_nodes from them. trace "Ignoring find_node request from unknown node ", remote return k.updateRoutingTable(remote) var found = k.routing.neighbours(nodeId) found.sort() do(x, y: Node) -> int: cmp(x.id, y.id) k.wire.sendNeighbours(remote, found) proc randomNodes*(k: KademliaProtocol, count: int): seq[Node] = var count = count let sz = k.routing.len if count > sz: debug "Looking for peers", requested = count, present = sz count = sz result = newSeqOfCap[Node](count) var seen = initHashSet[Node]() # This is a rather inneficient way of randomizing nodes from all buckets, but even if we # iterate over all nodes in the routing table, the time it takes would still be # insignificant compared to the time it takes for the network roundtrips when connecting # to nodes. while len(seen) < count: let bucket = k.routing.buckets.sample() if bucket.nodes.len != 0: let node = bucket.nodes.sample() if node notin seen: result.add(node) seen.incl(node) proc nodesDiscovered*(k: KademliaProtocol): int = k.routing.len when isMainModule: proc randomNode(): Node = newNode("enode://aa36fdf33dd030378a0168efe6ed7d5cc587fafa3cdd375854fe735a2e11ea3650ba29644e2db48368c46e1f60e716300ba49396cd63778bf8a818c09bded46f@13.93.211.84:30303") var nodes = @[randomNode()] doAssert(computeSharedPrefixBits(nodes) == ID_SIZE) nodes.add(randomNode()) nodes[0].id = 0b1.u256 nodes[1].id = 0b0.u256 doAssert(computeSharedPrefixBits(nodes) == ID_SIZE - 1) nodes[0].id = 0b010.u256 nodes[1].id = 0b110.u256 doAssert(computeSharedPrefixBits(nodes) == ID_SIZE - 3)