diff --git a/codexdht/private/eth/p2p/discoveryv5/protocol.nim b/codexdht/private/eth/p2p/discoveryv5/protocol.nim index c3f0672..4253059 100644 --- a/codexdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/codexdht/private/eth/p2p/discoveryv5/protocol.nim @@ -452,6 +452,11 @@ proc replaceNode(d: Protocol, n: Node) = # peers in the routing table. debug "Message request to bootstrap node failed", src=d.localNode, dst=n +proc waitResponse*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[Option[Message]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitMessage(node, reqId) + sendRequest(d, node, msg, reqId) proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Message]] = @@ -464,6 +469,12 @@ proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): res.complete(none(Message)) d.awaitedMessages[key] = result +proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[DiscResult[seq[SignedPeerRecord]]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitNodes(node, reqId) + sendRequest(d, node, msg, reqId) + proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[DiscResult[seq[SignedPeerRecord]]] {.async.} = ## Wait for one or more nodes replies. @@ -492,23 +503,20 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Nodes message not received in time") -proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T): - RequestId = +proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T, + reqId: RequestId) = let - reqId = RequestId.init(d.rng[]) message = encodeMessage(m, reqId) trace "Send message packet", dstId = toId, toAddr, kind = messageKind(T) discovery_message_requests_outgoing.inc() d.transport.sendMessage(toId, toAddr, message) - return reqId -proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T): - RequestId = +proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T, + reqId: RequestId) = doAssert(toNode.address.isSome()) let - reqId = RequestId.init(d.rng[]) message = encodeMessage(m, reqId) trace "Send message packet", dstId = toNode.id, @@ -516,16 +524,15 @@ proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T): discovery_message_requests_outgoing.inc() d.transport.sendMessage(toNode, message) - return reqId proc ping*(d: Protocol, toNode: Node): Future[DiscResult[PongMessage]] {.async.} = ## Send a discovery ping message. ## ## Returns the received pong message or an error. - let reqId = d.sendRequest(toNode, - PingMessage(sprSeq: d.localNode.record.seqNum)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = PingMessage(sprSeq: d.localNode.record.seqNum) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == pong: @@ -546,8 +553,9 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeMessage(distances: distances)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeMessage(distances: distances) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances) @@ -564,8 +572,9 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeFastMessage(target: target) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit) @@ -581,9 +590,9 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): ## Send a discovery talkreq message. ## ## Returns the received talkresp message or an error. - let reqId = d.sendRequest(toNode, - TalkReqMessage(protocol: protocol, request: request)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = TalkReqMessage(protocol: protocol, request: request) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == talkResp: @@ -707,7 +716,8 @@ proc addProvider*( res.add(d.localNode) for toNode in res: if toNode != d.localNode: - discard d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr)) + let reqId = RequestId.init(d.rng[]) + d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr), reqId) else: asyncSpawn d.addProviderLocal(cId, pr) @@ -720,8 +730,7 @@ proc sendGetProviders(d: Protocol, toNode: Node, trace "sendGetProviders", toNode, msg let - reqId = d.sendRequest(toNode, msg) - resp = await d.waitMessage(toNode, reqId) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == MessageKind.providers: