diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index 8d67910..c94d9df 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -132,12 +132,16 @@ type bootstrapRecords*: seq[Record] ipVote: IpVote enrAutoUpdate: bool + talkProtocols: Table[seq[byte], TalkProtocolHandler] rng*: ref BrHmacDrbgContext PendingRequest = object node: Node message: seq[byte] + TalkProtocolHandler* = proc(request: seq[byte]): seq[byte] + {.gcsafe, raises: [Defect].} + DiscResult*[T] = Result[T, cstring] proc addNode*(d: Protocol, node: Node): bool = @@ -295,9 +299,15 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address, proc handleTalkReq(d: Protocol, fromId: NodeId, fromAddr: Address, talkreq: TalkReqMessage, reqId: RequestId) = - # No support for any protocol yet so an empty response is send as per - # specification. - let talkresp = TalkRespMessage(response: @[]) + let protocolHandler = d.talkProtocols.getOrDefault(talkreq.protocol) + + let talkresp = + if protocolHandler.isNil(): + # Protocol identifier that is not registered and thus not supported. An + # empty response is send as per specification. + TalkRespMessage(response: @[]) + else: + TalkRespMessage(response: protocolHandler(talkreq.request)) let (data, _) = encodeMessagePacket(d.rng[], d.codec, fromId, fromAddr, encodeMessage(talkresp, reqId)) @@ -331,6 +341,14 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, trace "Timed out or unrequested message", kind = message.kind, origin = fromAddr +proc registerTalkProtocol*(d: Protocol, protocol: seq[byte], + handler: TalkProtocolHandler): DiscResult[void] = + # Currently allow only for one handler per talk protocol. + if d.talkProtocols.hasKeyOrPut(protocol, handler): + err("Protocol identifier already registered") + else: + ok() + proc sendWhoareyou(d: Protocol, toId: NodeId, a: Address, requestNonce: AESGCMNonce, node: Option[Node]) = let key = HandShakeKey(nodeId: toId, address: a) diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 8860db6..9f17bfe 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -2,7 +2,8 @@ import std/tables, - chronos, chronicles, stint, testutils/unittests, stew/shims/net, bearssl, + chronos, chronicles, stint, testutils/unittests, stew/shims/net, + stew/byteutils, bearssl, ../../eth/keys, ../../eth/p2p/discoveryv5/[enr, node, routing_table, encoding, sessions, messages], ../../eth/p2p/discoveryv5/protocol as discv5_protocol, @@ -619,3 +620,60 @@ procSuite "Discovery v5 Tests": firstRequestNonce await receiveNode.closeWait() + + asyncTest "Talkreq no protocol": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + talkresp = await discv5_protocol.talkreq(node1, node2.localNode, + @[byte 0x01], @[]) + + check: + talkresp.isOk() + talkresp.get().response.len == 0 + + await node1.closeWait() + await node2.closeWait() + + asyncTest "Talkreq echo protocol": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + talkProtocol = "echo".toBytes() + + proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} = + request + + check node2.registerTalkProtocol(talkProtocol, handler).isOk() + let talkresp = await discv5_protocol.talkreq(node1, node2.localNode, + talkProtocol, "hello".toBytes()) + + check: + talkresp.isOk() + talkresp.get().response == "hello".toBytes() + + await node1.closeWait() + await node2.closeWait() + + asyncTest "Talkreq register protocols": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + talkProtocol = "echo".toBytes() + + proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} = + request + + check: + node2.registerTalkProtocol(talkProtocol, handler).isOk() + node2.registerTalkProtocol(talkProtocol, handler).isErr() + node2.registerTalkProtocol("test".toBytes(), handler).isOk() + + await node1.closeWait() + await node2.closeWait() diff --git a/tests/p2p/test_discoveryv5_encoding.nim b/tests/p2p/test_discoveryv5_encoding.nim index 4130b34..873b831 100644 --- a/tests/p2p/test_discoveryv5_encoding.nim +++ b/tests/p2p/test_discoveryv5_encoding.nim @@ -110,6 +110,41 @@ suite "Discovery v5.1 Protocol Message Encodings": message.nodes.enrs[0] == e1 message.nodes.enrs[1] == e2 + test "Talk Request": + let + tr = TalkReqMessage(protocol: "echo".toBytes(), request: "hi".toBytes()) + reqId = RequestId(id: @[1.byte]) + + let encoded = encodeMessage(tr, reqId) + check encoded.toHex == "05c901846563686f826869" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.reqId == reqId + message.kind == talkreq + message.talkreq.protocol == "echo".toBytes() + message.talkreq.request == "hi".toBytes() + + test "Talk Response": + let + tr = TalkRespMessage(response: "hi".toBytes()) + reqId = RequestId(id: @[1.byte]) + + let encoded = encodeMessage(tr, reqId) + check encoded.toHex == "06c401826869" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.reqId == reqId + message.kind == talkresp + message.talkresp.response == "hi".toBytes() + test "Ping with too large RequestId": let enrSeq = 1'u64