diff --git a/libp2pdht/private/eth/p2p/discoveryv5/messages.nim b/libp2pdht/private/eth/p2p/discoveryv5/messages.nim index 7198006..9e3c52f 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/messages.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/messages.nim @@ -16,6 +16,7 @@ import std/[hashes, net], eth/[keys], ./spr, + ./node, ../../../../dht/providers_messages export providers_messages @@ -40,6 +41,7 @@ type addProvider = 0x0B getProviders = 0x0C providers = 0x0D + findNodeFast = 0x83 RequestId* = object id*: seq[byte] @@ -55,6 +57,9 @@ type FindNodeMessage* = object distances*: seq[uint16] + FindNodeFastMessage* = object + target*: NodeId + NodesMessage* = object total*: uint32 sprs*: seq[SignedPeerRecord] @@ -74,7 +79,7 @@ type SomeMessage* = PingMessage or PongMessage or FindNodeMessage or NodesMessage or TalkReqMessage or TalkRespMessage or AddProviderMessage or GetProvidersMessage or - ProvidersMessage + ProvidersMessage or FindNodeFastMessage Message* = object reqId*: RequestId @@ -85,6 +90,8 @@ type pong*: PongMessage of findNode: findNode*: FindNodeMessage + of findNodeFast: + findNodeFast*: FindNodeFastMessage of nodes: nodes*: NodesMessage of talkReq: @@ -112,6 +119,7 @@ template messageKind*(T: typedesc[SomeMessage]): MessageKind = when T is PingMessage: ping elif T is PongMessage: pong elif T is FindNodeMessage: findNode + elif T is FindNodeFastMessage: findNodeFast elif T is NodesMessage: nodes elif T is TalkReqMessage: talkReq elif T is TalkRespMessage: talkResp diff --git a/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim b/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim index 882ad80..7938661 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/messages_encoding.nim @@ -15,7 +15,7 @@ import chronicles, libp2p/routing_record, libp2p/signed_envelope, - "."/[messages, spr], + "."/[messages, spr, node], ../../../../dht/providers_encoding from stew/objects import checkedEnumAssign @@ -60,6 +60,16 @@ proc append*(writer: var RlpWriter, ip: IpAddress) = writer.append(ip.address_v4) of IpAddressFamily.IPv6: writer.append(ip.address_v6) +proc read*(rlp: var Rlp, T: type NodeId): T + {.raises: [ValueError, RlpError, Defect].} = + mixin read + let nodeId = NodeId.fromBytesBE(rlp.toBytes()) + rlp.skipElem() + nodeId + +proc append*(writer: var RlpWriter, value: NodeId) = + writer.append(value.toBytesBE) + proc numFields(T: typedesc): int = for k, v in fieldPairs(default(T)): inc result @@ -117,6 +127,7 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] = of ping: rlp.decode(message.ping) of pong: rlp.decode(message.pong) of findNode: rlp.decode(message.findNode) + of findNodeFast: rlp.decode(message.findNodeFast) of nodes: rlp.decode(message.nodes) of talkReq: rlp.decode(message.talkReq) of talkResp: rlp.decode(message.talkResp) diff --git a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim index 00670fa..85518ea 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim @@ -291,6 +291,12 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address, # with empty nodes. d.sendNodes(fromId, fromAddr, reqId, []) +proc handleFindNodeFast(d: Protocol, fromId: NodeId, fromAddr: Address, + fnf: FindNodeFastMessage, reqId: RequestId) = + d.sendNodes(fromId, fromAddr, reqId, + d.routingTable.neighbours(fnf.target, seenOnly = true)) + # TODO: if known, maybe we should add exact target even if not yet "seen" + proc handleTalkReq(d: Protocol, fromId: NodeId, fromAddr: Address, talkreq: TalkReqMessage, reqId: RequestId) = let talkProtocol = d.talkProtocols.getOrDefault(talkreq.protocol) @@ -336,6 +342,9 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, of findNode: discovery_message_requests_incoming.inc() d.handleFindNode(srcId, fromAddr, message.findNode, message.reqId) + of findNodeFast: + discovery_message_requests_incoming.inc() + d.handleFindNodeFast(srcId, fromAddr, message.findNodeFast, message.reqId) of talkReq: discovery_message_requests_incoming.inc() d.handleTalkReq(srcId, fromAddr, message.talkReq, message.reqId) @@ -467,7 +476,7 @@ proc ping*(d: Protocol, toNode: Node): proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): Future[DiscResult[seq[Node]]] {.async.} = - ## Send a discovery findNode message. + ## Send a getNeighbours message. ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. @@ -482,6 +491,24 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): d.replaceNode(toNode) return err(nodes.error) +proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId): + Future[DiscResult[seq[Node]]] {.async.} = + ## Send a findNode message. + ## + ## Returns the received nodes or an error. + ## Received SPRs are already validated and converted to `Node`. + let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target)) + let nodes = await d.waitNodes(toNode, reqId) + + if nodes.isOk: + let res = verifyNodesRecords(nodes.get(), toNode, findNodeResultLimit) + d.routingTable.setJustSeen(toNode) + return ok(res) + else: + d.replaceNode(toNode) + return err(nodes.error) + + proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): Future[DiscResult[seq[byte]]] {.async.} = ## Send a discovery talkreq message. @@ -530,7 +557,19 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId): for n in result: discard d.addNode(n) -proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} = +proc lookupWorkerFast(d: Protocol, destNode: Node, target: NodeId): + Future[seq[Node]] {.async.} = + ## use terget NodeId based find_node + + let r = await d.findNodeFast(destNode, target) + if r.isOk: + result.add(r[]) + + # Attempt to add all nodes discovered + for n in result: + discard d.addNode(n) + +proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] {.async.} = ## Perform a lookup for the given target, return the closest n nodes to the ## target. Maximum value for n is `BUCKET_SIZE`. # `closestNodes` holds the k closest nodes to target found, sorted by distance @@ -553,7 +592,10 @@ proc lookup*(d: Protocol, target: NodeId): Future[seq[Node]] {.async.} = while i < closestNodes.len and pendingQueries.len < alpha: let n = closestNodes[i] if not asked.containsOrIncl(n.id): - pendingQueries.add(d.lookupWorker(n, target)) + if fast: + pendingQueries.add(d.lookupWorkerFast(n, target)) + else: + pendingQueries.add(d.lookupWorker(n, target)) inc i trace "discv5 pending queries", total = pendingQueries.len