From bfadcfbfaf18916a2ee1b77b896e0d01c6a0ca4f Mon Sep 17 00:00:00 2001 From: KonradStaniec Date: Thu, 2 Sep 2021 14:00:36 +0200 Subject: [PATCH] Make Routing table distance function configurable (#392) --- eth/p2p/discoveryv5/protocol.nim | 2 +- eth/p2p/discoveryv5/routing_table.nim | 84 ++++++++++++++++----------- tests/p2p/discv5_test_helper.nim | 7 +++ tests/p2p/test_routing_table.nim | 61 +++++++++++++++++++ 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 4dbc1c5..a1a0852 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -726,7 +726,7 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} = # If it wasn't seen before, insert node while remaining sorted closestNodes.insert(n, closestNodes.lowerBound(n, proc(x: Node, n: Node): int = - cmp(distanceTo(x, target), distanceTo(n, target)) + cmp(distanceTo(x.id, target), distanceTo(n.id, target)) )) if closestNodes.len > BUCKET_SIZE: diff --git a/eth/p2p/discoveryv5/routing_table.nim b/eth/p2p/discoveryv5/routing_table.nim index 674ac44..cefe7df 100644 --- a/eth/p2p/discoveryv5/routing_table.nim +++ b/eth/p2p/discoveryv5/routing_table.nim @@ -19,6 +19,15 @@ declarePublicGauge routing_table_nodes, "Discovery routing table nodes", labels = ["state"] type + DistanceProc* = proc(a, b: NodeId): NodeId {.raises: [Defect], gcsafe, noSideEffect.} + LogDistanceProc* = proc(a, b: NodeId): uint16 {.raises: [Defect], gcsafe, noSideEffect.} + IdAtDistanceProc* = proc (id: NodeId, dist: uint16): NodeId {.raises: [Defect], gcsafe, noSideEffect.} + + DistanceCalculator* = object + calculateDistance*: DistanceProc + calculateLogDistance*: LogDistanceProc + calculateIdAtDistance*: IdAtDistanceProc + RoutingTable* = object thisNode: Node buckets: seq[KBucket] @@ -32,6 +41,7 @@ type ## will result in an improvement of log base(2^b) n hops per lookup. ipLimits: IpLimits ## IP limits for total routing table: all buckets and ## replacement caches. + distanceCalculator: DistanceCalculator rng: ref BrHmacDrbgContext KBucket = ref object @@ -82,22 +92,21 @@ type ReplacementExisting NoAddress -const - BUCKET_SIZE* = 16 ## Maximum amount of nodes per bucket - REPLACEMENT_CACHE_SIZE* = 8 ## Maximum amount of nodes per replacement cache - ## of a bucket - ID_SIZE = 256 - DefaultBitsPerHop* = 5 - DefaultBucketIpLimit* = 2'u - DefaultTableIpLimit* = 10'u - DefaultTableIpLimits* = TableIpLimits(tableIpLimit: DefaultTableIpLimit, - bucketIpLimit: DefaultBucketIpLimit) - -proc distanceTo*(n: Node, id: NodeId): UInt256 = +# xor distance functions +# +func distanceTo*(a, b: NodeId): Uint256 = ## Calculate the distance to a NodeId. - n.id xor id + a xor b -proc logDist*(a, b: NodeId): uint16 = +func idAtDistance*(id: NodeId, dist: uint16): NodeId = + ## Calculate the "lowest" `NodeId` for given logarithmic distance. + ## A logarithmic distance obviously covers a whole range of distances and thus + ## potential `NodeId`s. + # xor the NodeId with 2^(d - 1) or one could say, calculate back the leading + # zeroes and xor those` with the id. + id xor (1.stuint(256) shl (dist.int - 1)) + +func logDist*(a, b: NodeId): uint16 = ## Calculate the logarithmic distance between two `NodeId`s. ## ## According the specification, this is the log base 2 of the distance. But it @@ -115,6 +124,20 @@ proc logDist*(a, b: NodeId): uint16 = lz += bitops.countLeadingZeroBits(x) break return uint16(a.len * 8 - lz) +# + +const + BUCKET_SIZE* = 16 ## Maximum amount of nodes per bucket + REPLACEMENT_CACHE_SIZE* = 8 ## Maximum amount of nodes per replacement cache + ## of a bucket + ID_SIZE = 256 + DefaultBitsPerHop* = 5 + DefaultBucketIpLimit* = 2'u + DefaultTableIpLimit* = 10'u + DefaultTableIpLimits* = TableIpLimits(tableIpLimit: DefaultTableIpLimit, + bucketIpLimit: DefaultBucketIpLimit) + XorDistanceCalculator* = DistanceCalculator(calculateDistance: distanceTo, + calculateLogDistance: logDist, calculateIdAtDistance: idAtDistance) proc newKBucket(istart, iend: NodeId, bucketIpLimit: uint): KBucket = result.new() @@ -127,11 +150,8 @@ proc newKBucket(istart, iend: NodeId, bucketIpLimit: uint): KBucket = proc midpoint(k: KBucket): NodeId = k.istart + (k.iend - k.istart) div 2.u256 -proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id -proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] = - sortedByIt(k.nodes, it.distanceTo(id)) - proc len(k: KBucket): int = k.nodes.len + proc tail(k: KBucket): Node = k.nodes[high(k.nodes)] proc ipLimitInc(r: var RoutingTable, b: KBucket, n: Node): bool = @@ -233,13 +253,14 @@ proc computeSharedPrefixBits(nodes: openarray[NodeId]): int = doAssert(false, "Unable to calculate number of shared prefix bits") proc init*(r: var RoutingTable, thisNode: Node, bitsPerHop = DefaultBitsPerHop, - ipLimits = DefaultTableIpLimits, rng: ref BrHmacDrbgContext) = + ipLimits = DefaultTableIpLimits, rng: ref BrHmacDrbgContext, distanceCalculator = XorDistanceCalculator) = ## Initialize the routing table for provided `Node` and bitsPerHop value. ## `bitsPerHop` is default set to 5 as recommended by original Kademlia paper. r.thisNode = thisNode r.buckets = @[newKBucket(0.u256, high(Uint256), ipLimits.bucketIpLimit)] r.bitsPerHop = bitsPerHop r.ipLimits.limit = ipLimits.tableIpLimit + r.distanceCalculator = distanceCalculator r.rng = rng proc splitBucket(r: var RoutingTable, index: int) = @@ -397,7 +418,10 @@ proc contains*(r: RoutingTable, n: Node): bool = n in r.bucketForNode(n.id) # Check if the routing table contains node `n`. proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] = - sortedByIt(r.buckets, it.distanceTo(id)) + sortedByIt(r.buckets, r.distanceCalculator.calculateDistance(it.midpoint, id)) + +proc nodesByDistanceTo(r: RoutingTable, k: KBucket, id: NodeId): seq[Node] = + sortedByIt(k.nodes, r.distanceCalculator.calculateDistance(it.id, id)) proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE, seenOnly = false): seq[Node] = @@ -407,7 +431,7 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE, result = newSeqOfCap[Node](k * 2) block addNodes: for bucket in r.bucketsByDistanceTo(id): - for n in bucket.nodesByDistanceTo(id): + for n in r.nodesByDistanceTo(bucket, id): # Only provide actively seen nodes when `seenOnly` set. if not seenOnly or n.seen: result.add(n) @@ -416,25 +440,17 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE, # TODO: is this sort still needed? Can we get nodes closer from the "next" # bucket? - result = sortedByIt(result, it.distanceTo(id)) + result = sortedByIt(result, r.distanceCalculator.calculateDistance(it.id, id)) if result.len > k: result.setLen(k) -proc idAtDistance*(id: NodeId, dist: uint16): NodeId = - ## Calculate the "lowest" `NodeId` for given logarithmic distance. - ## A logarithmic distance obviously covers a whole range of distances and thus - ## potential `NodeId`s. - # xor the NodeId with 2^(d - 1) or one could say, calculate back the leading - # zeroes and xor those` with the id. - id xor (1.stuint(256) shl (dist.int - 1)) - proc neighboursAtDistance*(r: RoutingTable, distance: uint16, k: int = BUCKET_SIZE, seenOnly = false): seq[Node] = ## Return up to k neighbours at given logarithmic distance. - result = r.neighbours(idAtDistance(r.thisNode.id, distance), k, seenOnly) + result = r.neighbours(r.distanceCalculator.calculateIdAtDistance(r.thisNode.id, distance), k, seenOnly) # This is a bit silly, first getting closest nodes then to only keep the ones # that are exactly the requested distance. - keepIf(result, proc(n: Node): bool = logDist(n.id, r.thisNode.id) == distance) + keepIf(result, proc(n: Node): bool = r.distanceCalculator.calculateLogDistance(n.id, r.thisNode.id) == distance) proc neighboursAtDistances*(r: RoutingTable, distances: seq[uint16], k: int = BUCKET_SIZE, seenOnly = false): seq[Node] = @@ -443,12 +459,12 @@ proc neighboursAtDistances*(r: RoutingTable, distances: seq[uint16], # first one prioritize. It might end up not including all the node distances # requested. Need to rework the logic here and not use the neighbours call. if distances.len > 0: - result = r.neighbours(idAtDistance(r.thisNode.id, distances[0]), k, + result = r.neighbours(r.distanceCalculator.calculateIdAtDistance(r.thisNode.id, distances[0]), k, seenOnly) # This is a bit silly, first getting closest nodes then to only keep the ones # that are exactly the requested distances. keepIf(result, proc(n: Node): bool = - distances.contains(logDist(n.id, r.thisNode.id))) + distances.contains(r.distanceCalculator.calculateLogDistance(n.id, r.thisNode.id))) proc len*(r: RoutingTable): int = for b in r.buckets: result += b.len diff --git a/tests/p2p/discv5_test_helper.nim b/tests/p2p/discv5_test_helper.nim index 22bdd96..7b5751f 100644 --- a/tests/p2p/discv5_test_helper.nim +++ b/tests/p2p/discv5_test_helper.nim @@ -42,6 +42,13 @@ proc generateNode*(privKey: PrivateKey, port: int = 20302, some(port), some(port), localEnrFields).expect("Properly intialized private key") result = newNode(enr).expect("Properly initialized node") +proc generateNRandomNodes*(rng: ref BrHmacDrbgContext, n: int): seq[Node] = + var res = newSeq[Node]() + for i in 1..n: + let node = generateNode(PrivateKey.random(rng[])) + res.add(node) + res + proc nodeAndPrivKeyAtDistance*(n: Node, rng: var BrHmacDrbgContext, d: uint32, ip: ValidIpAddress = ValidIpAddress.init("127.0.0.1")): (Node, PrivateKey) = while true: diff --git a/tests/p2p/test_routing_table.nim b/tests/p2p/test_routing_table.nim index 68dc474..26946bb 100644 --- a/tests/p2p/test_routing_table.nim +++ b/tests/p2p/test_routing_table.nim @@ -6,6 +6,20 @@ import ../../eth/keys, ../../eth/p2p/discoveryv5/[routing_table, node, enr], ./discv5_test_helper +func customDistance*(a, b: NodeId): Uint256 = + if a >= b: + a - b + else: + b - a + +func customLogDistance*(a, b: NodeId): uint16 = + let distance = customDistance(a, b) + let modulo = distance mod (u256(uint8.high)) + cast[uint16](modulo) + +func customIdAdDist*(id: NodeId, dist: uint16): NodeId = + id + u256(dist) + suite "Routing Table Tests": let rng = newRng() @@ -16,6 +30,11 @@ suite "Routing Table Tests": let ipLimits = TableIpLimits(tableIpLimit: 200, bucketIpLimit: BUCKET_SIZE + REPLACEMENT_CACHE_SIZE + 1) + let customDistanceCalculator = DistanceCalculator( + calculateDistance: customDistance, + calculateLogDistance: customLogDistance, + calculateIdAtDistance: customIdAdDist) + test "Add local node": let node = generateNode(PrivateKey.random(rng[])) var table: RoutingTable @@ -540,3 +559,45 @@ suite "Routing Table Tests": check table.addNode(n) == Added check table.len == int(DefaultTableIpLimits.bucketIpLimit) + 1 + + test "Custom distance calculator: distance": + let numNodes = 10 + let local = generateNode(PrivateKey.random(rng[])) + var table: RoutingTable + table.init(local, 1, ipLimits, rng = rng, distanceCalculator = customDistanceCalculator) + + let nodes = generateNRandomNodes(rng, numNodes) + + for n in nodes: + check table.addNode(n) == Added + + let neighbours = table.neighbours(local.id) + check len(neighbours) == numNodes + + # check that neighbours are sorted by provdied custom distance funciton + for i in 0..numNodes-2: + let prevDist = customDistance(local.id, neighbours[i].id) + let nextDist = customDistance(local.id, neighbours[i + 1].id) + check prevDist <= nextDist + + test "Custom distance calculator: at log distance": + let numNodes = 10 + let local = generateNode(PrivateKey.random(rng[])) + var table: RoutingTable + table.init(local, 1, ipLimits, rng = rng, distanceCalculator = customDistanceCalculator) + + let nodes = generateNRandomNodes(rng, numNodes) + + for n in nodes: + check table.addNode(n) == Added + + let neighbours = table.neighbours(local.id) + check len(neighbours) == numNodes + + for n in neighbours: + let cLogDist = customLogDistance(local.id, n.id) + let neighboursAtLogDist = table.neighboursAtDistance(cLogDist) + # there may be more than one node at provided distance + check len(neighboursAtLogDist) >= 1 + check neighboursAtLogDist.contains(n) +