diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index b2d6a0e2f..602b0fbd6 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -304,14 +304,15 @@ proc decodeMessages*(pb: ProtoBuffer): ProtoResult[seq[Message]] {.inline.} = if ? pb.getRepeatedField(2, msgpbs): trace "decodeMessages: read messages", count = len(msgpbs) for item in msgpbs: - msgs.add(? decodeMessage(initProtoBuffer(item))) + # size is constrained at the network level + msgs.add(? decodeMessage(initProtoBuffer(item, maxSize = uint.high))) else: trace "decodeMessages: no messages found" ok(msgs) proc encodeRpcMsg*(msg: RPCMsg, anonymize: bool): seq[byte] = trace "encodeRpcMsg: encoding message", msg = msg.shortLog() - var pb = initProtoBuffer() + var pb = initProtoBuffer(maxSize = uint.high) for item in msg.subscriptions: pb.write(1, item) for item in msg.messages: @@ -324,7 +325,7 @@ proc encodeRpcMsg*(msg: RPCMsg, anonymize: bool): seq[byte] = proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} = trace "decodeRpcMsg: decoding message", msg = msg.shortLog() - var pb = initProtoBuffer(msg) + var pb = initProtoBuffer(msg, maxSize = uint.high) var rpcMsg = ok(RPCMsg()) assign(rpcMsg.get().messages, ? pb.decodeMessages()) assign(rpcMsg.get().subscriptions, ? pb.decodeSubscriptions()) diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 0606618a6..ecac40554 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -362,3 +362,35 @@ suite "FloodSub": ) await allFuturesThrowing(nodesFut) + + asyncTest "FloodSub message size validation 2": + var messageReceived = 0 + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + inc(messageReceived) + + let + bigNode1 = generateNodes(1, maxMessageSize = 20000000) + bigNode2 = generateNodes(1, maxMessageSize = 20000000) + + # start switches + nodesFut = await allFinished( + bigNode1[0].switch.start(), + bigNode2[0].switch.start(), + ) + + await subscribeNodes(bigNode1 & bigNode2) + bigNode2[0].subscribe("foo", handler) + await waitSub(bigNode1[0], bigNode2[0], "foo") + + let bigMessage = newSeq[byte](19000000) + + check (await bigNode1[0].publish("foo", bigMessage)) > 0 + + checkExpiring: messageReceived == 1 + + await allFuturesThrowing( + bigNode1[0].switch.stop(), + bigNode2[0].switch.stop() + ) + + await allFuturesThrowing(nodesFut)