diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index 5fb89eb56..7d5bf3dd0 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -12,6 +12,7 @@ import chronos, chronicles import pubsub, pubsubpeer, rpcmsg, + ../../crypto/crypto, ../../connection, ../../peerinfo, ../../peer @@ -58,11 +59,11 @@ proc rpcHandler(f: FloodSub, f.peerTopics[s.topic] = initSet[string]() if s.subscribe: - trace "subscribing to topic", peer = id, subscriptions = m.subscriptions, topic = s.topic + trace "adding subscription for topic", peer = id, subscriptions = m.subscriptions, topic = s.topic # subscribe the peer to the topic f.peerTopics[s.topic].incl(id) else: - trace "unsubscribing to topic", peer = id, subscriptions = m.subscriptions, topic = s.topic + trace "removing subscription for topic", peer = id, subscriptions = m.subscriptions, topic = s.topic # unsubscribe the peer from the topic f.peerTopics[s.topic].excl(id) @@ -127,12 +128,14 @@ method subscribeToPeer*(f: FloodSub, conn: Connection) {.async, gcsafe.} = method publish*(f: FloodSub, topic: string, data: seq[byte]) {.async, gcsafe.} = - trace "about to publish message on topic", topic = topic, data = data - let msg = makeMessage(f.peerInfo.peerId.get(), data, topic) - if topic in f.peerTopics: - for p in f.peerTopics[topic]: - trace "pubslishing message", topic = topic, peer = p, data = data - await f.peers[p].send(@[RPCMsg(messages: @[msg])]) + trace "about to publish message on topic", name = topic, data = data.toHex() + if data.len > 0 and topic.len > 0: + let msg = makeMessage(f.peerInfo.peerId.get(), data, topic) + if topic in f.peerTopics: + trace "processing topic", name = topic + for p in f.peerTopics[topic]: + trace "pubslishing message", topic = topic, peer = p, data = data + await f.peers[p].send(@[RPCMsg(messages: @[msg])]) method subscribe*(f: FloodSub, topic: string, diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index c6d8e5349..66ad58fed 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -67,8 +67,9 @@ method subscribe*(p: PubSub, ## on every received message ## if not p.topics.contains(topic): + trace "subscribing to topic", name = topic p.topics[topic] = Topic(name: topic) - + p.topics[topic].handler.add(handler) method publish*(p: PubSub, topic: string, data: seq[byte]) {.base, async, gcsafe.} = diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 44a24f19b..e6074135d 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import options +import options, sets, hashes, strutils import chronos, chronicles import rpcmsg, ../../peer, @@ -26,6 +26,7 @@ type handler*: RPCHandler topics*: seq[string] id*: string # base58 peer id string + seen: HashSet[string] # list of messages forwarded to peers RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} @@ -48,8 +49,18 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async, gcsafe.} = for m in msgs: trace "sending msgs to peer", peer = p.id, msgs = msgs let encoded = encodeRpcMsg(m) - if encoded.buffer.len > 0: - await p.conn.writeLp(encoded.buffer) + if encoded.buffer.len <= 0: + trace "empty message, skipping", peer = p.id + return + + let encodedHex = encoded.buffer.toHex() + trace "sending encoded msgs to peer", peer = p.id, encoded = encodedHex + if p.seen.contains(encodedHex): + trace "message already sent to peer, skipping", peer = p.id + continue + + await p.conn.writeLp(encoded.buffer) + p.seen.incl(encodedHex) proc newPubSubPeer*(conn: Connection, handler: RPCHandler): PubSubPeer = new result @@ -57,3 +68,4 @@ proc newPubSubPeer*(conn: Connection, handler: RPCHandler): PubSubPeer = result.conn = conn result.peerInfo = conn.peerInfo result.id = conn.peerInfo.peerId.get().pretty() + result.seen = initSet[string]() diff --git a/libp2p/protocols/pubsub/rpcmsg.nim b/libp2p/protocols/pubsub/rpcmsg.nim index e11562c08..6fb344476 100644 --- a/libp2p/protocols/pubsub/rpcmsg.nim +++ b/libp2p/protocols/pubsub/rpcmsg.nim @@ -55,7 +55,7 @@ proc encodeSubs(subs: SubOpts, buff: var ProtoBuffer) {.gcsafe.} = buff.write(initProtoField(2, subs.topic)) proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = - result = initProtoBuffer({WithVarintLength}) + result = initProtoBuffer() trace "encoding msg: ", msg = msg if msg.subscriptions.len > 0: @@ -63,6 +63,7 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = for s in msg.subscriptions: encodeSubs(s, subs) + # write subscriptions to protobuf subs.finish() result.write(initProtoField(1, subs)) @@ -71,10 +72,12 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = for m in msg.messages: encodeMessage(m, messages) + # write messages to protobuf messages.finish() result.write(initProtoField(2, messages)) - result.finish() + if result.buffer.len > 0: + result.finish() proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} = var pb = initProtoBuffer(msg) diff --git a/tests/testpubsub.nim b/tests/testpubsub.nim index 46ae482a0..dd94153dc 100644 --- a/tests/testpubsub.nim +++ b/tests/testpubsub.nim @@ -7,8 +7,137 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import unittest -import chronos +import unittest, options, tables, sugar, sequtils +import chronos, chronicles +import ../libp2p/switch, + ../libp2p/multistream, + ../libp2p/protocols/identify, + ../libp2p/connection, + ../libp2p/transports/[transport, tcptransport], + ../libp2p/multiaddress, + ../libp2p/peerinfo, + ../libp2p/crypto/crypto, + ../libp2p/peer, + ../libp2p/protocols/protocol, + ../libp2p/muxers/muxer, + ../libp2p/muxers/mplex/mplex, + ../libp2p/muxers/mplex/types, + ../libp2p/protocols/secure/secure, + ../libp2p/protocols/secure/secio, + ../libp2p/protocols/pubsub/pubsub, + ../libp2p/protocols/pubsub/floodsub + +proc createMplex(conn: Connection): Muxer = + result = newMplex(conn) + +proc createNode(privKey: Option[PrivateKey] = none(PrivateKey), + address: string = "/ip4/127.0.0.1/tcp/0"): Switch = + var peerInfo: PeerInfo + var seckey = privKey + if privKey.isNone: + seckey = some(PrivateKey.random(RSA)) + + peerInfo.peerId = some(PeerID.init(seckey.get())) + peerInfo.addrs.add(Multiaddress.init(address)) + + let mplexProvider = newMuxerProvider(createMplex, MplexCodec) + let transports = @[Transport(newTransport(TcpTransport))] + let muxers = [(MplexCodec, mplexProvider)].toTable() + let identify = newIdentify(peerInfo) + let secureManagers = [(SecioCodec, Secure(newSecio(seckey.get())))].toTable() + let pubSub = some(PubSub(newFloodSub(peerInfo))) + result = newSwitch(peerInfo, + transports, + identify, + muxers, + secureManagers = secureManagers, + pubSub = pubSub) + +proc generateNodes*(num: Natural): seq[Switch] = + for i in 0.. B": + proc testBasicPubSub(): Future[bool] {.async.} = + var passed: bool + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foobar" + passed = true + + var nodes = generateNodes(2) + var wait = await nodes[1].start() + + await nodes[0].subscribeToPeer(nodes[1].peerInfo) + + await nodes[1].subscribe("foobar", handler) + await sleepAsync(100.millis) + + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) + await sleepAsync(100.millis) + + await nodes[1].stop() + await allFutures(wait) + result = passed + + check: + waitFor(testBasicPubSub()) == true + + test "FloodSub basic publish/subscribe B -> A": + proc testBasicPubSub(): Future[bool] {.async.} = + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foobar" + + var nodes = generateNodes(2) + var wait = await nodes[1].start() + + await nodes[0].subscribeToPeer(nodes[1].peerInfo) + + await nodes[0].subscribe("foobar", handler) + await sleepAsync(100.millis) + + await nodes[1].publish("foobar", cast[seq[byte]]("Hello!")) + await sleepAsync(100.millis) + + await nodes[1].stop() + await allFutures(wait) + result = true + + check: + waitFor(testBasicPubSub()) == true + + test "basic FloodSub": + proc testBasicFloodSub(): Future[bool] {.async.} = + var passed: bool + proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = + check topic == "foobar" + passed = true + + var nodes: seq[Switch] = generateNodes(4) + var awaitters: seq[Future[void]] + for node in nodes: + awaitters.add(await node.start()) + await node.subscribe("foobar", handler) + await sleepAsync(100.millis) + + await subscribeNodes(nodes) + await sleepAsync(500.millis) + + for node in nodes: + await node.publish("foobar", cast[seq[byte]]("Hello!")) + await sleepAsync(100.millis) + + await allFutures(nodes.mapIt(it.stop())) + await allFutures(awaitters) + + result = passed + + check: + waitFor(testBasicFloodSub()) == true