Add clientMode in the Message

This commit is contained in:
Arnaud 2026-05-06 18:19:56 +04:00
parent c7521be1ca
commit 1a35212eee
No known key found for this signature in database
GPG Key ID: A6C7C781817146FA
6 changed files with 76 additions and 104 deletions

View File

@ -83,6 +83,7 @@ type
Message* = object
reqId*: RequestId
clientMode*: bool
case kind*: MessageKind
of ping:
ping*: PingMessage

View File

@ -312,7 +312,7 @@ proc encode*(msg: TalkRespMessage): seq[byte] =
pb.finish()
pb.buffer
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId, clientMode: bool = false): seq[byte] =
result = newSeqOfCap[byte](64)
result.add(messageKind(T).ord)
@ -324,6 +324,10 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, reqId)
pb.write(2, encoded)
if clientMode:
pb.write(3, 1'u64)
pb.finish()
result.add(pb.buffer)
trace "Encoded protobuf message", typ = $T
@ -344,6 +348,7 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
var
reqId: RequestId
encoded: EncodedMessage
clientModeField: uint64
if pb.getRequiredField(1, reqId).isErr:
return err("Invalid request-id")
@ -353,6 +358,11 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
if pb.getRequiredField(2, encoded).isErr:
return err("Invalid message encoding")
if pb.getField(3, clientModeField).isErr:
return err("Invalid clientMode field")
message.clientMode = clientModeField != 0
case kind
of unused: return err("Invalid message type")

View File

@ -76,7 +76,7 @@
import
std/[net, tables, sets, options, math, sequtils, algorithm, strutils],
json_serialization/std/net,
stew/[base64, byteutils, endians2],
stew/[base64, endians2],
pkg/[chronicles, chronicles/chronos_tools],
pkg/chronos,
pkg/stint,
@ -137,15 +137,6 @@ const
LookupSeenThreshold = 0.0 ## threshold used for lookup nodeset selection
QuerySeenThreshold = 0.0 ## threshold used for query nodeset selection
NoreplyRemoveThreshold = 0.5 ## remove node on no reply if 'seen' is below this value
clientModeProtocolId* = toBytes("clientMode") ## Protocol ID for clientMode check over TalkProtocol
type DhtMode = enum
Server = 0.byte
Client = 1.byte
func `==`(response: seq[byte], mode: DhtMode): bool =
response.len == 1 and response[0] == mode.byte
func shortLog*(record: SignedPeerRecord): string =
## Returns compact string representation of ``SignedPeerRecord``.
##
@ -189,7 +180,6 @@ type
rng*: ref HmacDrbgContext
providers: ProvidersManager
clientMode*: bool
trackedFutures: Table[uint, Future[bool].Raising([CancelledError])]
TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte]
{.gcsafe, raises: [Defect].}
@ -307,7 +297,7 @@ proc updateRecord*(
proc sendResponse(d: Protocol, dstId: NodeId, dstAddr: Address,
message: SomeMessage, reqId: RequestId) =
## send Response using the specifid reqId
d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId))
d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId, d.clientMode))
proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId,
nodes: openArray[Node]) =
@ -455,6 +445,11 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
trace "Timed out or unrequested message", kind = message.kind,
origin = fromAddr
if message.clientMode:
let node = d.routingTable.getNode(srcId)
if node.isSome:
d.routingTable.removeNode(node.get)
proc registerTalkProtocol*(d: Protocol, protocolId: seq[byte],
protocol: TalkProtocol): DiscResult[void] =
# Currently allow only for one handler per talk protocol.
@ -476,7 +471,7 @@ proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T,
reqId: RequestId) =
doAssert(toNode.address.isSome())
let
message = encodeMessage(m, reqId)
message = encodeMessage(m, reqId, d.clientMode)
trace "Send message packet", dstId = toNode.id,
address = toNode.address, kind = messageKind(T)
@ -646,26 +641,6 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
dht_message_requests_outgoing.inc(labelValues = ["no_response"])
return err("Talk response message not received in time")
proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async: (raises: [CancelledError]).} =
## Remove node from routing table if it responds as a client.
## Returns true if the node was removed, false otherwise.
## The TalkProtocol is used because it is a plug and use solution.
## Another solution would be to include clientMode in the SPR,
## but it would change every time the clientMode is updated
## and it is not really compatible with mode changes because
## it has to be propagated over the nodes.
## Note that if the talk protocol fails (timeout or error),
## the node is not removed in order to keep backward compatibility.
try:
let resp = await d.talkReq(node, clientModeProtocolId, @[])
if resp.isOk() and resp.get() == DhtMode.Client:
d.routingTable.removeNode(node)
return true
return false
except CatchableError as e:
error "Failed to get the TalkProtocol response when checking the client mode", error = e.msg
return false
proc lookupDistances*(target, dest: NodeId): seq[uint16] =
let td = logDistance(target, dest)
let tdAsInt = int(td)
@ -695,12 +670,7 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool):
# Attempt to add all nodes discovered
for n in result:
if d.addNode(n):
let fut: Future[bool].Raising([CancelledError]) = d.removeIfClientMode(n)
fut.addCallback(proc(data: pointer) =
d.trackedFutures.del(fut.id)
)
d.trackedFutures[fut.id] = fut
discard d.addNode(n)
proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] {.async.} =
## Perform a lookup for the given target, return the closest n nodes to the
@ -989,10 +959,6 @@ proc revalidateNode*(d: Protocol, n: Node) {.async.} =
let pong = await d.ping(n)
if pong.isOk():
if await d.removeIfClientMode(n):
debug "Removed client mode node from routing table", node = n
return
let res = pong.get()
if res.sprSeq > n.record.seqNum:
# Request new SPR
@ -1229,18 +1195,6 @@ proc open*(d: Protocol) {.raises: [Defect, CatchableError].} =
d.transport.open()
trace "Transport open."
let clientModeProtocol = TalkProtocol(
protocolHandler: proc(p: TalkProtocol, request: seq[byte], fromId: NodeId,
fromUdpAddress: Address): seq[byte] {.raises: [].} =
if d.clientMode:
@[DhtMode.Client.byte]
else:
@[DhtMode.Server.byte]
)
d.registerTalkProtocol(clientModeProtocolId, clientModeProtocol).expect(
"Only one protocol should have this id"
)
d.seedTable()
trace "Routing table seeded."
@ -1264,10 +1218,4 @@ proc closeWait*(d: Protocol) {.async.} =
if not d.ipMajorityLoop.isNil:
await d.ipMajorityLoop.cancelAndWait()
let cancellations = d.trackedFutures.values.toSeq.mapIt(it.cancelAndWait())
await noCancel allFutures cancellations
d.trackedFutures.clear()
d.talkProtocols.del(clientModeProtocolId)
await d.transport.closeWait()

View File

@ -232,16 +232,11 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) =
# sending the 'whoareyou' message to. In that case, we can set 'seen'
# TODO: verify how this works with restrictive NAT and firewall scenarios.
node.registerSeen()
if t.client.addNode(node):
trace "Added new node to routing table after handshake", node, tablesize=t.client.nodesDiscovered()
# We keep adding the node in the line above in order to not break anything.
# Then we remove the node if it using client mode.
# The operation is async because the check is done over TalkProtocol.
let fut = t.client.removeIfClientMode(node)
fut.addCallback(proc(data: pointer) =
t.client.trackedFutures.del(fut.id))
t.client.trackedFutures[fut.id] = fut
if packet.message.clientMode:
t.client.routingTable.removeNode(node)
else:
discard t.client.addNode(node)
discard t.sendPending(node)
else:

View File

@ -778,34 +778,14 @@ suite "Discovery v5 Tests":
await node1.closeWait()
await node2.closeWait()
test "Client mode is detected over TalkProtocol when clientMode is explicitly set to true":
test "Node is added to routing table when clientMode is not enabled":
let
clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310))
serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311))
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310))
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311))
clientNode.clientMode = true
discard await discv5_protocol.ping(node1, node2.localNode)
let response = await discv5_protocol.talkReq(
serverNode, clientNode.localNode, clientModeProtocolId, @[])
check:
response.isOk()
response.get() == @[byte 1]
await clientNode.closeWait()
await serverNode.closeWait()
test "Client mode is not decteted when clientMode is not set":
let
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20312))
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20313))
let response = await discv5_protocol.talkReq(
node1, node2.localNode, clientModeProtocolId, @[])
check:
response.isOk()
response.get() == @[byte 0]
check node2.routingTable.len() == 1
await node1.closeWait()
await node2.closeWait()
@ -817,31 +797,45 @@ suite "Discovery v5 Tests":
clientNode.clientMode = true
# Trigger the handshake
discard await discv5_protocol.ping(clientNode, serverNode.localNode)
check serverNode.routingTable.len() == 1
# Wait for TalkProtocol response
await sleepAsync(300.milliseconds)
check serverNode.routingTable.len() == 0
await clientNode.closeWait()
await serverNode.closeWait()
test "Node is removed from routing table when clientMode is enabled after session is established":
let
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20318))
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20319))
# Establish session: node1 is added to node2's routing table
discard await discv5_protocol.ping(node1, node2.localNode)
check node2.routingTable.len() == 1
# node1 switches to client mode
node1.clientMode = true
# Second ping uses the existing session (ordinary message, not handshake)
discard await discv5_protocol.ping(node1, node2.localNode)
check node2.routingTable.len() == 0
await node1.closeWait()
await node2.closeWait()
test "Node is removed from routing table when clientMode is enabled during validation":
let
clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20316))
serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20317))
clientNode.clientMode = true
# Add client node directly to server routing table
check serverNode.addNode(clientNode.localNode)
check serverNode.routingTable.len() == 1
clientNode.clientMode = true
await serverNode.revalidateNode(clientNode.localNode)
check serverNode.routingTable.len() == 0

View File

@ -170,6 +170,30 @@ suite "Discovery v5.1 Protocol Message Encodings":
let decoded = decodeMessage(hexToSeqByte(encodedPong))
check decoded.isErr()
test "clientMode flag is correctly encoded and decoded":
let
p = PingMessage(sprSeq: 1'u64)
reqId = RequestId(id: @[1.byte])
let encodedClient = encodeMessage(p, reqId, clientMode = true)
let decodedClient = decodeMessage(encodedClient)
check decodedClient.isOk()
check decodedClient.get().clientMode == true
let encodedServer = encodeMessage(p, reqId, clientMode = false)
let decodedServer = decodeMessage(encodedServer)
check decodedServer.isOk()
check decodedServer.get().clientMode == false
test "Message without clientMode field decodes as server mode":
let
p = PingMessage(sprSeq: 1'u64)
reqId = RequestId(id: @[1.byte])
encoded = encodeMessage(p, reqId) # no clientMode field (legacy node)
decoded = decodeMessage(encoded)
check decoded.isOk()
check decoded.get().clientMode == false
# According to test vectors:
# https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md#cryptographic-primitives
suite "Discovery v5.1 Cryptographic Primitives Test Vectors":