Merge pull request #44 from status-im/providers-lru

Providers lru
This commit is contained in:
Dmitriy Ryajov 2022-09-12 23:34:02 -04:00 committed by GitHub
commit 94b75f141c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 60 deletions

View File

@ -2,13 +2,16 @@ import std/[tables, lists, options]
{.push raises: [Defect].} {.push raises: [Defect].}
export tables, lists, options
type type
LRUCache*[K, V] = object of RootObj LRUCache*[K, V] = object of RootObj
list: DoublyLinkedList[(K, V)] # Head is MRU k:v and tail is LRU k:v list: DoublyLinkedList[(K, V)] # Head is MRU k:v and tail is LRU k:v
table: Table[K, DoublyLinkedNode[(K, V)]] # DoublyLinkedNode is alraedy ref table: Table[K, DoublyLinkedNode[(K, V)]] # DoublyLinkedNode is already ref
capacity: int capacity: int
func init*[K, V](T: type LRUCache[K, V], capacity: int): LRUCache[K, V] = func init*[K, V](T: type LRUCache[K, V], capacity: int): LRUCache[K, V] =
doAssert capacity > 0, "Capacity should be greater than 0!"
LRUCache[K, V](capacity: capacity) # Table and list init is done default LRUCache[K, V](capacity: capacity) # Table and list init is done default
func get*[K, V](lru: var LRUCache[K, V], key: K): Option[V] = func get*[K, V](lru: var LRUCache[K, V], key: K): Option[V] =
@ -25,7 +28,7 @@ func put*[K, V](lru: var LRUCache[K, V], key: K, value: V) =
if not node.isNil: if not node.isNil:
lru.list.remove(node) lru.list.remove(node)
else: else:
if lru.table.len >= lru.capacity: if lru.len > 0 and lru.table.len >= lru.capacity:
lru.table.del(lru.list.tail.value[0]) lru.table.del(lru.list.tail.value[0])
lru.list.remove(lru.list.tail) lru.list.remove(lru.list.tail)
@ -39,3 +42,16 @@ func del*[K, V](lru: var LRUCache[K, V], key: K) =
func len*[K, V](lru: LRUCache[K, V]): int = func len*[K, V](lru: LRUCache[K, V]): int =
lru.table.len lru.table.len
proc contains*[K, V](lru: LRUCache[K, V], k: K): bool =
## Check for cached item - this doesn't touch the cache
##
k in lru.table
iterator items*[K, V](lru: LRUCache[K, V]): V =
## Get cached items - this doesn't touch the cache
##
for item in lru.list:
yield item[1]

View File

@ -22,7 +22,7 @@ import
export providers_messages export providers_messages
type type
MessageKind* = enum MessageKind* {.pure.} = enum
# TODO This is needed only to make Nim 1.2.6 happy # TODO This is needed only to make Nim 1.2.6 happy
# Without it, the `MessageKind` type cannot be used as # Without it, the `MessageKind` type cannot be used as
# a discriminator in case objects. # a discriminator in case objects.
@ -116,16 +116,16 @@ type
discard discard
template messageKind*(T: typedesc[SomeMessage]): MessageKind = template messageKind*(T: typedesc[SomeMessage]): MessageKind =
when T is PingMessage: ping when T is PingMessage: MessageKind.ping
elif T is PongMessage: pong elif T is PongMessage: MessageKind.pong
elif T is FindNodeMessage: findNode elif T is FindNodeMessage: MessageKind.findNode
elif T is FindNodeFastMessage: findNodeFast elif T is FindNodeFastMessage: MessageKind.findNodeFast
elif T is NodesMessage: nodes elif T is NodesMessage: MessageKind.nodes
elif T is TalkReqMessage: talkReq elif T is TalkReqMessage: MessageKind.talkReq
elif T is TalkRespMessage: talkResp elif T is TalkRespMessage: MessageKind.talkResp
elif T is AddProviderMessage: addProvider elif T is AddProviderMessage: MessageKind.addProvider
elif T is GetProvidersMessage: getProviders elif T is GetProvidersMessage: MessageKind.getProviders
elif T is ProvidersMessage: providers elif T is ProvidersMessage: MessageKind.providers
proc hash*(reqId: RequestId): Hash = proc hash*(reqId: RequestId): Hash =
hash(reqId.id) hash(reqId.id)

View File

@ -79,7 +79,7 @@ import
stew/[base64, endians2, results], chronicles, chronicles/chronos_tools, chronos, chronos/timer, stint, bearssl, stew/[base64, endians2, results], chronicles, chronicles/chronos_tools, chronos, chronos/timer, stint, bearssl,
metrics, metrics,
libp2p/[crypto/crypto, routing_record], libp2p/[crypto/crypto, routing_record],
"."/[transport, messages, messages_encoding, node, routing_table, spr, random2, ip_vote, nodes_verification] "."/[transport, messages, messages_encoding, node, routing_table, spr, random2, ip_vote, nodes_verification, lru]
import nimcrypto except toHex import nimcrypto except toHex
@ -98,20 +98,22 @@ logScope:
topics = "discv5" topics = "discv5"
const const
alpha = 3 ## Kademlia concurrency factor Alpha = 3 ## Kademlia concurrency factor
lookupRequestLimit = 3 ## Amount of distances requested in a single Findnode LookupRequestLimit = 3 ## Amount of distances requested in a single Findnode
## message for a lookup or query ## message for a lookup or query
findNodeResultLimit = 16 ## Maximum amount of SPRs in the total Nodes messages FindNodeResultLimit = 16 ## Maximum amount of SPRs in the total Nodes messages
## that will be processed ## that will be processed
maxNodesPerMessage = 3 ## Maximum amount of SPRs per individual Nodes message MaxNodesPerMessage = 3 ## Maximum amount of SPRs per individual Nodes message
refreshInterval = 5.minutes ## Interval of launching a random query to RefreshInterval = 5.minutes ## Interval of launching a random query to
## refresh the routing table. ## refresh the routing table.
revalidateMax = 10000 ## Revalidation of a peer is done between 0 and this RevalidateMax = 10000 ## Revalidation of a peer is done between 0 and this
## value in milliseconds ## value in milliseconds
ipMajorityInterval = 5.minutes ## Interval for checking the latest IP:Port IpMajorityInterval = 5.minutes ## Interval for checking the latest IP:Port
## majority and updating this when SPR auto update is set. ## majority and updating this when SPR auto update is set.
initialLookups = 1 ## Amount of lookups done when populating the routing table InitialLookups = 1 ## Amount of lookups done when populating the routing table
responseTimeout* = 4.seconds ## timeout for the response of a request-response ResponseTimeout* = 4.seconds ## timeout for the response of a request-response
MaxProvidersEntries* = 1_000_000 # one million records
MaxProvidersPerEntry* = 20 # providers per entry
## call ## call
type type
@ -119,6 +121,9 @@ type
tableIpLimits*: TableIpLimits tableIpLimits*: TableIpLimits
bitsPerHop*: int bitsPerHop*: int
ProvidersCache = LRUCache[PeerId, SignedPeerRecord]
ItemsCache = LRUCache[NodeId, ProvidersCache]
Protocol* = ref object Protocol* = ref object
localNode*: Node localNode*: Node
privateKey: PrivateKey privateKey: PrivateKey
@ -135,7 +140,7 @@ type
talkProtocols*: Table[seq[byte], TalkProtocol] # TODO: Table is a bit of talkProtocols*: Table[seq[byte], TalkProtocol] # TODO: Table is a bit of
# overkill here, use sequence # overkill here, use sequence
rng*: ref BrHmacDrbgContext rng*: ref BrHmacDrbgContext
providers: Table[NodeId, seq[SignedPeerRecord]] providers: ItemsCache
TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte]
{.gcsafe, raises: [Defect].} {.gcsafe, raises: [Defect].}
@ -267,11 +272,11 @@ proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId,
# TODO: Do the total calculation based on the max UDP packet size we want to # TODO: Do the total calculation based on the max UDP packet size we want to
# send and the SPR size of all (max 16) nodes. # send and the SPR size of all (max 16) nodes.
# Which UDP packet size to take? 1280? 576? # Which UDP packet size to take? 1280? 576?
message.total = ceil(nodes.len / maxNodesPerMessage).uint32 message.total = ceil(nodes.len / MaxNodesPerMessage).uint32
for i in 0 ..< nodes.len: for i in 0 ..< nodes.len:
message.sprs.add(nodes[i].record) message.sprs.add(nodes[i].record)
if message.sprs.len == maxNodesPerMessage: if message.sprs.len == MaxNodesPerMessage:
d.sendNodes(toId, toAddr, message, reqId) d.sendNodes(toId, toAddr, message, reqId)
message.sprs.setLen(0) message.sprs.setLen(0)
@ -329,23 +334,36 @@ proc handleTalkReq(d: Protocol, fromId: NodeId, fromAddr: Address,
d.sendResponse(fromId, fromAddr, talkresp, reqId) d.sendResponse(fromId, fromAddr, talkresp, reqId)
proc addProviderLocal(p: Protocol, cId: NodeId, prov: SignedPeerRecord) = proc addProviderLocal(p: Protocol, cId: NodeId, prov: SignedPeerRecord) =
trace "adding provider to local db", cid=cId, spr=prov.data trace "adding provider to local db", n=p.localNode, cId, prov
p.providers.mgetOrPut(cId, @[]).add(prov)
proc handleAddProvider(d: Protocol, fromId: NodeId, fromAddr: Address, var providers =
addProvider: AddProviderMessage, reqId: RequestId) = if cId notin p.providers:
ProvidersCache.init(MaxProvidersPerEntry)
else:
p.providers.get(cId).get()
providers.put(prov.data.peerId, prov)
p.providers.put(cId, providers)
proc handleAddProvider(
d: Protocol,
fromId: NodeId,
fromAddr: Address,
addProvider: AddProviderMessage,
reqId: RequestId) =
d.addProviderLocal(addProvider.cId, addProvider.prov) d.addProviderLocal(addProvider.cId, addProvider.prov)
proc handleGetProviders(d: Protocol, fromId: NodeId, fromAddr: Address, proc handleGetProviders(d: Protocol, fromId: NodeId, fromAddr: Address,
getProviders: GetProvidersMessage, reqId: RequestId) = getProviders: GetProvidersMessage, reqId: RequestId) =
#TODO: add checks, add signed version #TODO: add checks, add signed version
let provs = d.providers.getOrDefault(getProviders.cId) let provs = d.providers.get(getProviders.cId)
trace "providers:", prov=provs.mapIt(it.data) if provs.isSome:
trace "providers:", provs
##TODO: handle multiple messages ##TODO: handle multiple messages
let response = ProvidersMessage(total: 1, provs: provs) let response = ProvidersMessage(total: 1, provs: toSeq(provs.get()))
d.sendResponse(fromId, fromAddr, response, reqId) d.sendResponse(fromId, fromAddr, response, reqId)
proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
message: Message) = message: Message) =
@ -406,7 +424,7 @@ proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId):
result = newFuture[Option[Message]]("waitMessage") result = newFuture[Option[Message]]("waitMessage")
let res = result let res = result
let key = (fromNode.id, reqId) let key = (fromNode.id, reqId)
sleepAsync(responseTimeout).addCallback() do(data: pointer): sleepAsync(ResponseTimeout).addCallback() do(data: pointer):
d.awaitedMessages.del(key) d.awaitedMessages.del(key)
if not res.finished: if not res.finished:
res.complete(none(Message)) res.complete(none(Message))
@ -422,12 +440,12 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId):
## Same counts for out of order receival. ## Same counts for out of order receival.
var op = await d.waitMessage(fromNode, reqId) var op = await d.waitMessage(fromNode, reqId)
if op.isSome: if op.isSome:
if op.get.kind == nodes: if op.get.kind == MessageKind.nodes:
var res = op.get.nodes.sprs var res = op.get.nodes.sprs
let total = op.get.nodes.total let total = op.get.nodes.total
for i in 1 ..< total: for i in 1 ..< total:
op = await d.waitMessage(fromNode, reqId) op = await d.waitMessage(fromNode, reqId)
if op.isSome and op.get.kind == nodes: if op.isSome and op.get.kind == MessageKind.nodes:
res.add(op.get.nodes.sprs) res.add(op.get.nodes.sprs)
else: else:
# No error on this as we received some nodes. # No error on this as we received some nodes.
@ -498,7 +516,7 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]):
let nodes = await d.waitNodes(toNode, reqId) let nodes = await d.waitNodes(toNode, reqId)
if nodes.isOk: if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit, distances) let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances)
d.routingTable.setJustSeen(toNode) d.routingTable.setJustSeen(toNode)
return ok(res) return ok(res)
else: else:
@ -515,7 +533,7 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId):
let nodes = await d.waitNodes(toNode, reqId) let nodes = await d.waitNodes(toNode, reqId)
if nodes.isOk: if nodes.isOk:
let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit) let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit)
d.routingTable.setJustSeen(toNode) d.routingTable.setJustSeen(toNode)
return ok(res) return ok(res)
else: else:
@ -550,7 +568,7 @@ proc lookupDistances*(target, dest: NodeId): seq[uint16] =
let tdAsInt = int(td) let tdAsInt = int(td)
result.add(td) result.add(td)
var i = 1 var i = 1
while result.len < lookupRequestLimit: while result.len < LookupRequestLimit:
if tdAsInt + i < 256: if tdAsInt + i < 256:
result.add(td + uint16(i)) result.add(td + uint16(i))
if tdAsInt - i > 0: if tdAsInt - i > 0:
@ -561,7 +579,7 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId):
Future[seq[Node]] {.async.} = Future[seq[Node]] {.async.} =
let dists = lookupDistances(target, destNode.id) let dists = lookupDistances(target, destNode.id)
# Instead of doing max `lookupRequestLimit` findNode requests, make use # Instead of doing max `LookupRequestLimit` findNode requests, make use
# of the discv5.1 functionality to request nodes for multiple distances. # of the discv5.1 functionality to request nodes for multiple distances.
let r = await d.findNode(destNode, dists) let r = await d.findNode(destNode, dists)
if r.isOk: if r.isOk:
@ -597,13 +615,13 @@ proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]]
for node in closestNodes: for node in closestNodes:
seen.incl(node.id) seen.incl(node.id)
var pendingQueries = newSeqOfCap[Future[seq[Node]]](alpha) var pendingQueries = newSeqOfCap[Future[seq[Node]]](Alpha)
while true: while true:
var i = 0 var i = 0
# Doing `alpha` amount of requests at once as long as closer non queried # Doing `Alpha` amount of requests at once as long as closer non queried
# nodes are discovered. # nodes are discovered.
while i < closestNodes.len and pendingQueries.len < alpha: while i < closestNodes.len and pendingQueries.len < Alpha:
let n = closestNodes[i] let n = closestNodes[i]
if not asked.containsOrIncl(n.id): if not asked.containsOrIncl(n.id):
if fast: if fast:
@ -693,7 +711,7 @@ proc getProvidersLocal*(
): seq[SignedPeerRecord] {.raises: [KeyError,Defect].} = ): seq[SignedPeerRecord] {.raises: [KeyError,Defect].} =
return return
if (cId in d.providers): d.providers[cId] if (cId in d.providers): toSeq(d.providers.get(cId).get())
else: @[] else: @[]
proc getProviders*( proc getProviders*(
@ -752,11 +770,11 @@ proc query*(d: Protocol, target: NodeId, k = BUCKET_SIZE): Future[seq[Node]]
for node in queryBuffer: for node in queryBuffer:
seen.incl(node.id) seen.incl(node.id)
var pendingQueries = newSeqOfCap[Future[seq[Node]]](alpha) var pendingQueries = newSeqOfCap[Future[seq[Node]]](Alpha)
while true: while true:
var i = 0 var i = 0
while i < min(queryBuffer.len, k) and pendingQueries.len < alpha: while i < min(queryBuffer.len, k) and pendingQueries.len < Alpha:
let n = queryBuffer[i] let n = queryBuffer[i]
if not asked.containsOrIncl(n.id): if not asked.containsOrIncl(n.id):
pendingQueries.add(d.lookupWorker(n, target)) pendingQueries.add(d.lookupWorker(n, target))
@ -847,8 +865,8 @@ proc populateTable*(d: Protocol) {.async.} =
let selfQuery = await d.query(d.localNode.id) let selfQuery = await d.query(d.localNode.id)
trace "Discovered nodes in self target query", nodes = selfQuery.len trace "Discovered nodes in self target query", nodes = selfQuery.len
# `initialLookups` random queries # `InitialLookups` random queries
for i in 0..<initialLookups: for i in 0..<InitialLookups:
let randomQuery = await d.queryRandom() let randomQuery = await d.queryRandom()
trace "Discovered nodes in random target query", nodes = randomQuery.len trace "Discovered nodes in random target query", nodes = randomQuery.len
@ -875,7 +893,7 @@ proc revalidateLoop(d: Protocol) {.async.} =
## message. ## message.
try: try:
while true: while true:
await sleepAsync(milliseconds(d.rng[].rand(revalidateMax))) await sleepAsync(milliseconds(d.rng[].rand(RevalidateMax)))
let n = d.routingTable.nodeToRevalidate() let n = d.routingTable.nodeToRevalidate()
if not n.isNil: if not n.isNil:
traceAsyncErrors d.revalidateNode(n) traceAsyncErrors d.revalidateNode(n)
@ -884,19 +902,19 @@ proc revalidateLoop(d: Protocol) {.async.} =
proc refreshLoop(d: Protocol) {.async.} = proc refreshLoop(d: Protocol) {.async.} =
## Loop that refreshes the routing table by starting a random query in case ## Loop that refreshes the routing table by starting a random query in case
## no queries were done since `refreshInterval` or more. ## no queries were done since `RefreshInterval` or more.
## It also refreshes the majority address voted for via pong responses. ## It also refreshes the majority address voted for via pong responses.
try: try:
await d.populateTable() await d.populateTable()
while true: while true:
let currentTime = now(chronos.Moment) let currentTime = now(chronos.Moment)
if currentTime > (d.lastLookup + refreshInterval): if currentTime > (d.lastLookup + RefreshInterval):
let randomQuery = await d.queryRandom() let randomQuery = await d.queryRandom()
trace "Discovered nodes in random target query", nodes = randomQuery.len trace "Discovered nodes in random target query", nodes = randomQuery.len
debug "Total nodes in discv5 routing table", total = d.routingTable.len() debug "Total nodes in discv5 routing table", total = d.routingTable.len()
await sleepAsync(refreshInterval) await sleepAsync(RefreshInterval)
except CancelledError: except CancelledError:
trace "refreshLoop canceled" trace "refreshLoop canceled"
@ -944,7 +962,7 @@ proc ipMajorityLoop(d: Protocol) {.async.} =
debug "Discovered external address matches current address", majority, debug "Discovered external address matches current address", majority,
current = d.localNode.address current = d.localNode.address
await sleepAsync(ipMajorityInterval) await sleepAsync(IpMajorityInterval)
except CancelledError: except CancelledError:
trace "ipMajorityLoop canceled" trace "ipMajorityLoop canceled"
@ -1009,15 +1027,22 @@ proc newProtocol*(
# TODO Consider whether this should be a Defect # TODO Consider whether this should be a Defect
doAssert rng != nil, "RNG initialization failed" doAssert rng != nil, "RNG initialization failed"
let
routingTable = RoutingTable.init(
node,
config.bitsPerHop,
config.tableIpLimits,
rng)
result = Protocol( result = Protocol(
privateKey: privKey, privateKey: privKey,
localNode: node, localNode: node,
bootstrapRecords: @bootstrapRecords, bootstrapRecords: @bootstrapRecords,
ipVote: IpVote.init(), ipVote: IpVote.init(),
enrAutoUpdate: enrAutoUpdate, enrAutoUpdate: enrAutoUpdate,
routingTable: RoutingTable.init( routingTable: routingTable,
node, config.bitsPerHop, config.tableIpLimits, rng), rng: rng,
rng: rng) providers: ItemsCache.init(MaxProvidersEntries))
result.transport = newTransport(result, privKey, node, bindPort, bindIp, rng) result.transport = newTransport(result, privKey, node, bindPort, bindIp, rng)

View File

@ -89,7 +89,7 @@ suite "Discovery v5.1 Protocol Message Encodings":
let message = decoded.get() let message = decoded.get()
check: check:
message.reqId == reqId message.reqId == reqId
message.kind == nodes message.kind == MessageKind.nodes
message.nodes.total == total message.nodes.total == total
message.nodes.sprs.len() == 0 message.nodes.sprs.len() == 0
@ -111,7 +111,7 @@ suite "Discovery v5.1 Protocol Message Encodings":
let message = decoded.get() let message = decoded.get()
check: check:
message.reqId == reqId message.reqId == reqId
message.kind == nodes message.kind == MessageKind.nodes
message.nodes.total == total message.nodes.total == total
message.nodes.sprs.len() == 2 message.nodes.sprs.len() == 2
message.nodes.sprs[0] == s1 message.nodes.sprs[0] == s1