diff --git a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim index 11130de..d131cb9 100644 --- a/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/libp2pdht/private/eth/p2p/discoveryv5/protocol.nim @@ -162,6 +162,7 @@ type transport*: Transport[Protocol] # exported for tests routingTable*: RoutingTable awaitedMessages: Table[(NodeId, RequestId), Future[Option[Message]]] + awaitedNodesMessages: Table[(NodeId, RequestId), (Future[Result[seq[SignedPeerRecord],cstring]], uint32, seq[SignedPeerRecord])] refreshLoop: Future[void] revalidateLoop: Future[void] ipMajorityLoop: Future[void] @@ -424,6 +425,26 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, discovery_message_requests_incoming.inc(labelValues = ["no_response"]) trace "Received unimplemented message kind", kind = message.kind, origin = fromAddr + of nodes: + trace "node-response message received" + + var sprs = message.nodes.sprs + let total = message.nodes.total + trace "waiting for more nodes messages", me=d.localNode, srcId, total + try: + var (waiter, cnt, s) = d.awaitedNodesMessages[(srcId, message.reqId)] + cnt += 1 + s.add(sprs) + d.awaitedNodesMessages[(srcId, message.reqId)] = (waiter, cnt, s) + trace "nodes collected", me=d.localNode, srcId, cnt, s + if cnt == total: + d.awaitedNodesMessages.del((srcId, message.reqId)) + trace "all nodes responses received", me=d.localNode, srcId + waiter.complete(DiscResult[seq[SignedPeerRecord]].ok(s)) + except KeyError: + discovery_unsolicited_messages.inc() + warn "Timed out or unrequested message", kind = message.kind, + origin = fromAddr else: var waiter: Future[Option[Message]] if d.awaitedMessages.take((srcId, message.reqId), waiter): @@ -495,33 +516,24 @@ proc waitNodeResponses*[T: SomeMessage](d: Protocol, node: Node, msg: T): result = d.waitNodes(node, reqId) sendRequest(d, node, msg, reqId) -proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId): - Future[DiscResult[seq[SignedPeerRecord]]] {.async.} = +proc waitNodes(d: Protocol, fromNode: Node, reqId: RequestId, timeout = ResponseTimeout): + Future[DiscResult[seq[SignedPeerRecord]]] = ## Wait for one or more nodes replies. ## ## The first reply will hold the total number of replies expected, and based ## on that, more replies will be awaited. ## If one reply is lost here (timed out), others are ignored too. ## Same counts for out of order receival. - var op = await d.waitMessage(fromNode, reqId) - if op.isSome: - if op.get.kind == MessageKind.nodes: - var res = op.get.nodes.sprs - let total = op.get.nodes.total - for i in 1 ..< total: - op = await d.waitMessage(fromNode, reqId) - if op.isSome and op.get.kind == MessageKind.nodes: - res.add(op.get.nodes.sprs) - else: - # No error on this as we received some nodes. - break - return ok(res) - else: - discovery_message_requests_outgoing.inc(labelValues = ["invalid_response"]) - return err("Invalid response to find node message") - else: - discovery_message_requests_outgoing.inc(labelValues = ["no_response"]) - return err("Nodes message not received in time") + ## TODO: these are VERY optimistic assumptions here. We need a short timeout if we collect + + result = newFuture[DiscResult[seq[SignedPeerRecord]]]("waitNodesMessages") + let res = result + let key = (fromNode.id, reqId) + sleepAsync(timeout).addCallback() do(data: pointer): + d.awaitedNodesMessages.del(key) + if not res.finished: + res.complete(DiscResult[seq[SignedPeerRecord]].err("waitNodeMessages timed out")) + d.awaitedNodesMessages[key] = (result, 0.uint32, newSeq[SignedPeerRecord]()) proc ping*(d: Protocol, toNode: Node): Future[DiscResult[PongMessage]] {.async.} =