diff --git a/eth/p2p/discoveryv5/encoding.nim b/eth/p2p/discoveryv5/encoding.nim index ef22522..19e9e0f 100644 --- a/eth/p2p/discoveryv5/encoding.nim +++ b/eth/p2p/discoveryv5/encoding.nim @@ -43,6 +43,11 @@ const staticHeaderSize = protocolId.len + 2 + 2 + 1 + gcmNonceSize authdataHeadSize = sizeof(NodeId) + 1 + 1 whoareyouSize = ivSize + staticHeaderSize + idNonceSize + 8 + # It's mentioned in the specification that 1280 is the maximum size for the + # discovery v5 packet, not for the UDP datagram. Thus this limit is applied on + # the UDP payload and the UDP header is not taken into account. + # https://github.com/ethereum/devp2p/blob/26e380b1f3a57db16fbdd4528dde82104c77fa38/discv5/discv5-wire.md#udp-communication + maxDiscv5PacketSize* = 1280 type AESGCMNonce* = array[gcmNonceSize, byte] @@ -579,7 +584,10 @@ proc decodePacket*(c: var Codec, fromAddr: Address, input: openArray[byte]): ## WHOAREYOU packet. In case of the latter a `newNode` might be provided. # Smallest packet is Whoareyou packet so that is the minimum size if input.len() < whoareyouSize: - return err("Packet size too short") + return err("Packet size too small") + + if input.len() > maxDiscv5PacketSize: + return err("Packet size too big") # TODO: Just pass in the full input? Makes more sense perhaps. let (staticHeader, header) = ? decodeHeader(c.localNode.id, diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index a190e6e..a3bb3b2 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -90,7 +90,8 @@ import import nimcrypto except toHex -export options, results, node, enr +export + options, results, node, enr, encoding.maxDiscv5PacketSize declareCounter discovery_message_requests_outgoing, "Discovery protocol outgoing message requests", labels = ["response"] diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 1e0d204..d97e684 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -1,11 +1,12 @@ {.used.} import - std/tables, + std/[tables, sequtils], chronos, chronicles, stint, testutils/unittests, stew/shims/net, stew/byteutils, bearssl, ../../eth/keys, - ../../eth/p2p/discoveryv5/[enr, node, routing_table, encoding, sessions, messages, nodes_verification], + ../../eth/p2p/discoveryv5/[enr, node, routing_table, encoding, sessions, + messages, nodes_verification], ../../eth/p2p/discoveryv5/protocol as discv5_protocol, ./discv5_test_helper @@ -706,8 +707,10 @@ suite "Discovery v5 Tests": rng, PrivateKey.random(rng[]), localAddress(20303)) talkProtocol = "echo".toBytes() - proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] - {.gcsafe, raises: [Defect].} = + proc handler( + protocol: TalkProtocol, request: seq[byte], + fromId: NodeId, fromUdpAddress: Address): + seq[byte] {.gcsafe, raises: [Defect].} = request let echoProtocol = TalkProtocol(protocolHandler: handler) @@ -731,8 +734,10 @@ suite "Discovery v5 Tests": rng, PrivateKey.random(rng[]), localAddress(20303)) talkProtocol = "echo".toBytes() - proc handler(protocol: TalkProtocol, request: seq[byte], fromId: NodeId, fromUdpAddress: Address): seq[byte] - {.gcsafe, raises: [Defect].} = + proc handler( + protocol: TalkProtocol, request: seq[byte], + fromId: NodeId, fromUdpAddress: Address): + seq[byte] {.gcsafe, raises: [Defect].} = request let echoProtocol = TalkProtocol(protocolHandler: handler) @@ -744,3 +749,83 @@ suite "Discovery v5 Tests": await node1.closeWait() await node2.closeWait() + + asyncTest "Max packet size: Request": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + talkProtocol = "echo".toBytes() + + proc handler( + protocol: TalkProtocol, request: seq[byte], + fromId: NodeId, fromUdpAddress: Address): + seq[byte] {.gcsafe, raises: [Defect].} = + request + + let echoProtocol = TalkProtocol(protocolHandler: handler) + + check node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk() + # Do a ping first so a session is created, that makes the next message to + # be an ordinary message and more easy to reverse calculate packet sizes for + # than for a handshake message. + check (await node1.ping(node2.localNode)).isOk() + + block: # 1172 = 1280 - 103 - 4 - 1 = max - talkreq - "echo" - rlp blob + let talkresp = await discv5_protocol.talkReq(node1, node2.localNode, + talkProtocol, repeat(byte 6, 1172)) + + check: + talkresp.isOk() + + block: # > 1280 -> should fail + let talkresp = await discv5_protocol.talkReq(node1, node2.localNode, + talkProtocol, repeat(byte 6, 1173)) + + check: + talkresp.isErr() + + await node1.closeWait() + await node2.closeWait() + + asyncTest "Max packet size: Response": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + talkProtocol = "echo".toBytes() + + proc handler( + protocol: TalkProtocol, request: seq[byte], + fromId: NodeId, fromUdpAddress: Address): + seq[byte] {.gcsafe, raises: [Defect].} = + # Return the request + same protocol id + 2 bytes, to make it 1 byte + # bigger than the request + request & "echo12".toBytes() + + let echoProtocol = TalkProtocol(protocolHandler: handler) + + check node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk() + # Do a ping first so a session is created, that makes the next message to + # be an ordinary message and more easy to reverse calculate packet sizes for + # than for a handshake message. + check (await node1.ping(node2.localNode)).isOk() + + block: # 1171 -> response will be 1 byte bigger thus this should pass + let talkresp = await discv5_protocol.talkReq(node1, node2.localNode, + talkProtocol, repeat(byte 6, 1171)) + + check: + talkresp.isOk() + + block: # 1172 -> response will be 1 byte bigger thus this should fail + let talkresp = await discv5_protocol.talkReq(node1, node2.localNode, + talkProtocol, repeat(byte 6, 1172)) + + check: + talkresp.isErr() + + await node1.closeWait() + await node2.closeWait()