diff --git a/codexdht/private/eth/p2p/discoveryv5/messages.nim b/codexdht/private/eth/p2p/discoveryv5/messages.nim index d9842a9..e0de8bb 100644 --- a/codexdht/private/eth/p2p/discoveryv5/messages.nim +++ b/codexdht/private/eth/p2p/discoveryv5/messages.nim @@ -83,6 +83,7 @@ type Message* = object reqId*: RequestId + clientMode*: bool case kind*: MessageKind of ping: ping*: PingMessage diff --git a/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim b/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim index 229afc0..b0d1843 100644 --- a/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim +++ b/codexdht/private/eth/p2p/discoveryv5/messages_encoding.nim @@ -312,7 +312,7 @@ proc encode*(msg: TalkRespMessage): seq[byte] = pb.finish() pb.buffer -proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] = +proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId, clientMode: bool = false): seq[byte] = result = newSeqOfCap[byte](64) result.add(messageKind(T).ord) @@ -324,6 +324,10 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] = var pb = initProtoBuffer() pb.write(1, reqId) pb.write(2, encoded) + + if clientMode: + pb.write(3, 1'u64) + pb.finish() result.add(pb.buffer) trace "Encoded protobuf message", typ = $T @@ -344,6 +348,7 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] = var reqId: RequestId encoded: EncodedMessage + clientModeField: uint64 if pb.getRequiredField(1, reqId).isErr: return err("Invalid request-id") @@ -353,6 +358,11 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] = if pb.getRequiredField(2, encoded).isErr: return err("Invalid message encoding") + if pb.getField(3, clientModeField).isErr: + return err("Invalid clientMode field") + + message.clientMode = clientModeField != 0 + case kind of unused: return err("Invalid message type") diff --git a/codexdht/private/eth/p2p/discoveryv5/protocol.nim b/codexdht/private/eth/p2p/discoveryv5/protocol.nim index ac4bb4d..f36af17 100644 --- a/codexdht/private/eth/p2p/discoveryv5/protocol.nim +++ b/codexdht/private/eth/p2p/discoveryv5/protocol.nim @@ -76,7 +76,7 @@ import std/[net, tables, sets, options, math, sequtils, algorithm, strutils], json_serialization/std/net, - stew/[base64, byteutils, endians2], + stew/[base64, endians2], pkg/[chronicles, chronicles/chronos_tools], pkg/chronos, pkg/stint, @@ -137,15 +137,6 @@ const LookupSeenThreshold = 0.0 ## threshold used for lookup nodeset selection QuerySeenThreshold = 0.0 ## threshold used for query nodeset selection NoreplyRemoveThreshold = 0.5 ## remove node on no reply if 'seen' is below this value - clientModeProtocolId* = toBytes("clientMode") ## Protocol ID for clientMode check over TalkProtocol - -type DhtMode = enum - Server = 0.byte - Client = 1.byte - -func `==`(response: seq[byte], mode: DhtMode): bool = - response.len == 1 and response[0] == mode.byte - func shortLog*(record: SignedPeerRecord): string = ## Returns compact string representation of ``SignedPeerRecord``. ## @@ -189,7 +180,6 @@ type rng*: ref HmacDrbgContext providers: ProvidersManager clientMode*: bool - trackedFutures: Table[uint, Future[bool].Raising([CancelledError])] TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] {.gcsafe, raises: [Defect].} @@ -307,7 +297,7 @@ proc updateRecord*( proc sendResponse(d: Protocol, dstId: NodeId, dstAddr: Address, message: SomeMessage, reqId: RequestId) = ## send Response using the specifid reqId - d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId)) + d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId, d.clientMode)) proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId, nodes: openArray[Node]) = @@ -455,6 +445,11 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, trace "Timed out or unrequested message", kind = message.kind, origin = fromAddr + if message.clientMode: + let node = d.routingTable.getNode(srcId) + if node.isSome: + d.routingTable.removeNode(node.get) + proc registerTalkProtocol*(d: Protocol, protocolId: seq[byte], protocol: TalkProtocol): DiscResult[void] = # Currently allow only for one handler per talk protocol. @@ -476,7 +471,7 @@ proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T, reqId: RequestId) = doAssert(toNode.address.isSome()) let - message = encodeMessage(m, reqId) + message = encodeMessage(m, reqId, d.clientMode) trace "Send message packet", dstId = toNode.id, address = toNode.address, kind = messageKind(T) @@ -646,26 +641,6 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]): dht_message_requests_outgoing.inc(labelValues = ["no_response"]) return err("Talk response message not received in time") -proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async: (raises: [CancelledError]).} = - ## Remove node from routing table if it responds as a client. - ## Returns true if the node was removed, false otherwise. - ## The TalkProtocol is used because it is a plug and use solution. - ## Another solution would be to include clientMode in the SPR, - ## but it would change every time the clientMode is updated - ## and it is not really compatible with mode changes because - ## it has to be propagated over the nodes. - ## Note that if the talk protocol fails (timeout or error), - ## the node is not removed in order to keep backward compatibility. - try: - let resp = await d.talkReq(node, clientModeProtocolId, @[]) - if resp.isOk() and resp.get() == DhtMode.Client: - d.routingTable.removeNode(node) - return true - return false - except CatchableError as e: - error "Failed to get the TalkProtocol response when checking the client mode", error = e.msg - return false - proc lookupDistances*(target, dest: NodeId): seq[uint16] = let td = logDistance(target, dest) let tdAsInt = int(td) @@ -695,12 +670,7 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool): # Attempt to add all nodes discovered for n in result: - if d.addNode(n): - let fut: Future[bool].Raising([CancelledError]) = d.removeIfClientMode(n) - fut.addCallback(proc(data: pointer) = - d.trackedFutures.del(fut.id) - ) - d.trackedFutures[fut.id] = fut + discard d.addNode(n) proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] {.async.} = ## Perform a lookup for the given target, return the closest n nodes to the @@ -989,10 +959,6 @@ proc revalidateNode*(d: Protocol, n: Node) {.async.} = let pong = await d.ping(n) if pong.isOk(): - if await d.removeIfClientMode(n): - debug "Removed client mode node from routing table", node = n - return - let res = pong.get() if res.sprSeq > n.record.seqNum: # Request new SPR @@ -1229,18 +1195,6 @@ proc open*(d: Protocol) {.raises: [Defect, CatchableError].} = d.transport.open() trace "Transport open." - let clientModeProtocol = TalkProtocol( - protocolHandler: proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, - fromUdpAddress: Address): seq[byte] {.raises: [].} = - if d.clientMode: - @[DhtMode.Client.byte] - else: - @[DhtMode.Server.byte] - ) - d.registerTalkProtocol(clientModeProtocolId, clientModeProtocol).expect( - "Only one protocol should have this id" - ) - d.seedTable() trace "Routing table seeded." @@ -1264,10 +1218,4 @@ proc closeWait*(d: Protocol) {.async.} = if not d.ipMajorityLoop.isNil: await d.ipMajorityLoop.cancelAndWait() - let cancellations = d.trackedFutures.values.toSeq.mapIt(it.cancelAndWait()) - await noCancel allFutures cancellations - d.trackedFutures.clear() - - d.talkProtocols.del(clientModeProtocolId) - await d.transport.closeWait() diff --git a/codexdht/private/eth/p2p/discoveryv5/transport.nim b/codexdht/private/eth/p2p/discoveryv5/transport.nim index 69034eb..cdeb964 100644 --- a/codexdht/private/eth/p2p/discoveryv5/transport.nim +++ b/codexdht/private/eth/p2p/discoveryv5/transport.nim @@ -232,16 +232,11 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) = # sending the 'whoareyou' message to. In that case, we can set 'seen' # TODO: verify how this works with restrictive NAT and firewall scenarios. node.registerSeen() - if t.client.addNode(node): - trace "Added new node to routing table after handshake", node, tablesize=t.client.nodesDiscovered() - # We keep adding the node in the line above in order to not break anything. - # Then we remove the node if it using client mode. - # The operation is async because the check is done over TalkProtocol. - let fut = t.client.removeIfClientMode(node) - fut.addCallback(proc(data: pointer) = - t.client.trackedFutures.del(fut.id)) - t.client.trackedFutures[fut.id] = fut + if packet.message.clientMode: + t.client.routingTable.removeNode(node) + else: + discard t.client.addNode(node) discard t.sendPending(node) else: diff --git a/tests/discv5/test_discoveryv5.nim b/tests/discv5/test_discoveryv5.nim index d38993c..ccef31f 100644 --- a/tests/discv5/test_discoveryv5.nim +++ b/tests/discv5/test_discoveryv5.nim @@ -778,34 +778,14 @@ suite "Discovery v5 Tests": await node1.closeWait() await node2.closeWait() - test "Client mode is detected over TalkProtocol when clientMode is explicitly set to true": + test "Node is added to routing table when clientMode is not enabled": let - clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310)) - serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311)) + node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310)) + node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311)) - clientNode.clientMode = true + discard await discv5_protocol.ping(node1, node2.localNode) - let response = await discv5_protocol.talkReq( - serverNode, clientNode.localNode, clientModeProtocolId, @[]) - - check: - response.isOk() - response.get() == @[byte 1] - - await clientNode.closeWait() - await serverNode.closeWait() - - test "Client mode is not decteted when clientMode is not set": - let - node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20312)) - node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20313)) - - let response = await discv5_protocol.talkReq( - node1, node2.localNode, clientModeProtocolId, @[]) - - check: - response.isOk() - response.get() == @[byte 0] + check node2.routingTable.len() == 1 await node1.closeWait() await node2.closeWait() @@ -817,31 +797,45 @@ suite "Discovery v5 Tests": clientNode.clientMode = true - # Trigger the handshake discard await discv5_protocol.ping(clientNode, serverNode.localNode) - check serverNode.routingTable.len() == 1 - - # Wait for TalkProtocol response - await sleepAsync(300.milliseconds) - check serverNode.routingTable.len() == 0 await clientNode.closeWait() await serverNode.closeWait() + test "Node is removed from routing table when clientMode is enabled after session is established": + let + node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20318)) + node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20319)) + + # Establish session: node1 is added to node2's routing table + discard await discv5_protocol.ping(node1, node2.localNode) + check node2.routingTable.len() == 1 + + # node1 switches to client mode + node1.clientMode = true + + # Second ping uses the existing session (ordinary message, not handshake) + discard await discv5_protocol.ping(node1, node2.localNode) + + check node2.routingTable.len() == 0 + + await node1.closeWait() + await node2.closeWait() + test "Node is removed from routing table when clientMode is enabled during validation": let clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20316)) serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20317)) - clientNode.clientMode = true - # Add client node directly to server routing table check serverNode.addNode(clientNode.localNode) check serverNode.routingTable.len() == 1 + clientNode.clientMode = true + await serverNode.revalidateNode(clientNode.localNode) check serverNode.routingTable.len() == 0 diff --git a/tests/discv5/test_discoveryv5_encoding.nim b/tests/discv5/test_discoveryv5_encoding.nim index 80a01eb..1d926f4 100644 --- a/tests/discv5/test_discoveryv5_encoding.nim +++ b/tests/discv5/test_discoveryv5_encoding.nim @@ -170,6 +170,30 @@ suite "Discovery v5.1 Protocol Message Encodings": let decoded = decodeMessage(hexToSeqByte(encodedPong)) check decoded.isErr() + test "clientMode flag is correctly encoded and decoded": + let + p = PingMessage(sprSeq: 1'u64) + reqId = RequestId(id: @[1.byte]) + + let encodedClient = encodeMessage(p, reqId, clientMode = true) + let decodedClient = decodeMessage(encodedClient) + check decodedClient.isOk() + check decodedClient.get().clientMode == true + + let encodedServer = encodeMessage(p, reqId, clientMode = false) + let decodedServer = decodeMessage(encodedServer) + check decodedServer.isOk() + check decodedServer.get().clientMode == false + + test "Message without clientMode field decodes as server mode": + let + p = PingMessage(sprSeq: 1'u64) + reqId = RequestId(id: @[1.byte]) + encoded = encodeMessage(p, reqId) # no clientMode field (legacy node) + decoded = decodeMessage(encoded) + check decoded.isOk() + check decoded.get().clientMode == false + # According to test vectors: # https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md#cryptographic-primitives suite "Discovery v5.1 Cryptographic Primitives Test Vectors":