diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index ab603145a..37bb59d01 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -153,7 +153,7 @@ method init*(g: GossipSub) = g.codecs &= GossipSubCodec g.codecs &= GossipSubCodec_10 -method onNewPeer(g: GossipSub, peer: PubSubPeer) = +method onNewPeer*(g: GossipSub, peer: PubSubPeer) = g.withPeerStats(peer.peerId) do (stats: var PeerStats): # Make sure stats and peer information match, even when reloading peer stats # from a previous connection diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 464131ddf..25315e4dc 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -242,7 +242,7 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte]) {.raises: [], async.} = return if msg.len > p.maxMessageSize: - info "trying to send a too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len + info "trying to send a msg too big for pubsub", maxSize=p.maxMessageSize, msgSize=msg.len return if p.sendConn == nil: @@ -269,9 +269,42 @@ proc sendEncoded*(p: PubSubPeer, msg: seq[byte]) {.raises: [], async.} = await conn.close() # This will clean up the send connection -proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} = - trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg) +iterator splitRPCMsg(peer: PubSubPeer, rpcMsg: RPCMsg, maxSize: int, anonymize: bool): seq[byte] = + ## This iterator takes an `RPCMsg` and sequentially repackages its Messages into new `RPCMsg` instances. + ## Each new `RPCMsg` accumulates Messages until reaching the specified `maxSize`. If a single Message + ## exceeds the `maxSize` when trying to fit into an empty `RPCMsg`, the latter is skipped as too large to send. + ## Every constructed `RPCMsg` is then encoded, optionally anonymized, and yielded as a sequence of bytes. + var currentRPCMsg = rpcMsg + currentRPCMsg.messages = newSeq[Message]() + + var currentSize = byteSize(currentRPCMsg) + + for msg in rpcMsg.messages: + let msgSize = byteSize(msg) + + # Check if adding the next message will exceed maxSize + if float(currentSize + msgSize) * 1.1 > float(maxSize): # Guessing 10% protobuf overhead + if currentRPCMsg.messages.len == 0: + trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg) + continue # Skip this message + + trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg) + yield encodeRpcMsg(currentRPCMsg, anonymize) + currentRPCMsg = RPCMsg() + currentSize = 0 + + currentRPCMsg.messages.add(msg) + currentSize += msgSize + + # Check if there is a non-empty currentRPCMsg left to be added + if currentSize > 0 and currentRPCMsg.messages.len > 0: + trace "sending msg to peer", peer, rpcMsg = shortLog(currentRPCMsg) + yield encodeRpcMsg(currentRPCMsg, anonymize) + else: + trace "message too big to sent", peer, rpcMsg = shortLog(currentRPCMsg) + +proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} = # When sending messages, we take care to re-encode them with the right # anonymization flag to ensure that we're not penalized for sending invalid # or malicious data on the wire - in particular, re-encoding protects against @@ -289,7 +322,13 @@ proc send*(p: PubSubPeer, msg: RPCMsg, anonymize: bool) {.raises: [].} = sendMetrics(msg) encodeRpcMsg(msg, anonymize) - asyncSpawn p.sendEncoded(encoded) + if encoded.len > p.maxMessageSize and msg.messages.len > 1: + for encodedSplitMsg in splitRPCMsg(p, msg, p.maxMessageSize, anonymize): + asyncSpawn p.sendEncoded(encodedSplitMsg) + else: + # If the message size is within limits, send it as is + trace "sending msg to peer", peer = p, rpcMsg = shortLog(msg) + asyncSpawn p.sendEncoded(encoded) proc canAskIWant*(p: PubSubPeer, msgId: MessageId): bool = for sentIHave in p.sentIHaves.mitems(): diff --git a/libp2p/protocols/pubsub/rpc/messages.nim b/libp2p/protocols/pubsub/rpc/messages.nim index d4cbf85da..77baded78 100644 --- a/libp2p/protocols/pubsub/rpc/messages.nim +++ b/libp2p/protocols/pubsub/rpc/messages.nim @@ -9,7 +9,7 @@ {.push raises: [].} -import options, sequtils +import options, sequtils, sugar import "../../.."/[ peerid, routing_record, @@ -18,6 +18,14 @@ import "../../.."/[ export options +proc expectedFields[T](t: typedesc[T], existingFieldNames: seq[string]) {.raises: [CatchableError].} = + var fieldNames: seq[string] + for name, _ in fieldPairs(T()): + fieldNames &= name + if fieldNames != existingFieldNames: + fieldNames.keepIf(proc(it: string): bool = it notin existingFieldNames) + raise newException(CatchableError, $T & " fields changed, please search for and revise all relevant procs. New fields: " & $fieldNames) + type PeerInfoMsg* = object peerId*: PeerId @@ -117,31 +125,53 @@ func shortLog*(m: RPCMsg): auto = control: m.control.get(ControlMessage()).shortLog ) +static: expectedFields(PeerInfoMsg, @["peerId", "signedPeerRecord"]) +proc byteSize(peerInfo: PeerInfoMsg): int = + peerInfo.peerId.len + peerInfo.signedPeerRecord.len + +static: expectedFields(SubOpts, @["subscribe", "topic"]) +proc byteSize(subOpts: SubOpts): int = + 1 + subOpts.topic.len # 1 byte for the bool + +static: expectedFields(Message, @["fromPeer", "data", "seqno", "topicIds", "signature", "key"]) proc byteSize*(msg: Message): int = - var total = 0 - total += msg.fromPeer.len - total += msg.data.len - total += msg.seqno.len - total += msg.signature.len - total += msg.key.len - for topicId in msg.topicIds: - total += topicId.len - return total + msg.fromPeer.len + msg.data.len + msg.seqno.len + + msg.signature.len + msg.key.len + msg.topicIds.foldl(a + b.len, 0) proc byteSize*(msgs: seq[Message]): int = - msgs.mapIt(byteSize(it)).foldl(a + b, 0) + msgs.foldl(a + b.byteSize, 0) -proc byteSize*(ihave: seq[ControlIHave]): int = - var total = 0 - for item in ihave: - total += item.topicId.len - for msgId in item.messageIds: - total += msgId.len - return total +static: expectedFields(ControlIHave, @["topicId", "messageIds"]) +proc byteSize(controlIHave: ControlIHave): int = + controlIHave.topicId.len + controlIHave.messageIds.foldl(a + b.len, 0) -proc byteSize*(iwant: seq[ControlIWant]): int = - var total = 0 - for item in iwant: - for msgId in item.messageIds: - total += msgId.len - return total +proc byteSize*(ihaves: seq[ControlIHave]): int = + ihaves.foldl(a + b.byteSize, 0) + +static: expectedFields(ControlIWant, @["messageIds"]) +proc byteSize(controlIWant: ControlIWant): int = + controlIWant.messageIds.foldl(a + b.len, 0) + +proc byteSize*(iwants: seq[ControlIWant]): int = + iwants.foldl(a + b.byteSize, 0) + +static: expectedFields(ControlGraft, @["topicId"]) +proc byteSize(controlGraft: ControlGraft): int = + controlGraft.topicId.len + +static: expectedFields(ControlPrune, @["topicId", "peers", "backoff"]) +proc byteSize(controlPrune: ControlPrune): int = + controlPrune.topicId.len + controlPrune.peers.foldl(a + b.byteSize, 0) + 8 # 8 bytes for uint64 + +static: expectedFields(ControlMessage, @["ihave", "iwant", "graft", "prune", "idontwant"]) +proc byteSize(control: ControlMessage): int = + control.ihave.foldl(a + b.byteSize, 0) + control.iwant.foldl(a + b.byteSize, 0) + + control.graft.foldl(a + b.byteSize, 0) + control.prune.foldl(a + b.byteSize, 0) + + control.idontwant.foldl(a + b.byteSize, 0) + +static: expectedFields(RPCMsg, @["subscriptions", "messages", "control", "ping", "pong"]) +proc byteSize*(rpc: RPCMsg): int = + result = rpc.subscriptions.foldl(a + b.byteSize, 0) + byteSize(rpc.messages) + + rpc.ping.len + rpc.pong.len + rpc.control.withValue(ctrl): + result += ctrl.byteSize diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index a764c8e52..42809e7c2 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -1,42 +1,31 @@ -include ../../libp2p/protocols/pubsub/gossipsub +# Nim-LibP2P +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. {.used.} -import std/[options, deques] +import std/[options, deques, sequtils, enumerate, algorithm] import stew/byteutils import ../../libp2p/builders import ../../libp2p/errors import ../../libp2p/crypto/crypto import ../../libp2p/stream/bufferstream +import ../../libp2p/protocols/pubsub/[pubsub, gossipsub, mcache, mcache, peertable] +import ../../libp2p/protocols/pubsub/rpc/[message, messages] import ../../libp2p/switch import ../../libp2p/muxers/muxer import ../../libp2p/protocols/pubsub/rpc/protobuf +import utils import ../helpers -type - TestGossipSub = ref object of GossipSub - proc noop(data: seq[byte]) {.async, gcsafe.} = discard -proc getPubSubPeer(p: TestGossipSub, peerId: PeerId): PubSubPeer = - proc getConn(): Future[Connection] = - p.switch.dial(peerId, GossipSubCodec) - - let pubSubPeer = PubSubPeer.new(peerId, getConn, nil, GossipSubCodec, 1024 * 1024, Opt.some(TokenBucket.new(1024, 500.milliseconds))) - debug "created new pubsub peer", peerId - - p.peers[peerId] = pubSubPeer - - onNewPeer(p, pubSubPeer) - pubSubPeer - -proc randomPeerId(): PeerId = - try: - PeerId.init(PrivateKey.random(ECDSA, rng[]).get()).tryGet() - except CatchableError as exc: - raise newException(Defect, exc.msg) - const MsgIdSuccess = "msg id gen success" suite "GossipSub internal": @@ -826,3 +815,130 @@ suite "GossipSub internal": await allFuturesThrowing(conns.mapIt(it.close())) await gossipSub.switch.stop() + + proc setupTest(): Future[tuple[gossip0: GossipSub, gossip1: GossipSub, receivedMessages: ref HashSet[seq[byte]]]] {.async.} = + let + nodes = generateNodes(2, gossip = true, verifySignature = false) + discard await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start() + ) + + await nodes[1].switch.connect(nodes[0].switch.peerInfo.peerId, nodes[0].switch.peerInfo.addrs) + + var receivedMessages = new(HashSet[seq[byte]]) + + proc handlerA(topic: string, data: seq[byte]) {.async, gcsafe.} = + receivedMessages[].incl(data) + + proc handlerB(topic: string, data: seq[byte]) {.async, gcsafe.} = + discard + + nodes[0].subscribe("foobar", handlerA) + nodes[1].subscribe("foobar", handlerB) + await waitSubGraph(nodes, "foobar") + + var gossip0: GossipSub = GossipSub(nodes[0]) + var gossip1: GossipSub = GossipSub(nodes[1]) + + return (gossip0, gossip1, receivedMessages) + + proc teardownTest(gossip0: GossipSub, gossip1: GossipSub) {.async.} = + await allFuturesThrowing( + gossip0.switch.stop(), + gossip1.switch.stop() + ) + + proc createMessages(gossip0: GossipSub, gossip1: GossipSub, size1: int, size2: int): tuple[iwantMessageIds: seq[MessageId], sentMessages: HashSet[seq[byte]]] = + var iwantMessageIds = newSeq[MessageId]() + var sentMessages = initHashSet[seq[byte]]() + + for i, size in enumerate([size1, size2]): + let data = newSeqWith[byte](size, i.byte) + sentMessages.incl(data) + + let msg = Message.init(gossip1.peerInfo.peerId, data, "foobar", some(uint64(i + 1))) + let iwantMessageId = gossip1.msgIdProvider(msg).expect(MsgIdSuccess) + iwantMessageIds.add(iwantMessageId) + gossip1.mcache.put(iwantMessageId, msg) + + let peer = gossip1.peers[(gossip0.peerInfo.peerId)] + peer.sentIHaves[^1].incl(iwantMessageId) + + return (iwantMessageIds, sentMessages) + + asyncTest "e2e - Split IWANT replies when individual messages are below maxSize but combined exceed maxSize": + # This test checks if two messages, each below the maxSize, are correctly split when their combined size exceeds maxSize. + # Expected: Both messages should be received. + let (gossip0, gossip1, receivedMessages) = await setupTest() + + let messageSize = gossip1.maxMessageSize div 2 + 1 + let (iwantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize) + + gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage( + ihave: @[ControlIHave(topicId: "foobar", messageIds: iwantMessageIds)] + )))) + + checkExpiring: receivedMessages[] == sentMessages + check receivedMessages[].len == 2 + + await teardownTest(gossip0, gossip1) + + asyncTest "e2e - Discard IWANT replies when both messages individually exceed maxSize": + # This test checks if two messages, each exceeding the maxSize, are discarded and not sent. + # Expected: No messages should be received. + let (gossip0, gossip1, receivedMessages) = await setupTest() + + let messageSize = gossip1.maxMessageSize + 10 + let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, messageSize, messageSize) + + gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage( + ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)] + )))) + + await sleepAsync(300.milliseconds) + checkExpiring: receivedMessages[].len == 0 + + await teardownTest(gossip0, gossip1) + + asyncTest "e2e - Process IWANT replies when both messages are below maxSize": + # This test checks if two messages, both below the maxSize, are correctly processed and sent. + # Expected: Both messages should be received. + let (gossip0, gossip1, receivedMessages) = await setupTest() + let size1 = gossip1.maxMessageSize div 2 + let size2 = gossip1.maxMessageSize div 3 + let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2) + + gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage( + ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)] + )))) + + checkExpiring: receivedMessages[] == sentMessages + check receivedMessages[].len == 2 + + await teardownTest(gossip0, gossip1) + + asyncTest "e2e - Split IWANT replies when one message is below maxSize and the other exceeds maxSize": + # This test checks if, when given two messages where one is below maxSize and the other exceeds it, only the smaller message is processed and sent. + # Expected: Only the smaller message should be received. + let (gossip0, gossip1, receivedMessages) = await setupTest() + let maxSize = gossip1.maxMessageSize + let size1 = maxSize div 2 + let size2 = maxSize + 10 + let (bigIWantMessageIds, sentMessages) = createMessages(gossip0, gossip1, size1, size2) + + gossip1.broadcast(gossip1.mesh["foobar"], RPCMsg(control: some(ControlMessage( + ihave: @[ControlIHave(topicId: "foobar", messageIds: bigIWantMessageIds)] + )))) + + var smallestSet: HashSet[seq[byte]] + let seqs = toSeq(sentMessages) + if seqs[0] < seqs[1]: + smallestSet.incl(seqs[0]) + else: + smallestSet.incl(seqs[1]) + + checkExpiring: receivedMessages[] == smallestSet + check receivedMessages[].len == 1 + + await teardownTest(gossip0, gossip1) diff --git a/tests/pubsub/testmessage.nim b/tests/pubsub/testmessage.nim index f4f062fa3..589920b40 100644 --- a/tests/pubsub/testmessage.nim +++ b/tests/pubsub/testmessage.nim @@ -2,10 +2,10 @@ import unittest2 {.used.} -import options +import options, strutils import stew/byteutils import ../../libp2p/[peerid, peerinfo, - crypto/crypto, + crypto/crypto as crypto, protocols/pubsub/errors, protocols/pubsub/rpc/message, protocols/pubsub/rpc/messages] @@ -28,7 +28,7 @@ suite "Message": """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F E731065A""" - seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + seckey = PrivateKey.init(crypto.fromHex(stripSpaces(pkHex))) .expect("valid private key bytes") peer = PeerInfo.new(seckey) msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true) @@ -46,7 +46,7 @@ suite "Message": """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F E731065A""" - seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + seckey = PrivateKey.init(crypto.fromHex(stripSpaces(pkHex))) .expect("valid private key bytes") peer = PeerInfo.new(seckey) @@ -64,7 +64,7 @@ suite "Message": """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F E731065A""" - seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + seckey = PrivateKey.init(crypto.fromHex(stripSpaces(pkHex))) .expect("valid private key bytes") peer = PeerInfo.new(seckey) msg = Message.init(some(peer), @[], "topic", uint64.none, sign = true) @@ -74,14 +74,54 @@ suite "Message": msgIdResult.isErr msgIdResult.error == ValidationResult.Reject - test "byteSize for Message": + test "byteSize for RPCMsg": var msg = Message( - fromPeer: PeerId(data: @[]), # Empty seq[byte] + fromPeer: PeerId(data: @['a'.byte, 'b'.byte]), # 2 bytes data: @[1'u8, 2, 3], # 3 bytes - seqno: @[1'u8], # 1 byte - signature: @[], # Empty seq[byte] - key: @[1'u8], # 1 byte + seqno: @[4'u8, 5], # 2 bytes + signature: @['c'.byte, 'd'.byte], # 2 bytes + key: @[6'u8, 7], # 2 bytes topicIds: @["abc", "defgh"] # 3 + 5 = 8 bytes ) - check byteSize(msg) == 3 + 1 + 1 + 8 # Total: 13 bytes \ No newline at end of file + var peerInfo = PeerInfoMsg( + peerId: PeerId(data: @['e'.byte]), # 1 byte + signedPeerRecord: @['f'.byte, 'g'.byte] # 2 bytes + ) + + var controlIHave = ControlIHave( + topicId: "ijk", # 3 bytes + messageIds: @[ @['l'.byte], @['m'.byte, 'n'.byte] ] # 1 + 2 = 3 bytes + ) + + var controlIWant = ControlIWant( + messageIds: @[ @['o'.byte, 'p'.byte], @['q'.byte] ] # 2 + 1 = 3 bytes + ) + + var controlGraft = ControlGraft( + topicId: "rst" # 3 bytes + ) + + var controlPrune = ControlPrune( + topicId: "uvw", # 3 bytes + peers: @[peerInfo, peerInfo], # (1 + 2) * 2 = 6 bytes + backoff: 12345678 # 8 bytes for uint64 + ) + + var control = ControlMessage( + ihave: @[controlIHave, controlIHave], # (3 + 3) * 2 = 12 bytes + iwant: @[controlIWant], # 3 bytes + graft: @[controlGraft], # 3 bytes + prune: @[controlPrune], # 3 + 6 + 8 = 17 bytes + idontwant: @[controlIWant] # 3 bytes + ) + + var rpcMsg = RPCMsg( + subscriptions: @[SubOpts(subscribe: true, topic: "a".repeat(12)), SubOpts(subscribe: false, topic: "b".repeat(14))], # 1 + 12 + 1 + 14 = 28 bytes + messages: @[msg, msg], # 19 * 2 = 38 bytes + ping: @[1'u8, 2], # 2 bytes + pong: @[3'u8, 4], # 2 bytes + control: some(control) # 12 + 3 + 3 + 17 + 3 = 38 bytes + ) + + check byteSize(rpcMsg) == 28 + 38 + 2 + 2 + 38 # Total: 108 bytes diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 6ac49b9b8..82209dcc1 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -9,16 +9,39 @@ import chronos, stew/[byteutils, results] import ../../libp2p/[builders, protocols/pubsub/errors, protocols/pubsub/pubsub, + protocols/pubsub/pubsubpeer, protocols/pubsub/gossipsub, protocols/pubsub/floodsub, protocols/pubsub/rpc/messages, protocols/secure/secure] +import ../helpers import chronicles export builders randomize() +type + TestGossipSub* = ref object of GossipSub + +proc getPubSubPeer*(p: TestGossipSub, peerId: PeerId): PubSubPeer = + proc getConn(): Future[Connection] = + p.switch.dial(peerId, GossipSubCodec) + + let pubSubPeer = PubSubPeer.new(peerId, getConn, nil, GossipSubCodec, 1024 * 1024) + debug "created new pubsub peer", peerId + + p.peers[peerId] = pubSubPeer + + onNewPeer(p, pubSubPeer) + pubSubPeer + +proc randomPeerId*(): PeerId = + try: + PeerId.init(PrivateKey.random(ECDSA, rng[]).get()).tryGet() + except CatchableError as exc: + raise newException(Defect, exc.msg) + func defaultMsgIdProvider*(m: Message): Result[MessageId, ValidationResult] = let mid = if m.seqno.len > 0 and m.fromPeer.data.len > 0: