mirror of
https://github.com/logos-storage/logos-storage-nim-dht.git
synced 2026-05-21 17:19:27 +00:00
Add clientMode in the Message
This commit is contained in:
parent
c7521be1ca
commit
1a35212eee
@ -83,6 +83,7 @@ type
|
||||
|
||||
Message* = object
|
||||
reqId*: RequestId
|
||||
clientMode*: bool
|
||||
case kind*: MessageKind
|
||||
of ping:
|
||||
ping*: PingMessage
|
||||
|
||||
@ -312,7 +312,7 @@ proc encode*(msg: TalkRespMessage): seq[byte] =
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
|
||||
proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId, clientMode: bool = false): seq[byte] =
|
||||
result = newSeqOfCap[byte](64)
|
||||
result.add(messageKind(T).ord)
|
||||
|
||||
@ -324,6 +324,10 @@ proc encodeMessage*[T: SomeMessage](p: T, reqId: RequestId): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, reqId)
|
||||
pb.write(2, encoded)
|
||||
|
||||
if clientMode:
|
||||
pb.write(3, 1'u64)
|
||||
|
||||
pb.finish()
|
||||
result.add(pb.buffer)
|
||||
trace "Encoded protobuf message", typ = $T
|
||||
@ -344,6 +348,7 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
|
||||
var
|
||||
reqId: RequestId
|
||||
encoded: EncodedMessage
|
||||
clientModeField: uint64
|
||||
|
||||
if pb.getRequiredField(1, reqId).isErr:
|
||||
return err("Invalid request-id")
|
||||
@ -353,6 +358,11 @@ proc decodeMessage*(body: openArray[byte]): DecodeResult[Message] =
|
||||
if pb.getRequiredField(2, encoded).isErr:
|
||||
return err("Invalid message encoding")
|
||||
|
||||
if pb.getField(3, clientModeField).isErr:
|
||||
return err("Invalid clientMode field")
|
||||
|
||||
message.clientMode = clientModeField != 0
|
||||
|
||||
case kind
|
||||
of unused: return err("Invalid message type")
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@
|
||||
import
|
||||
std/[net, tables, sets, options, math, sequtils, algorithm, strutils],
|
||||
json_serialization/std/net,
|
||||
stew/[base64, byteutils, endians2],
|
||||
stew/[base64, endians2],
|
||||
pkg/[chronicles, chronicles/chronos_tools],
|
||||
pkg/chronos,
|
||||
pkg/stint,
|
||||
@ -137,15 +137,6 @@ const
|
||||
LookupSeenThreshold = 0.0 ## threshold used for lookup nodeset selection
|
||||
QuerySeenThreshold = 0.0 ## threshold used for query nodeset selection
|
||||
NoreplyRemoveThreshold = 0.5 ## remove node on no reply if 'seen' is below this value
|
||||
clientModeProtocolId* = toBytes("clientMode") ## Protocol ID for clientMode check over TalkProtocol
|
||||
|
||||
type DhtMode = enum
|
||||
Server = 0.byte
|
||||
Client = 1.byte
|
||||
|
||||
func `==`(response: seq[byte], mode: DhtMode): bool =
|
||||
response.len == 1 and response[0] == mode.byte
|
||||
|
||||
func shortLog*(record: SignedPeerRecord): string =
|
||||
## Returns compact string representation of ``SignedPeerRecord``.
|
||||
##
|
||||
@ -189,7 +180,6 @@ type
|
||||
rng*: ref HmacDrbgContext
|
||||
providers: ProvidersManager
|
||||
clientMode*: bool
|
||||
trackedFutures: Table[uint, Future[bool].Raising([CancelledError])]
|
||||
|
||||
TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte]
|
||||
{.gcsafe, raises: [Defect].}
|
||||
@ -307,7 +297,7 @@ proc updateRecord*(
|
||||
proc sendResponse(d: Protocol, dstId: NodeId, dstAddr: Address,
|
||||
message: SomeMessage, reqId: RequestId) =
|
||||
## send Response using the specifid reqId
|
||||
d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId))
|
||||
d.transport.sendMessage(dstId, dstAddr, encodeMessage(message, reqId, d.clientMode))
|
||||
|
||||
proc sendNodes(d: Protocol, toId: NodeId, toAddr: Address, reqId: RequestId,
|
||||
nodes: openArray[Node]) =
|
||||
@ -455,6 +445,11 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address,
|
||||
trace "Timed out or unrequested message", kind = message.kind,
|
||||
origin = fromAddr
|
||||
|
||||
if message.clientMode:
|
||||
let node = d.routingTable.getNode(srcId)
|
||||
if node.isSome:
|
||||
d.routingTable.removeNode(node.get)
|
||||
|
||||
proc registerTalkProtocol*(d: Protocol, protocolId: seq[byte],
|
||||
protocol: TalkProtocol): DiscResult[void] =
|
||||
# Currently allow only for one handler per talk protocol.
|
||||
@ -476,7 +471,7 @@ proc sendRequest*[T: SomeMessage](d: Protocol, toNode: Node, m: T,
|
||||
reqId: RequestId) =
|
||||
doAssert(toNode.address.isSome())
|
||||
let
|
||||
message = encodeMessage(m, reqId)
|
||||
message = encodeMessage(m, reqId, d.clientMode)
|
||||
|
||||
trace "Send message packet", dstId = toNode.id,
|
||||
address = toNode.address, kind = messageKind(T)
|
||||
@ -646,26 +641,6 @@ proc talkReq*(d: Protocol, toNode: Node, protocol, request: seq[byte]):
|
||||
dht_message_requests_outgoing.inc(labelValues = ["no_response"])
|
||||
return err("Talk response message not received in time")
|
||||
|
||||
proc removeIfClientMode*(d: Protocol, node: Node): Future[bool] {.async: (raises: [CancelledError]).} =
|
||||
## Remove node from routing table if it responds as a client.
|
||||
## Returns true if the node was removed, false otherwise.
|
||||
## The TalkProtocol is used because it is a plug and use solution.
|
||||
## Another solution would be to include clientMode in the SPR,
|
||||
## but it would change every time the clientMode is updated
|
||||
## and it is not really compatible with mode changes because
|
||||
## it has to be propagated over the nodes.
|
||||
## Note that if the talk protocol fails (timeout or error),
|
||||
## the node is not removed in order to keep backward compatibility.
|
||||
try:
|
||||
let resp = await d.talkReq(node, clientModeProtocolId, @[])
|
||||
if resp.isOk() and resp.get() == DhtMode.Client:
|
||||
d.routingTable.removeNode(node)
|
||||
return true
|
||||
return false
|
||||
except CatchableError as e:
|
||||
error "Failed to get the TalkProtocol response when checking the client mode", error = e.msg
|
||||
return false
|
||||
|
||||
proc lookupDistances*(target, dest: NodeId): seq[uint16] =
|
||||
let td = logDistance(target, dest)
|
||||
let tdAsInt = int(td)
|
||||
@ -695,12 +670,7 @@ proc lookupWorker(d: Protocol, destNode: Node, target: NodeId, fast: bool):
|
||||
|
||||
# Attempt to add all nodes discovered
|
||||
for n in result:
|
||||
if d.addNode(n):
|
||||
let fut: Future[bool].Raising([CancelledError]) = d.removeIfClientMode(n)
|
||||
fut.addCallback(proc(data: pointer) =
|
||||
d.trackedFutures.del(fut.id)
|
||||
)
|
||||
d.trackedFutures[fut.id] = fut
|
||||
discard d.addNode(n)
|
||||
|
||||
proc lookup*(d: Protocol, target: NodeId, fast: bool = false): Future[seq[Node]] {.async.} =
|
||||
## Perform a lookup for the given target, return the closest n nodes to the
|
||||
@ -989,10 +959,6 @@ proc revalidateNode*(d: Protocol, n: Node) {.async.} =
|
||||
let pong = await d.ping(n)
|
||||
|
||||
if pong.isOk():
|
||||
if await d.removeIfClientMode(n):
|
||||
debug "Removed client mode node from routing table", node = n
|
||||
return
|
||||
|
||||
let res = pong.get()
|
||||
if res.sprSeq > n.record.seqNum:
|
||||
# Request new SPR
|
||||
@ -1229,18 +1195,6 @@ proc open*(d: Protocol) {.raises: [Defect, CatchableError].} =
|
||||
d.transport.open()
|
||||
trace "Transport open."
|
||||
|
||||
let clientModeProtocol = TalkProtocol(
|
||||
protocolHandler: proc(p: TalkProtocol, request: seq[byte], fromId: NodeId,
|
||||
fromUdpAddress: Address): seq[byte] {.raises: [].} =
|
||||
if d.clientMode:
|
||||
@[DhtMode.Client.byte]
|
||||
else:
|
||||
@[DhtMode.Server.byte]
|
||||
)
|
||||
d.registerTalkProtocol(clientModeProtocolId, clientModeProtocol).expect(
|
||||
"Only one protocol should have this id"
|
||||
)
|
||||
|
||||
d.seedTable()
|
||||
trace "Routing table seeded."
|
||||
|
||||
@ -1264,10 +1218,4 @@ proc closeWait*(d: Protocol) {.async.} =
|
||||
if not d.ipMajorityLoop.isNil:
|
||||
await d.ipMajorityLoop.cancelAndWait()
|
||||
|
||||
let cancellations = d.trackedFutures.values.toSeq.mapIt(it.cancelAndWait())
|
||||
await noCancel allFutures cancellations
|
||||
d.trackedFutures.clear()
|
||||
|
||||
d.talkProtocols.del(clientModeProtocolId)
|
||||
|
||||
await d.transport.closeWait()
|
||||
|
||||
@ -232,16 +232,11 @@ proc receive*(t: Transport, a: Address, packet: openArray[byte]) =
|
||||
# sending the 'whoareyou' message to. In that case, we can set 'seen'
|
||||
# TODO: verify how this works with restrictive NAT and firewall scenarios.
|
||||
node.registerSeen()
|
||||
if t.client.addNode(node):
|
||||
trace "Added new node to routing table after handshake", node, tablesize=t.client.nodesDiscovered()
|
||||
|
||||
# We keep adding the node in the line above in order to not break anything.
|
||||
# Then we remove the node if it using client mode.
|
||||
# The operation is async because the check is done over TalkProtocol.
|
||||
let fut = t.client.removeIfClientMode(node)
|
||||
fut.addCallback(proc(data: pointer) =
|
||||
t.client.trackedFutures.del(fut.id))
|
||||
t.client.trackedFutures[fut.id] = fut
|
||||
if packet.message.clientMode:
|
||||
t.client.routingTable.removeNode(node)
|
||||
else:
|
||||
discard t.client.addNode(node)
|
||||
|
||||
discard t.sendPending(node)
|
||||
else:
|
||||
|
||||
@ -778,34 +778,14 @@ suite "Discovery v5 Tests":
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
||||
test "Client mode is detected over TalkProtocol when clientMode is explicitly set to true":
|
||||
test "Node is added to routing table when clientMode is not enabled":
|
||||
let
|
||||
clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310))
|
||||
serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311))
|
||||
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20310))
|
||||
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20311))
|
||||
|
||||
clientNode.clientMode = true
|
||||
discard await discv5_protocol.ping(node1, node2.localNode)
|
||||
|
||||
let response = await discv5_protocol.talkReq(
|
||||
serverNode, clientNode.localNode, clientModeProtocolId, @[])
|
||||
|
||||
check:
|
||||
response.isOk()
|
||||
response.get() == @[byte 1]
|
||||
|
||||
await clientNode.closeWait()
|
||||
await serverNode.closeWait()
|
||||
|
||||
test "Client mode is not decteted when clientMode is not set":
|
||||
let
|
||||
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20312))
|
||||
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20313))
|
||||
|
||||
let response = await discv5_protocol.talkReq(
|
||||
node1, node2.localNode, clientModeProtocolId, @[])
|
||||
|
||||
check:
|
||||
response.isOk()
|
||||
response.get() == @[byte 0]
|
||||
check node2.routingTable.len() == 1
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
@ -817,31 +797,45 @@ suite "Discovery v5 Tests":
|
||||
|
||||
clientNode.clientMode = true
|
||||
|
||||
# Trigger the handshake
|
||||
discard await discv5_protocol.ping(clientNode, serverNode.localNode)
|
||||
|
||||
check serverNode.routingTable.len() == 1
|
||||
|
||||
# Wait for TalkProtocol response
|
||||
await sleepAsync(300.milliseconds)
|
||||
|
||||
check serverNode.routingTable.len() == 0
|
||||
|
||||
await clientNode.closeWait()
|
||||
await serverNode.closeWait()
|
||||
|
||||
test "Node is removed from routing table when clientMode is enabled after session is established":
|
||||
let
|
||||
node1 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20318))
|
||||
node2 = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20319))
|
||||
|
||||
# Establish session: node1 is added to node2's routing table
|
||||
discard await discv5_protocol.ping(node1, node2.localNode)
|
||||
check node2.routingTable.len() == 1
|
||||
|
||||
# node1 switches to client mode
|
||||
node1.clientMode = true
|
||||
|
||||
# Second ping uses the existing session (ordinary message, not handshake)
|
||||
discard await discv5_protocol.ping(node1, node2.localNode)
|
||||
|
||||
check node2.routingTable.len() == 0
|
||||
|
||||
await node1.closeWait()
|
||||
await node2.closeWait()
|
||||
|
||||
test "Node is removed from routing table when clientMode is enabled during validation":
|
||||
let
|
||||
clientNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20316))
|
||||
serverNode = initDiscoveryNode(rng, PrivateKey.example(rng), localAddress(20317))
|
||||
|
||||
clientNode.clientMode = true
|
||||
|
||||
# Add client node directly to server routing table
|
||||
check serverNode.addNode(clientNode.localNode)
|
||||
|
||||
check serverNode.routingTable.len() == 1
|
||||
|
||||
clientNode.clientMode = true
|
||||
|
||||
await serverNode.revalidateNode(clientNode.localNode)
|
||||
|
||||
check serverNode.routingTable.len() == 0
|
||||
|
||||
@ -170,6 +170,30 @@ suite "Discovery v5.1 Protocol Message Encodings":
|
||||
let decoded = decodeMessage(hexToSeqByte(encodedPong))
|
||||
check decoded.isErr()
|
||||
|
||||
test "clientMode flag is correctly encoded and decoded":
|
||||
let
|
||||
p = PingMessage(sprSeq: 1'u64)
|
||||
reqId = RequestId(id: @[1.byte])
|
||||
|
||||
let encodedClient = encodeMessage(p, reqId, clientMode = true)
|
||||
let decodedClient = decodeMessage(encodedClient)
|
||||
check decodedClient.isOk()
|
||||
check decodedClient.get().clientMode == true
|
||||
|
||||
let encodedServer = encodeMessage(p, reqId, clientMode = false)
|
||||
let decodedServer = decodeMessage(encodedServer)
|
||||
check decodedServer.isOk()
|
||||
check decodedServer.get().clientMode == false
|
||||
|
||||
test "Message without clientMode field decodes as server mode":
|
||||
let
|
||||
p = PingMessage(sprSeq: 1'u64)
|
||||
reqId = RequestId(id: @[1.byte])
|
||||
encoded = encodeMessage(p, reqId) # no clientMode field (legacy node)
|
||||
decoded = decodeMessage(encoded)
|
||||
check decoded.isOk()
|
||||
check decoded.get().clientMode == false
|
||||
|
||||
# According to test vectors:
|
||||
# https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md#cryptographic-primitives
|
||||
suite "Discovery v5.1 Cryptographic Primitives Test Vectors":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user