diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index b9b65dadc..c00e11ca6 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -24,6 +24,8 @@ const FloodSubCodec* = "/floodsub/1.0.0" type FloodSub = ref object of PubSub + peers*: Table[string, PubSubPeer] # peerid to peer map + peerTopics*: Table[string, HashSet[string]] # topic to remote peer map proc sendSubs(f: FloodSub, peer: PubSubPeer, @@ -38,6 +40,19 @@ proc sendSubs(f: FloodSub, await peer.send(@[msg]) +proc subscribeTopic(f: FloodSub, topic: string, subscribe: bool, peerId: string) {.gcsafe.} = + if not f.peerTopics.contains(topic): + f.peerTopics[topic] = initSet[string]() + + if subscribe: + trace "adding subscription for topic", peer = peerId, name = topic + # subscribe the peer to the topic + f.peerTopics[topic].incl(peerId) + else: + trace "removing subscription for topic", peer = peerId, name = topic + # unsubscribe the peer from the topic + f.peerTopics[topic].excl(peerId) + proc rpcHandler(f: FloodSub, peer: PubSubPeer, rpcMsgs: seq[RPCMsg]) {.async, gcsafe.} = @@ -55,22 +70,12 @@ proc rpcHandler(f: FloodSub, for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic let id = peer.id - if not f.peerTopics.contains(s.topic): - f.peerTopics[s.topic] = initSet[string]() - - if s.subscribe: - trace "adding subscription for topic", peer = id, subscriptions = m.subscriptions, topic = s.topic - # subscribe the peer to the topic - f.peerTopics[s.topic].incl(id) - else: - trace "removing subscription for topic", peer = id, subscriptions = m.subscriptions, topic = s.topic - # unsubscribe the peer from the topic - f.peerTopics[s.topic].excl(id) + f.subscribeTopic(s.topic, s.subscribe, id) # send subscriptions to every peer for p in f.peers.values: - # if p.id != peer.id: - await p.send(@[RPCMsg(subscriptions: m.subscriptions)]) + if p.id != peer.id: + await p.send(@[RPCMsg(subscriptions: m.subscriptions)]) var toSendPeers: HashSet[string] = initSet[string]() if m.messages.len > 0: # if there are any messages @@ -95,22 +100,32 @@ proc handleConn(f: FloodSub, ## 2) register a handler with the peer; ## this handler gets called on every rpc message ## that the peer receives - ## 3) ask the peer to subscribe us to every topic + ## 3) ask the peer to subscribe us to every topic ## that we're interested in ## - proc handleRpc(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async, gcsafe.} = - await f.rpcHandler(peer, msgs) - - var peer = newPubSubPeer(conn, handleRpc) - if peer.peerInfo.peerId.isNone: - debug "no valid PeerInfo for peer" + if conn.peerInfo.peerId.isNone: + debug "no valid PeerId for peer" return + # create new pubsub peer + var peer = newPubSubPeer(conn, proc (peer: PubSubPeer, + msgs: seq[RPCMsg]) {.async, gcsafe.} = + # call floodsub rpc handler + await f.rpcHandler(peer, msgs)) + + trace "created new pubsub peer", id = peer.id + f.peers[peer.id] = peer let topics = toSeq(f.topics.keys) await f.sendSubs(peer, topics, true) - asyncCheck peer.handle() + let handlerFut = peer.handle() # spawn peer read loop + handlerFut.addCallback( + proc(udata: pointer = nil) {.gcsafe.} = + trace "pubsub peer handler ended, cleaning up", + peer = conn.peerInfo.peerId.get().pretty + f.peers.del(peer.id) + ) method init(f: FloodSub) = proc handler(conn: Connection, proto: string) {.async, gcsafe.} = @@ -134,15 +149,17 @@ method publish*(f: FloodSub, if data.len > 0 and topic.len > 0: let msg = makeMessage(f.peerInfo.peerId.get(), data, topic) if topic in f.peerTopics: - trace "processing topic", name = topic + trace "publishing on topic", name = topic for p in f.peerTopics[topic]: - trace "pubslishing message", topic = topic, peer = p, data = data + trace "publishing message", name = topic, peer = p, data = data await f.peers[p].send(@[RPCMsg(messages: @[msg])]) method subscribe*(f: FloodSub, topic: string, handler: TopicHandler) {.async, gcsafe.} = await procCall PubSub(f).subscribe(topic, handler) + + f.subscribeTopic(topic, true, f.peerInfo.peerId.get().pretty) for p in f.peers.values: await f.sendSubs(p, @[topic], true) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 66ad58fed..4b5d78909 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -33,8 +33,6 @@ type PubSub* = ref object of LPProtocol peerInfo*: PeerInfo topics*: Table[string, Topic] # local topics - peers*: Table[string, PubSubPeer] # peerid to peer map - peerTopics*: Table[string, HashSet[string]] # topic to remote peer map method subscribeToPeer*(p: PubSub, conn: Connection) {.base, async, gcsafe.} = ## subscribe to a peer to send/receive pubsub messages diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index e6074135d..745dd01ec 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -35,9 +35,13 @@ proc handle*(p: PubSubPeer) {.async, gcsafe.} = try: while not p.conn.closed: let data = await p.conn.readLp() - trace "Read data from peer", peer = p.peerInfo, data = data.toHex() + trace "Read data from peer", peer = p.id, data = data.toHex() + if data.toHex() in p.seen: + trace "Message already received, skipping", peer = p.id + continue + let msg = decodeRpcMsg(data) - trace "Decoded msg from peer", peer = p.peerInfo, msg = msg + trace "Decoded msg from peer", peer = p.id, msg = msg await p.handler(p, @[msg]) except: error "An exception occured while processing pubsub rpc requests", exc = getCurrentExceptionMsg() @@ -54,11 +58,11 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async, gcsafe.} = return let encodedHex = encoded.buffer.toHex() - trace "sending encoded msgs to peer", peer = p.id, encoded = encodedHex - if p.seen.contains(encodedHex): + if encodedHex in p.seen: trace "message already sent to peer, skipping", peer = p.id continue - + + trace "sending encoded msgs to peer", peer = p.id, encoded = encodedHex await p.conn.writeLp(encoded.buffer) p.seen.incl(encodedHex) diff --git a/tests/testpubsub.nim b/tests/testpubsub.nim index dd94153dc..bf25000f2 100644 --- a/tests/testpubsub.nim +++ b/tests/testpubsub.nim @@ -9,23 +9,24 @@ import unittest, options, tables, sugar, sequtils import chronos, chronicles -import ../libp2p/switch, - ../libp2p/multistream, - ../libp2p/protocols/identify, - ../libp2p/connection, - ../libp2p/transports/[transport, tcptransport], - ../libp2p/multiaddress, - ../libp2p/peerinfo, - ../libp2p/crypto/crypto, - ../libp2p/peer, - ../libp2p/protocols/protocol, - ../libp2p/muxers/muxer, - ../libp2p/muxers/mplex/mplex, - ../libp2p/muxers/mplex/types, - ../libp2p/protocols/secure/secure, - ../libp2p/protocols/secure/secio, - ../libp2p/protocols/pubsub/pubsub, - ../libp2p/protocols/pubsub/floodsub +import ../libp2p/[switch, + multistream, + protocols/identify, + connection, + transports/transport, + transports/tcptransport, + multiaddress, + peerinfo, + crypto/crypto, + peer, + protocols/protocol, + muxers/muxer, + muxers/mplex/mplex, + muxers/mplex/types, + protocols/secure/secure, + protocols/secure/secio, + protocols/pubsub/pubsub, + protocols/pubsub/floodsub] proc createMplex(conn: Connection): Muxer = result = newMplex(conn) @@ -101,10 +102,10 @@ suite "PubSub": await nodes[0].subscribeToPeer(nodes[1].peerInfo) await nodes[0].subscribe("foobar", handler) - await sleepAsync(100.millis) + await sleepAsync(10.millis) await nodes[1].publish("foobar", cast[seq[byte]]("Hello!")) - await sleepAsync(100.millis) + await sleepAsync(10.millis) await nodes[1].stop() await allFutures(wait) @@ -115,20 +116,20 @@ suite "PubSub": test "basic FloodSub": proc testBasicFloodSub(): Future[bool] {.async.} = - var passed: bool + var passed: int proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" - passed = true + passed.inc() - var nodes: seq[Switch] = generateNodes(4) + var nodes: seq[Switch] = generateNodes(20) var awaitters: seq[Future[void]] for node in nodes: awaitters.add(await node.start()) await node.subscribe("foobar", handler) - await sleepAsync(100.millis) + await sleepAsync(10.millis) await subscribeNodes(nodes) - await sleepAsync(500.millis) + await sleepAsync(50.millis) for node in nodes: await node.publish("foobar", cast[seq[byte]]("Hello!")) @@ -137,7 +138,7 @@ suite "PubSub": await allFutures(nodes.mapIt(it.stop())) await allFutures(awaitters) - result = passed + result = passed == 20 check: waitFor(testBasicFloodSub()) == true diff --git a/tests/testsecio.nim b/tests/testsecio.nim deleted file mode 100644 index c5c2eee96..000000000 --- a/tests/testsecio.nim +++ /dev/null @@ -1,2 +0,0 @@ -import unittest -