diff --git a/codexdht/private/eth/p2p/discoveryv5/protocol.nim b/codexdht/private/eth/p2p/discoveryv5/protocol.nim index 2e2c41c..fadf9b2 100644 --- a/codexdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/codexdht/private/eth/p2p/discoveryv5/protocol.nim @@ -139,6 +139,13 @@ const NoreplyRemoveThreshold = 0.5 ## remove node on no reply if 'seen' is below this value clientModeProtocolId* = toBytes("clientMode") ## Protocol ID for clientMode check over TalkProtocol +type DhtMode = enum + Server = 0.byte + Client = 1.byte + +func `==`(response: seq[byte], mode: DhtMode): bool = + response.len == 1 and response[0] == mode.byte + func shortLog*(record: SignedPeerRecord): string = ## Returns compact string representation of ``SignedPeerRecord``. ## @@ -182,7 +189,7 @@ type rng*: ref HmacDrbgContext providers: ProvidersManager clientMode*: bool - trackedFutures: seq[Future[bool]] + trackedFutures: Table[uint, Future[bool].Raising([CancelledError])] TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] {.gcsafe, raises: [Defect].} @@ -639,7 +646,7 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): dht_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Talk response message not received in time") -proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async.} = +proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async: (raises: [CancelledError]).} = ## Remove node from routing table if it responds as a client. ## Returns true if the node was removed, false otherwise. ## The TalkProtocol is used because it is a plug and use solution. @@ -649,11 +656,15 @@ proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async.} = ## it has to be propagated over the nodes. ## Note that if the talk protocol fails (timeout or error), ## the node is not removed in order to keep backward compatibility. - let resp = await d.talkReq(node, clientModeProtocolId, @[]) - if resp.isOk() and resp.get() == @[byte 1]: - d.routingTable.removeNode(node) - return true - return false + try: + let resp = await d.talkReq(node, clientModeProtocolId, @[]) + if resp.isOk() and resp.get() == DhtMode.Client: + d.routingTable.removeNode(node) + return true + return false + except CatchableError as e: + error "Failed to get the TalkProtocol response when checking the client mode", error = e.msg + return false proc lookupDistances*(target, dest: NodeId): seq[uint16] = let td = logDistance(target, dest) @@ -685,11 +696,11 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool): # Attempt to add all nodes discovered for n in result: if d.addNode(n): - let fut = d.removeIfClientMode(n) + let fut: Future[bool].Raising([CancelledError]) = d.removeIfClientMode(n) fut.addCallback(proc(data: pointer) = - d.trackedFutures.remove(fut) + d.trackedFutures.del(fut.id) ) - d.trackedFutures.add(fut) + d.trackedFutures[fut.id] = fut proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] {.async.} = ## Perform a lookup for the given target, return the closest n nodes to the @@ -1226,7 +1237,7 @@ proc open*(d: Protocol) {.raises: [Defect, CatchableError].} = else: @[byte 0] ) - discard d.registerTalkProtocol(clientModeProtocolId, clientModeProtocol).expect( + d.registerTalkProtocol(clientModeProtocolId, clientModeProtocol).expect( "Only one protocol should have this id" ) @@ -1253,6 +1264,10 @@ proc closeWait*(d: Protocol) {.async.} = if not d.ipMajorityLoop.isNil: await d.ipMajorityLoop.cancelAndWait() - d.trackedFutures.cancelTracked() + let cancellations = d.trackedFutures.values.toSeq.mapIt(it.cancelAndWait()) + await noCancel allFutures cancellations + d.trackedFutures.clear() + + d.talkProtocols.del(clientModeProtocolId) await d.transport.closeWait() diff --git a/codexdht/private/eth/p2p/discoveryv5/transport.nim b/codexdht/private/eth/p2p/discoveryv5/transport.nim index afd4d06..69034eb 100644 --- a/codexdht/private/eth/p2p/discoveryv5/transport.nim +++ b/codexdht/private/eth/p2p/discoveryv5/transport.nim @@ -240,8 +240,8 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) = # The operation is async because the check is done over TalkProtocol. let fut = t.client.removeIfClientMode(node) fut.addCallback(proc(data: pointer) = - t.client.trackedFutures.remove(fut)) - t.client.trackedFutures.add(fut) + t.client.trackedFutures.del(fut.id)) + t.client.trackedFutures[fut.id] = fut discard t.sendPending(node) else: