mirror of
https://github.com/status-im/nim-eth.git
synced 2025-02-16 16:06:35 +00:00
Make Routing table distance function configurable (#392)
This commit is contained in:
parent
aa3fbbd95d
commit
bfadcfbfaf
@ -726,7 +726,7 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} =
|
|||||||
# If it wasn't seen before, insert node while remaining sorted
|
# If it wasn't seen before, insert node while remaining sorted
|
||||||
closestNodes.insert(n, closestNodes.lowerBound(n,
|
closestNodes.insert(n, closestNodes.lowerBound(n,
|
||||||
proc(x: Node, n: Node): int =
|
proc(x: Node, n: Node): int =
|
||||||
cmp(distanceTo(x, target), distanceTo(n, target))
|
cmp(distanceTo(x.id, target), distanceTo(n.id, target))
|
||||||
))
|
))
|
||||||
|
|
||||||
if closestNodes.len > BUCKET_SIZE:
|
if closestNodes.len > BUCKET_SIZE:
|
||||||
|
@ -19,6 +19,15 @@ declarePublicGauge routing_table_nodes,
|
|||||||
"Discovery routing table nodes", labels = ["state"]
|
"Discovery routing table nodes", labels = ["state"]
|
||||||
|
|
||||||
type
|
type
|
||||||
|
DistanceProc* = proc(a, b: NodeId): NodeId {.raises: [Defect], gcsafe, noSideEffect.}
|
||||||
|
LogDistanceProc* = proc(a, b: NodeId): uint16 {.raises: [Defect], gcsafe, noSideEffect.}
|
||||||
|
IdAtDistanceProc* = proc (id: NodeId, dist: uint16): NodeId {.raises: [Defect], gcsafe, noSideEffect.}
|
||||||
|
|
||||||
|
DistanceCalculator* = object
|
||||||
|
calculateDistance*: DistanceProc
|
||||||
|
calculateLogDistance*: LogDistanceProc
|
||||||
|
calculateIdAtDistance*: IdAtDistanceProc
|
||||||
|
|
||||||
RoutingTable* = object
|
RoutingTable* = object
|
||||||
thisNode: Node
|
thisNode: Node
|
||||||
buckets: seq[KBucket]
|
buckets: seq[KBucket]
|
||||||
@ -32,6 +41,7 @@ type
|
|||||||
## will result in an improvement of log base(2^b) n hops per lookup.
|
## will result in an improvement of log base(2^b) n hops per lookup.
|
||||||
ipLimits: IpLimits ## IP limits for total routing table: all buckets and
|
ipLimits: IpLimits ## IP limits for total routing table: all buckets and
|
||||||
## replacement caches.
|
## replacement caches.
|
||||||
|
distanceCalculator: DistanceCalculator
|
||||||
rng: ref BrHmacDrbgContext
|
rng: ref BrHmacDrbgContext
|
||||||
|
|
||||||
KBucket = ref object
|
KBucket = ref object
|
||||||
@ -82,22 +92,21 @@ type
|
|||||||
ReplacementExisting
|
ReplacementExisting
|
||||||
NoAddress
|
NoAddress
|
||||||
|
|
||||||
const
|
# xor distance functions
|
||||||
BUCKET_SIZE* = 16 ## Maximum amount of nodes per bucket
|
#
|
||||||
REPLACEMENT_CACHE_SIZE* = 8 ## Maximum amount of nodes per replacement cache
|
func distanceTo*(a, b: NodeId): Uint256 =
|
||||||
## of a bucket
|
|
||||||
ID_SIZE = 256
|
|
||||||
DefaultBitsPerHop* = 5
|
|
||||||
DefaultBucketIpLimit* = 2'u
|
|
||||||
DefaultTableIpLimit* = 10'u
|
|
||||||
DefaultTableIpLimits* = TableIpLimits(tableIpLimit: DefaultTableIpLimit,
|
|
||||||
bucketIpLimit: DefaultBucketIpLimit)
|
|
||||||
|
|
||||||
proc distanceTo*(n: Node, id: NodeId): UInt256 =
|
|
||||||
## Calculate the distance to a NodeId.
|
## Calculate the distance to a NodeId.
|
||||||
n.id xor id
|
a xor b
|
||||||
|
|
||||||
proc logDist*(a, b: NodeId): uint16 =
|
func idAtDistance*(id: NodeId, dist: uint16): NodeId =
|
||||||
|
## Calculate the "lowest" `NodeId` for given logarithmic distance.
|
||||||
|
## A logarithmic distance obviously covers a whole range of distances and thus
|
||||||
|
## potential `NodeId`s.
|
||||||
|
# xor the NodeId with 2^(d - 1) or one could say, calculate back the leading
|
||||||
|
# zeroes and xor those` with the id.
|
||||||
|
id xor (1.stuint(256) shl (dist.int - 1))
|
||||||
|
|
||||||
|
func logDist*(a, b: NodeId): uint16 =
|
||||||
## Calculate the logarithmic distance between two `NodeId`s.
|
## Calculate the logarithmic distance between two `NodeId`s.
|
||||||
##
|
##
|
||||||
## According the specification, this is the log base 2 of the distance. But it
|
## According the specification, this is the log base 2 of the distance. But it
|
||||||
@ -115,6 +124,20 @@ proc logDist*(a, b: NodeId): uint16 =
|
|||||||
lz += bitops.countLeadingZeroBits(x)
|
lz += bitops.countLeadingZeroBits(x)
|
||||||
break
|
break
|
||||||
return uint16(a.len * 8 - lz)
|
return uint16(a.len * 8 - lz)
|
||||||
|
#
|
||||||
|
|
||||||
|
const
|
||||||
|
BUCKET_SIZE* = 16 ## Maximum amount of nodes per bucket
|
||||||
|
REPLACEMENT_CACHE_SIZE* = 8 ## Maximum amount of nodes per replacement cache
|
||||||
|
## of a bucket
|
||||||
|
ID_SIZE = 256
|
||||||
|
DefaultBitsPerHop* = 5
|
||||||
|
DefaultBucketIpLimit* = 2'u
|
||||||
|
DefaultTableIpLimit* = 10'u
|
||||||
|
DefaultTableIpLimits* = TableIpLimits(tableIpLimit: DefaultTableIpLimit,
|
||||||
|
bucketIpLimit: DefaultBucketIpLimit)
|
||||||
|
XorDistanceCalculator* = DistanceCalculator(calculateDistance: distanceTo,
|
||||||
|
calculateLogDistance: logDist, calculateIdAtDistance: idAtDistance)
|
||||||
|
|
||||||
proc newKBucket(istart, iend: NodeId, bucketIpLimit: uint): KBucket =
|
proc newKBucket(istart, iend: NodeId, bucketIpLimit: uint): KBucket =
|
||||||
result.new()
|
result.new()
|
||||||
@ -127,11 +150,8 @@ proc newKBucket(istart, iend: NodeId, bucketIpLimit: uint): KBucket =
|
|||||||
proc midpoint(k: KBucket): NodeId =
|
proc midpoint(k: KBucket): NodeId =
|
||||||
k.istart + (k.iend - k.istart) div 2.u256
|
k.istart + (k.iend - k.istart) div 2.u256
|
||||||
|
|
||||||
proc distanceTo(k: KBucket, id: NodeId): UInt256 = k.midpoint xor id
|
|
||||||
proc nodesByDistanceTo(k: KBucket, id: NodeId): seq[Node] =
|
|
||||||
sortedByIt(k.nodes, it.distanceTo(id))
|
|
||||||
|
|
||||||
proc len(k: KBucket): int = k.nodes.len
|
proc len(k: KBucket): int = k.nodes.len
|
||||||
|
|
||||||
proc tail(k: KBucket): Node = k.nodes[high(k.nodes)]
|
proc tail(k: KBucket): Node = k.nodes[high(k.nodes)]
|
||||||
|
|
||||||
proc ipLimitInc(r: var RoutingTable, b: KBucket, n: Node): bool =
|
proc ipLimitInc(r: var RoutingTable, b: KBucket, n: Node): bool =
|
||||||
@ -233,13 +253,14 @@ proc computeSharedPrefixBits(nodes: openarray[NodeId]): int =
|
|||||||
doAssert(false, "Unable to calculate number of shared prefix bits")
|
doAssert(false, "Unable to calculate number of shared prefix bits")
|
||||||
|
|
||||||
proc init*(r: var RoutingTable, thisNode: Node, bitsPerHop = DefaultBitsPerHop,
|
proc init*(r: var RoutingTable, thisNode: Node, bitsPerHop = DefaultBitsPerHop,
|
||||||
ipLimits = DefaultTableIpLimits, rng: ref BrHmacDrbgContext) =
|
ipLimits = DefaultTableIpLimits, rng: ref BrHmacDrbgContext, distanceCalculator = XorDistanceCalculator) =
|
||||||
## Initialize the routing table for provided `Node` and bitsPerHop value.
|
## Initialize the routing table for provided `Node` and bitsPerHop value.
|
||||||
## `bitsPerHop` is default set to 5 as recommended by original Kademlia paper.
|
## `bitsPerHop` is default set to 5 as recommended by original Kademlia paper.
|
||||||
r.thisNode = thisNode
|
r.thisNode = thisNode
|
||||||
r.buckets = @[newKBucket(0.u256, high(Uint256), ipLimits.bucketIpLimit)]
|
r.buckets = @[newKBucket(0.u256, high(Uint256), ipLimits.bucketIpLimit)]
|
||||||
r.bitsPerHop = bitsPerHop
|
r.bitsPerHop = bitsPerHop
|
||||||
r.ipLimits.limit = ipLimits.tableIpLimit
|
r.ipLimits.limit = ipLimits.tableIpLimit
|
||||||
|
r.distanceCalculator = distanceCalculator
|
||||||
r.rng = rng
|
r.rng = rng
|
||||||
|
|
||||||
proc splitBucket(r: var RoutingTable, index: int) =
|
proc splitBucket(r: var RoutingTable, index: int) =
|
||||||
@ -397,7 +418,10 @@ proc contains*(r: RoutingTable, n: Node): bool = n in r.bucketForNode(n.id)
|
|||||||
# Check if the routing table contains node `n`.
|
# Check if the routing table contains node `n`.
|
||||||
|
|
||||||
proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] =
|
proc bucketsByDistanceTo(r: RoutingTable, id: NodeId): seq[KBucket] =
|
||||||
sortedByIt(r.buckets, it.distanceTo(id))
|
sortedByIt(r.buckets, r.distanceCalculator.calculateDistance(it.midpoint, id))
|
||||||
|
|
||||||
|
proc nodesByDistanceTo(r: RoutingTable, k: KBucket, id: NodeId): seq[Node] =
|
||||||
|
sortedByIt(k.nodes, r.distanceCalculator.calculateDistance(it.id, id))
|
||||||
|
|
||||||
proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE,
|
proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE,
|
||||||
seenOnly = false): seq[Node] =
|
seenOnly = false): seq[Node] =
|
||||||
@ -407,7 +431,7 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE,
|
|||||||
result = newSeqOfCap[Node](k * 2)
|
result = newSeqOfCap[Node](k * 2)
|
||||||
block addNodes:
|
block addNodes:
|
||||||
for bucket in r.bucketsByDistanceTo(id):
|
for bucket in r.bucketsByDistanceTo(id):
|
||||||
for n in bucket.nodesByDistanceTo(id):
|
for n in r.nodesByDistanceTo(bucket, id):
|
||||||
# Only provide actively seen nodes when `seenOnly` set.
|
# Only provide actively seen nodes when `seenOnly` set.
|
||||||
if not seenOnly or n.seen:
|
if not seenOnly or n.seen:
|
||||||
result.add(n)
|
result.add(n)
|
||||||
@ -416,25 +440,17 @@ proc neighbours*(r: RoutingTable, id: NodeId, k: int = BUCKET_SIZE,
|
|||||||
|
|
||||||
# TODO: is this sort still needed? Can we get nodes closer from the "next"
|
# TODO: is this sort still needed? Can we get nodes closer from the "next"
|
||||||
# bucket?
|
# bucket?
|
||||||
result = sortedByIt(result, it.distanceTo(id))
|
result = sortedByIt(result, r.distanceCalculator.calculateDistance(it.id, id))
|
||||||
if result.len > k:
|
if result.len > k:
|
||||||
result.setLen(k)
|
result.setLen(k)
|
||||||
|
|
||||||
proc idAtDistance*(id: NodeId, dist: uint16): NodeId =
|
|
||||||
## Calculate the "lowest" `NodeId` for given logarithmic distance.
|
|
||||||
## A logarithmic distance obviously covers a whole range of distances and thus
|
|
||||||
## potential `NodeId`s.
|
|
||||||
# xor the NodeId with 2^(d - 1) or one could say, calculate back the leading
|
|
||||||
# zeroes and xor those` with the id.
|
|
||||||
id xor (1.stuint(256) shl (dist.int - 1))
|
|
||||||
|
|
||||||
proc neighboursAtDistance*(r: RoutingTable, distance: uint16,
|
proc neighboursAtDistance*(r: RoutingTable, distance: uint16,
|
||||||
k: int = BUCKET_SIZE, seenOnly = false): seq[Node] =
|
k: int = BUCKET_SIZE, seenOnly = false): seq[Node] =
|
||||||
## Return up to k neighbours at given logarithmic distance.
|
## Return up to k neighbours at given logarithmic distance.
|
||||||
result = r.neighbours(idAtDistance(r.thisNode.id, distance), k, seenOnly)
|
result = r.neighbours(r.distanceCalculator.calculateIdAtDistance(r.thisNode.id, distance), k, seenOnly)
|
||||||
# This is a bit silly, first getting closest nodes then to only keep the ones
|
# This is a bit silly, first getting closest nodes then to only keep the ones
|
||||||
# that are exactly the requested distance.
|
# that are exactly the requested distance.
|
||||||
keepIf(result, proc(n: Node): bool = logDist(n.id, r.thisNode.id) == distance)
|
keepIf(result, proc(n: Node): bool = r.distanceCalculator.calculateLogDistance(n.id, r.thisNode.id) == distance)
|
||||||
|
|
||||||
proc neighboursAtDistances*(r: RoutingTable, distances: seq[uint16],
|
proc neighboursAtDistances*(r: RoutingTable, distances: seq[uint16],
|
||||||
k: int = BUCKET_SIZE, seenOnly = false): seq[Node] =
|
k: int = BUCKET_SIZE, seenOnly = false): seq[Node] =
|
||||||
@ -443,12 +459,12 @@ proc neighboursAtDistances*(r: RoutingTable, distances: seq[uint16],
|
|||||||
# first one prioritize. It might end up not including all the node distances
|
# first one prioritize. It might end up not including all the node distances
|
||||||
# requested. Need to rework the logic here and not use the neighbours call.
|
# requested. Need to rework the logic here and not use the neighbours call.
|
||||||
if distances.len > 0:
|
if distances.len > 0:
|
||||||
result = r.neighbours(idAtDistance(r.thisNode.id, distances[0]), k,
|
result = r.neighbours(r.distanceCalculator.calculateIdAtDistance(r.thisNode.id, distances[0]), k,
|
||||||
seenOnly)
|
seenOnly)
|
||||||
# This is a bit silly, first getting closest nodes then to only keep the ones
|
# This is a bit silly, first getting closest nodes then to only keep the ones
|
||||||
# that are exactly the requested distances.
|
# that are exactly the requested distances.
|
||||||
keepIf(result, proc(n: Node): bool =
|
keepIf(result, proc(n: Node): bool =
|
||||||
distances.contains(logDist(n.id, r.thisNode.id)))
|
distances.contains(r.distanceCalculator.calculateLogDistance(n.id, r.thisNode.id)))
|
||||||
|
|
||||||
proc len*(r: RoutingTable): int =
|
proc len*(r: RoutingTable): int =
|
||||||
for b in r.buckets: result += b.len
|
for b in r.buckets: result += b.len
|
||||||
|
@ -42,6 +42,13 @@ proc generateNode*(privKey: PrivateKey, port: int = 20302,
|
|||||||
some(port), some(port), localEnrFields).expect("Properly intialized private key")
|
some(port), some(port), localEnrFields).expect("Properly intialized private key")
|
||||||
result = newNode(enr).expect("Properly initialized node")
|
result = newNode(enr).expect("Properly initialized node")
|
||||||
|
|
||||||
|
proc generateNRandomNodes*(rng: ref BrHmacDrbgContext, n: int): seq[Node] =
|
||||||
|
var res = newSeq[Node]()
|
||||||
|
for i in 1..n:
|
||||||
|
let node = generateNode(PrivateKey.random(rng[]))
|
||||||
|
res.add(node)
|
||||||
|
res
|
||||||
|
|
||||||
proc nodeAndPrivKeyAtDistance*(n: Node, rng: var BrHmacDrbgContext, d: uint32,
|
proc nodeAndPrivKeyAtDistance*(n: Node, rng: var BrHmacDrbgContext, d: uint32,
|
||||||
ip: ValidIpAddress = ValidIpAddress.init("127.0.0.1")): (Node, PrivateKey) =
|
ip: ValidIpAddress = ValidIpAddress.init("127.0.0.1")): (Node, PrivateKey) =
|
||||||
while true:
|
while true:
|
||||||
|
@ -6,6 +6,20 @@ import
|
|||||||
../../eth/keys, ../../eth/p2p/discoveryv5/[routing_table, node, enr],
|
../../eth/keys, ../../eth/p2p/discoveryv5/[routing_table, node, enr],
|
||||||
./discv5_test_helper
|
./discv5_test_helper
|
||||||
|
|
||||||
|
func customDistance*(a, b: NodeId): Uint256 =
|
||||||
|
if a >= b:
|
||||||
|
a - b
|
||||||
|
else:
|
||||||
|
b - a
|
||||||
|
|
||||||
|
func customLogDistance*(a, b: NodeId): uint16 =
|
||||||
|
let distance = customDistance(a, b)
|
||||||
|
let modulo = distance mod (u256(uint8.high))
|
||||||
|
cast[uint16](modulo)
|
||||||
|
|
||||||
|
func customIdAdDist*(id: NodeId, dist: uint16): NodeId =
|
||||||
|
id + u256(dist)
|
||||||
|
|
||||||
suite "Routing Table Tests":
|
suite "Routing Table Tests":
|
||||||
let rng = newRng()
|
let rng = newRng()
|
||||||
|
|
||||||
@ -16,6 +30,11 @@ suite "Routing Table Tests":
|
|||||||
let ipLimits = TableIpLimits(tableIpLimit: 200,
|
let ipLimits = TableIpLimits(tableIpLimit: 200,
|
||||||
bucketIpLimit: BUCKET_SIZE + REPLACEMENT_CACHE_SIZE + 1)
|
bucketIpLimit: BUCKET_SIZE + REPLACEMENT_CACHE_SIZE + 1)
|
||||||
|
|
||||||
|
let customDistanceCalculator = DistanceCalculator(
|
||||||
|
calculateDistance: customDistance,
|
||||||
|
calculateLogDistance: customLogDistance,
|
||||||
|
calculateIdAtDistance: customIdAdDist)
|
||||||
|
|
||||||
test "Add local node":
|
test "Add local node":
|
||||||
let node = generateNode(PrivateKey.random(rng[]))
|
let node = generateNode(PrivateKey.random(rng[]))
|
||||||
var table: RoutingTable
|
var table: RoutingTable
|
||||||
@ -540,3 +559,45 @@ suite "Routing Table Tests":
|
|||||||
check table.addNode(n) == Added
|
check table.addNode(n) == Added
|
||||||
|
|
||||||
check table.len == int(DefaultTableIpLimits.bucketIpLimit) + 1
|
check table.len == int(DefaultTableIpLimits.bucketIpLimit) + 1
|
||||||
|
|
||||||
|
test "Custom distance calculator: distance":
|
||||||
|
let numNodes = 10
|
||||||
|
let local = generateNode(PrivateKey.random(rng[]))
|
||||||
|
var table: RoutingTable
|
||||||
|
table.init(local, 1, ipLimits, rng = rng, distanceCalculator = customDistanceCalculator)
|
||||||
|
|
||||||
|
let nodes = generateNRandomNodes(rng, numNodes)
|
||||||
|
|
||||||
|
for n in nodes:
|
||||||
|
check table.addNode(n) == Added
|
||||||
|
|
||||||
|
let neighbours = table.neighbours(local.id)
|
||||||
|
check len(neighbours) == numNodes
|
||||||
|
|
||||||
|
# check that neighbours are sorted by provdied custom distance funciton
|
||||||
|
for i in 0..numNodes-2:
|
||||||
|
let prevDist = customDistance(local.id, neighbours[i].id)
|
||||||
|
let nextDist = customDistance(local.id, neighbours[i + 1].id)
|
||||||
|
check prevDist <= nextDist
|
||||||
|
|
||||||
|
test "Custom distance calculator: at log distance":
|
||||||
|
let numNodes = 10
|
||||||
|
let local = generateNode(PrivateKey.random(rng[]))
|
||||||
|
var table: RoutingTable
|
||||||
|
table.init(local, 1, ipLimits, rng = rng, distanceCalculator = customDistanceCalculator)
|
||||||
|
|
||||||
|
let nodes = generateNRandomNodes(rng, numNodes)
|
||||||
|
|
||||||
|
for n in nodes:
|
||||||
|
check table.addNode(n) == Added
|
||||||
|
|
||||||
|
let neighbours = table.neighbours(local.id)
|
||||||
|
check len(neighbours) == numNodes
|
||||||
|
|
||||||
|
for n in neighbours:
|
||||||
|
let cLogDist = customLogDistance(local.id, n.id)
|
||||||
|
let neighboursAtLogDist = table.neighboursAtDistance(cLogDist)
|
||||||
|
# there may be more than one node at provided distance
|
||||||
|
check len(neighboursAtLogDist) >= 1
|
||||||
|
check neighboursAtLogDist.contains(n)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user