diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 445a9e56f..12bfb02b3 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -263,6 +263,7 @@ proc handleControl(g: GossipSub, peer: PubSubPeer, control: ControlMessage) = g.handlePrune(peer, control.prune) var respControl: ControlMessage + g.handleIDontWant(peer, control.idontwant) let iwant = g.handleIHave(peer, control.ihave) if iwant.messageIds.len > 0: respControl.iwant.add(iwant) @@ -337,6 +338,21 @@ proc validateAndRelay(g: GossipSub, toSendPeers.excl(peer) toSendPeers.excl(seenPeers) + # IDontWant is only worth it if the message is substantially + # bigger than the messageId + if msg.data.len > msgId.len * 10: + g.broadcast(toSendPeers, RPCMsg(control: some(ControlMessage( + idontwant: @[ControlIWant(messageIds: @[msgId])] + )))) + + for peer in toSendPeers: + for heDontWant in peer.heDontWants: + if msgId in heDontWant: + seenPeers.incl(peer) + break + toSendPeers.excl(seenPeers) + + # In theory, if topics are the same in all messages, we could batch - we'd # also have to be careful to only include validated messages g.broadcast(toSendPeers, RPCMsg(messages: @[msg])) diff --git a/libp2p/protocols/pubsub/gossipsub/behavior.nim b/libp2p/protocols/pubsub/gossipsub/behavior.nim index d13fac68b..e4b193549 100644 --- a/libp2p/protocols/pubsub/gossipsub/behavior.nim +++ b/libp2p/protocols/pubsub/gossipsub/behavior.nim @@ -262,6 +262,15 @@ proc handleIHave*(g: GossipSub, g.rng.shuffle(res.messageIds) return res +proc handleIDontWant*(g: GossipSub, + peer: PubSubPeer, + iDontWants: seq[ControlIWant]) = + for dontWant in iDontWants: + for messageId in dontWant.messageIds: + if peer.heDontWants[^1].len > 1000: break + if messageId.len > 100: continue + peer.heDontWants[^1].incl(messageId) + proc handleIWant*(g: GossipSub, peer: PubSubPeer, iwants: seq[ControlIWant]): seq[Message] {.raises: [].} = @@ -629,6 +638,9 @@ proc onHeartbeat(g: GossipSub) {.raises: [].} = peer.sentIHaves.addFirst(default(HashSet[MessageId])) if peer.sentIHaves.len > g.parameters.historyLength: discard peer.sentIHaves.popLast() + peer.heDontWants.addFirst(default(HashSet[MessageId])) + if peer.heDontWants.len > g.parameters.historyLength: + discard peer.heDontWants.popLast() peer.iHaveBudget = IHavePeerBudget peer.pingBudget = PingsPeerBudget diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index d10ec1e4b..1dcd28286 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -20,7 +20,7 @@ import rpc/[messages, message, protobuf], ../../protobuf/minprotobuf, ../../utility -export peerid, connection +export peerid, connection, deques logScope: topics = "libp2p pubsubpeer" @@ -60,6 +60,7 @@ type score*: float64 sentIHaves*: Deque[HashSet[MessageId]] + heDontWants*: Deque[HashSet[MessageId]] iHaveBudget*: int pingBudget*: int maxMessageSize: int @@ -317,3 +318,4 @@ proc new*( maxMessageSize: maxMessageSize ) result.sentIHaves.addFirst(default(HashSet[MessageId])) + result.heDontWants.addFirst(default(HashSet[MessageId])) diff --git a/libp2p/protocols/pubsub/rpc/messages.nim b/libp2p/protocols/pubsub/rpc/messages.nim index 6c3ee794e..ce6dd318b 100644 --- a/libp2p/protocols/pubsub/rpc/messages.nim +++ b/libp2p/protocols/pubsub/rpc/messages.nim @@ -42,6 +42,7 @@ type iwant*: seq[ControlIWant] graft*: seq[ControlGraft] prune*: seq[ControlPrune] + idontwant*: seq[ControlIWant] ControlIHave* = object topicId*: string diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index 87bc1d1b4..4aa2e5210 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -87,6 +87,8 @@ proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) = ipb.write(3, graft) for prune in control.prune: ipb.write(4, prune) + for idontwant in control.idontwant: + ipb.write(5, idontwant) if len(ipb.buffer) > 0: ipb.finish() pb.write(field, ipb) @@ -210,6 +212,7 @@ proc decodeControl*(pb: ProtoBuffer): ProtoResult[Option[ControlMessage]] {. var iwantpbs: seq[seq[byte]] var graftpbs: seq[seq[byte]] var prunepbs: seq[seq[byte]] + var idontwant: seq[seq[byte]] if ? cpb.getRepeatedField(1, ihavepbs): for item in ihavepbs: control.ihave.add(? decodeIHave(initProtoBuffer(item))) @@ -222,6 +225,9 @@ proc decodeControl*(pb: ProtoBuffer): ProtoResult[Option[ControlMessage]] {. if ? cpb.getRepeatedField(4, prunepbs): for item in prunepbs: control.prune.add(? decodePrune(initProtoBuffer(item))) + if ? cpb.getRepeatedField(5, idontwant): + for item in idontwant: + control.idontwant.add(? decodeIWant(initProtoBuffer(item))) trace "decodeControl: message statistics", graft_count = len(control.graft), prune_count = len(control.prune), ihave_count = len(control.ihave), diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 396bccbab..f3d698cbc 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -796,3 +796,63 @@ suite "GossipSub": ) await allFuturesThrowing(nodesFut.concat()) + + asyncTest "e2e - iDontWant": + # 3 nodes: A <=> B <=> C + # (A & C are NOT connected). We pre-emptively send a dontwant from C to B, + # and check that B doesn't relay the message to C. + # We also check that B sends IDONTWANT to C, but not A + func dumbMsgIdProvider(m: Message): Result[MessageId, ValidationResult] = + ok(newSeq[byte](10)) + let + nodes = generateNodes( + 3, + gossip = true, + msgIdProvider = dumbMsgIdProvider + ) + + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + nodes[2].switch.start(), + ) + + await nodes[0].switch.connect(nodes[1].switch.peerInfo.peerId, nodes[1].switch.peerInfo.addrs) + await nodes[1].switch.connect(nodes[2].switch.peerInfo.peerId, nodes[2].switch.peerInfo.addrs) + + let bFinished = newFuture[void]() + proc handlerA(topic: string, data: seq[byte]) {.async, gcsafe.} = discard + proc handlerB(topic: string, data: seq[byte]) {.async, gcsafe.} = bFinished.complete() + proc handlerC(topic: string, data: seq[byte]) {.async, gcsafe.} = doAssert false + + nodes[0].subscribe("foobar", handlerA) + nodes[1].subscribe("foobar", handlerB) + nodes[2].subscribe("foobar", handlerB) + await waitSubGraph(nodes, "foobar") + + var gossip1: GossipSub = GossipSub(nodes[0]) + var gossip2: GossipSub = GossipSub(nodes[1]) + var gossip3: GossipSub = GossipSub(nodes[2]) + + check: gossip3.mesh.peers("foobar") == 1 + + gossip3.broadcast(gossip3.mesh["foobar"], RPCMsg(control: some(ControlMessage( + idontwant: @[ControlIWant(messageIds: @[newSeq[byte](10)])] + )))) + checkExpiring: gossip2.mesh.getOrDefault("foobar").anyIt(it.heDontWants[^1].len == 1) + + tryPublish await nodes[0].publish("foobar", newSeq[byte](10000)), 1 + + await bFinished + + checkExpiring: toSeq(gossip3.mesh.getOrDefault("foobar")).anyIt(it.heDontWants[^1].len == 1) + check: toSeq(gossip1.mesh.getOrDefault("foobar")).anyIt(it.heDontWants[^1].len == 0) + + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop(), + nodes[2].switch.stop() + ) + + await allFuturesThrowing(nodesFut.concat()) +