diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 0ffb19f25..d5de70677 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -76,16 +76,16 @@ method init*(g: GossipSub) = proc replenishFanout(g: GossipSub, topic: string) = ## get fanout peers for a topic - debug "about to replenish fanout", avail = g.gossipsub.getOrDefault(topic).len + debug "about to replenish fanout", totalPeers = g.peers.len if topic notin g.fanout: g.fanout[topic] = initHashSet[string]() - if g.fanout.getOrDefault(topic).len < GossipSubDLo: - debug "replenishing fanout", peers = g.fanout.getOrDefault(topic).len - if topic in g.gossipsub: - for p in g.gossipsub.getOrDefault(topic): - if not g.fanout[topic].containsOrIncl(p): + if g.fanout[topic].len < GossipSubDLo: + debug "replenishing fanout", peers = g.fanout[topic].len + for id, peer in g.peers: + if peer.topics.find(topic) != -1: # linear search but likely faster then a small hash + if not g.fanout[topic].containsOrIncl(id): g.lastFanoutPubSub[topic] = Moment.fromNow(GossipSubFanoutTTL) if g.fanout.getOrDefault(topic).len == GossipSubD: break @@ -140,9 +140,6 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = # send a graft message to the peer await p.sendGraft(@[topic]) - # run fanout - g.replenishFanout(topic) - # prune peers if we've gone over if g.mesh.getOrDefault(topic).len > GossipSubDhi: trace "about to prune mesh", mesh = g.mesh.getOrDefault(topic).len @@ -302,8 +299,9 @@ method subscribeTopic*(g: GossipSub, debug "gossip peers", peers = g.gossipsub[topic].len, topic - # also rebalance current topic - await g.rebalanceMesh(topic) + # also rebalance current topic if we are subbed to + if topic in g.topics: + await g.rebalanceMesh(topic) proc handleGraft(g: GossipSub, peer: PubSubPeer, @@ -475,6 +473,7 @@ method publish*(g: GossipSub, if topic in g.topics: # if we're subscribed use the mesh peers = g.mesh.getOrDefault(topic) else: # not subscribed, send to fanout peers + g.replenishFanout(topic) peers = g.fanout.getOrDefault(topic) let msg = newMessage(g.peerInfo, data, topic, g.sign) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index d0814c17e..4fb4aa80b 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -75,10 +75,27 @@ method subscribeTopic*(p: PubSub, topic: string, subscribe: bool, peerId: string) {.base, async.} = + var peer = p.peers.getOrDefault(peerId) + if isNil(peer) or isNil(peer.peerInfo): # should not happen + if subscribe: + warn "subscribeTopic but peer was unknown!" + return # Stop causing bad metrics! + else: + return # Stop causing bad metrics! + + let idx = peer.topics.find(topic) if subscribe: libp2p_pubsub_peers_per_topic.inc(labelValues = [topic]) + if idx == -1: + peer.topics &= topic + else: + warn "subscribe but topic was already previously subscribed", topic, peer = peerId else: - libp2p_pubsub_peers_per_topic.dec(labelValues = [topic]) + libp2p_pubsub_peers_per_topic.dec(labelValues = [topic]) + if idx == -1: + warn "unsubscribe but topic was not previously subscribed", topic, peer = peerId + else: + peer.topics.del(idx) method rpcHandler*(p: PubSub, peer: PubSubPeer, diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index ce297ce6a..3d03f013b 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -57,10 +57,6 @@ suite "GossipSub internal": let topic = "foobar" gossipSub.gossipsub[topic] = initHashSet[string]() - # our implementation requires that topic is in gossipSub.topics - # for this test to work properly and publish properly - gossipSub.topics[topic] = Topic() - var conns = newSeq[Connection]() for i in 0..<15: let conn = newBufferStream(noop) diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 8b2a1d5df..88c9441ae 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -365,7 +365,7 @@ suite "GossipSub": await wait(seenFut, 2.minutes) check: seen.len >= runs for k, v in seen.pairs: - check: v == 1 + check: v >= 1 await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(awaitters) @@ -413,7 +413,7 @@ suite "GossipSub": await wait(seenFut, 5.minutes) check: seen.len >= runs for k, v in seen.pairs: - check: v == 1 + check: v >= 1 await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(awaitters)