Use custom distance function in state network (#831)

* Use custom distance calculator in state network
This commit is contained in:
KonradStaniec 2021-09-16 16:13:36 +02:00 committed by GitHub
parent 8f683bd318
commit 6192cd7dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 5 deletions

View File

@ -1,4 +1,6 @@
import stint import
eth/p2p/discoveryv5/routing_table,
stint
const MID* = u256(2).pow(u256(255)) const MID* = u256(2).pow(u256(255))
const MAX* = high(Uint256) const MAX* = high(Uint256)
@ -17,7 +19,7 @@ const MAX* = high(Uint256)
# Raw difference is: 5 - 0 = 5, which is larger than mid point which is equal to 4. # Raw difference is: 5 - 0 = 5, which is larger than mid point which is equal to 4.
# From this we know that the shorter distance is the one wraping around 0, which # From this we know that the shorter distance is the one wraping around 0, which
# is equal to 3 # is equal to 3
proc distance*(node_id: UInt256, content_id: UInt256): UInt256 = func distance*(node_id: UInt256, content_id: UInt256): UInt256 =
let rawDiff = let rawDiff =
if node_id > content_id: if node_id > content_id:
node_id - content_id node_id - content_id
@ -30,3 +32,41 @@ proc distance*(node_id: UInt256, content_id: UInt256): UInt256 =
MAX - rawDiff + UInt256.one MAX - rawDiff + UInt256.one
else: else:
rawDiff rawDiff
# TODO we do not have Uint256 log2 implementation. It would be nice to implement
# it in stint library in some more performant way. This version has O(n) complexity.
func myLog2Distance(value: UInt256): uint16 =
# Logarithm is not defined for zero values. Implementation in stew for builtin
# types return -1 in that case, but here it is just internal function so just make sure
# 0 is never provided.
doAssert(not value.isZero())
if value == UInt256.one:
return 0'u16
var comp = value
var ret = 0'u16
while (comp > 1):
comp = comp shr 1
ret = ret + 1
return ret
func atDistance*(id: UInt256, dist: uint16): UInt256 =
# TODO With current distance function there are always two ids at given distance
# so we might as well do: id - u256(dist), maybe it is worth discussing if every client
# should use the same id in this case.
id + u256(2).pow(dist)
func logDistance*(a, b: UInt256): uint16 =
let distance = distance(a, b)
if distance.isZero():
return 0
else:
return myLog2Distance(distance)
const customDistanceCalculator* =
DistanceCalculator(
calculateDistance: distance,
calculateLogDistance: logDistance,
calculateIdAtDistance: atDistance
)

View File

@ -11,7 +11,7 @@ import
std/[sequtils, sets, algorithm], std/[sequtils, sets, algorithm],
stew/[results, byteutils], chronicles, chronos, nimcrypto/hash, stew/[results, byteutils], chronicles, chronos, nimcrypto/hash,
eth/rlp, eth/p2p/discoveryv5/[protocol, node, enr, routing_table, random2, nodes_verification], eth/rlp, eth/p2p/discoveryv5/[protocol, node, enr, routing_table, random2, nodes_verification],
./messages ./messages, ./custom_distance
export messages export messages
@ -180,7 +180,7 @@ proc new*(T: type PortalProtocol, baseProtocol: protocol.Protocol,
dataRadius = UInt256.high()): T = dataRadius = UInt256.high()): T =
let proto = PortalProtocol( let proto = PortalProtocol(
routingTable: RoutingTable.init(baseProtocol.localNode, DefaultBitsPerHop, routingTable: RoutingTable.init(baseProtocol.localNode, DefaultBitsPerHop,
DefaultTableIpLimits, baseProtocol.rng), DefaultTableIpLimits, baseProtocol.rng, customDistanceCalculator),
protocolHandler: messageHandler, protocolHandler: messageHandler,
baseProtocol: baseProtocol, baseProtocol: baseProtocol,
dataRadius: dataRadius, dataRadius: dataRadius,

View File

@ -8,7 +8,7 @@
{.used.} {.used.}
import import
std/unittest, std/[unittest, sequtils],
stint, stint,
../network/state/custom_distance ../network/state/custom_distance
@ -28,3 +28,32 @@ suite "State network custom distance function":
# Additional test cases to check some basic properties # Additional test cases to check some basic properties
distance(UInt256.zero, MID + MID) == UInt256.zero distance(UInt256.zero, MID + MID) == UInt256.zero
distance(UInt256.zero, UInt256.one) == distance(UInt256.zero, high(UInt256)) distance(UInt256.zero, UInt256.one) == distance(UInt256.zero, high(UInt256))
test "Calculate logarithimic distance":
check:
logDistance(u256(0), u256(0)) == 0
logDistance(u256(0), u256(1)) == 0
logDistance(u256(0), u256(2)) == 1
logDistance(u256(0), u256(4)) == 2
logDistance(u256(0), u256(8)) == 3
logDistance(u256(8), u256(0)) == 3
logDistance(UInt256.zero, MID) == 255
logDistance(UInt256.zero, MID + UInt256.one) == 254
test "Calculate id at log distance":
let logDistances = @[
0'u16, 1, 2, 3, 4, 5, 6, 7, 8
]
# for each log distance, calulate node-id at given distance from node zero, and then
# log distance from calculate node-id to node zero. The results should equal
# starting log distances
let logCalculated = logDistances.map(
proc (x: uint16): uint16 =
let nodeAtDist = atDistance(Uint256.zero, x)
return logDistance(Uint256.zero, nodeAtDist)
)
check:
logDistances == logCalculated