From 1b417222eb224a27aff98c9443ee1c090b4ddbca Mon Sep 17 00:00:00 2001 From: Csaba Kiraly Date: Wed, 10 May 2023 21:31:18 +0200 Subject: [PATCH] introduce waitResponse wrapper initialize wait for response before sending request. This is needed in cases where the response arrives before moving to the next instruction, such as a directly connected test. Signed-off-by: Csaba Kiraly --- .../private/eth/p2p/discoveryv5/protocol.nim | 51 +++++++++++-------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim index fb0dd2b..eef48b6 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim @@ -450,6 +450,11 @@ proc replaceNode(d: Protocol, n: Node) = # peers in the routing table. debug "Message request to bootstrap node failed", src=d.localNode, dst=n +proc waitResponse*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[Option[Message]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitMessage(node, reqId) + sendRequest(d, node, msg, reqId) proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): Future[Option[Message]] = @@ -462,6 +467,12 @@ proc waitMessage(d: Protocol, fromNode: Node, reqId: RequestId): res.complete(none(Message)) d.awaitedMessages[key] = result +proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T): + Future[DiscResult[seq[SignedPeerRecord]]] = + let reqId = RequestId.init(d.rng[]) + result = d.waitNodes(node, reqId) + sendRequest(d, node, msg, reqId) + proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): Future[DiscResult[seq[SignedPeerRecord]]] {.async.} = ## Wait for one or more nodes replies. @@ -490,23 +501,20 @@ proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Nodes message not received in time") -proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T): - RequestId = +proc sendRequest*[T: SomeMessage](d: Protocol, toId: NodeId, toAddr: Address, m: T, + reqId: RequestId) = let - reqId = RequestId.init(d.rng[]) message = encodeMessage(m, reqId) trace "Send message packet", dstId = toId, toAddr, kind = messageKind(T) discovery_message_requests_outgoing.inc() d.transport.sendMessage(toId, toAddr, message) - return reqId -proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T): - RequestId = +proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T, + reqId: RequestId) = doAssert(toNode.address.isSome()) let - reqId = RequestId.init(d.rng[]) message = encodeMessage(m, reqId) trace "Send message packet", dstId = toNode.id, @@ -514,16 +522,15 @@ proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T): discovery_message_requests_outgoing.inc() d.transport.sendMessage(toNode, message) - return reqId proc ping*(d: Protocol, toNode: Node): Future[DiscResult[PongMessage]] {.async.} = ## Send a discovery ping message. ## ## Returns the received pong message or an error. - let reqId = d.sendRequest(toNode, - PingMessage(sprSeq: d.localNode.record.seqNum)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = PingMessage(sprSeq: d.localNode.record.seqNum) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == pong: @@ -544,8 +551,9 @@ proc findNode*(d: Protocol, toNode: Node, distances: seq[uint16]): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeMessage(distances: distances)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeMessage(distances: distances) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit, distances) @@ -561,8 +569,9 @@ proc findNodeFast*(d: Protocol, toNode: Node, target: NodeId): ## ## Returns the received nodes or an error. ## Received SPRs are already validated and converted to `Node`. - let reqId = d.sendRequest(toNode, FindNodeFastMessage(target: target)) - let nodes = await d.waitNodes(toNode, reqId) + let + msg = FindNodeFastMessage(target: target) + nodes = await d.waitNodeResponses(toNode, msg) if nodes.isOk: let res = verifyNodesRecords(nodes.get(), toNode, FindNodeResultLimit) @@ -578,9 +587,9 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): ## Send a discovery talkreq message. ## ## Returns the received talkresp message or an error. - let reqId = d.sendRequest(toNode, - TalkReqMessage(protocol: protocol, request: request)) - let resp = await d.waitMessage(toNode, reqId) + let + msg = TalkReqMessage(protocol: protocol, request: request) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == talkResp: @@ -704,7 +713,8 @@ proc addProvider*( res.add(d.localNode) for toNode in res: if toNode != d.localNode: - discard d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr)) + let reqId = RequestId.init(d.rng[]) + d.sendRequest(toNode, AddProviderMessage(cId: cId, prov: pr), reqId) else: asyncSpawn d.addProviderLocal(cId, pr) @@ -717,8 +727,7 @@ proc sendGetProviders(d: Protocol, toNode: Node, trace "sendGetProviders", toNode, msg let - reqId = d.sendRequest(toNode, msg) - resp = await d.waitMessage(toNode, reqId) + resp = await d.waitResponse(toNode, msg) if resp.isSome(): if resp.get().kind == MessageKind.providers: