diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index bcfebf1e1..5505093b8 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -125,7 +125,7 @@ method handleConn*(p: PubSub, return proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = - # call floodsub rpc handler + # call pubsub rpc handler await p.rpcHandler(peer, msgs) let peer = p.getPeer(conn.peerInfo, proto) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 0c9c3704d..704d39de8 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import tables, sequtils, options, strformat +import tables, sequtils, options, strformat, sets import chronos, chronicles import connection, transports/transport, @@ -45,6 +45,7 @@ type streamHandler*: StreamHandler secureManagers*: Table[string, Secure] pubSub*: Option[PubSub] + dialedPubSubPeers: HashSet[string] proc newNoPubSubException(): ref Exception {.inline.} = result = newException(NoPubSubException, "no pubsub provided!") @@ -144,6 +145,9 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = await s.connections[id].close() s.connections.del(id) + if id in s.dialedPubSubPeers: + s.dialedPubSubPeers.excl(id) + # TODO: Investigate cleanupConn() always called twice for one peer. if not(conn.peerInfo.isClosed()): conn.peerInfo.close() @@ -204,6 +208,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = # handle secured connections await ms.handle(conn) +proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} + proc dial*(s: Switch, peer: PeerInfo, proto: string = ""): @@ -244,6 +250,8 @@ proc dial*(s: Switch, error "Unable to select sub-protocol", proto = proto raise newException(CatchableError, &"unable to select protocol: {proto}") + await s.subscribeToPeer(peer) + proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): raise newException(CatchableError, @@ -291,11 +299,16 @@ proc stop*(s: Switch) {.async.} = await allFutures(toSeq(s.connections.values).mapIt(s.cleanupConn(it))) await allFutures(s.transports.mapIt(it.close())) -proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = +proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = ## Subscribe to pub sub peer - if s.pubSub.isSome: - let conn = await s.dial(peerInfo, s.pubSub.get().codec) - await s.pubSub.get().subscribeToPeer(conn) + if s.pubSub.isSome and peerInfo.id notin s.dialedPubSubPeers: + try: + s.dialedPubSubPeers.incl(peerInfo.id) + let conn = await s.dial(peerInfo, s.pubSub.get().codec) + await s.pubSub.get().subscribeToPeer(conn) + except CatchableError as exc: + trace "unable to initiate pubsub", exc = exc.msg + s.dialedPubSubPeers.excl(peerInfo.id) proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {.gcsafe.} = ## subscribe to a pubsub topic @@ -351,6 +364,7 @@ proc newSwitch*(peerInfo: PeerInfo, result.identity = identity result.muxers = muxers result.secureManagers = initTable[string, Secure]() + result.dialedPubSubPeers = initHashSet[string]() let s = result # can't capture result result.streamHandler = proc(stream: Connection) {.async, gcsafe.} = diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 76926617a..79e4e4a8b 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -8,10 +8,10 @@ proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] = result.add(newStandardSwitch(gossip = gossip)) proc subscribeNodes*(nodes: seq[Switch]) {.async.} = - var dials: seq[Future[void]] + var dials: seq[Future[Connection]] for dialer in nodes: for node in nodes: if dialer.peerInfo.peerId != node.peerInfo.peerId: - dials.add(dialer.subscribeToPeer(node.peerInfo)) + dials.add(dialer.dial(node.peerInfo)) await sleepAsync(100.millis) await allFutures(dials) diff --git a/tests/testinterop.nim b/tests/testinterop.nim index a652ea9cd..33fd50076 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -93,8 +93,8 @@ proc createNode*(privKey: Option[PrivateKey] = none(PrivateKey), secureManagers = secureManagers, pubSub = pubSub) -proc testPubSubDaemonPublish(gossip: bool = false, count: int = 1): Future[ - bool] {.async.} = +proc testPubSubDaemonPublish(gossip: bool = false, + count: int = 1): Future[bool] {.async.} = var pubsubData = "TEST MESSAGE" var testTopic = "test-topic" var msgData = cast[seq[byte]](pubsubData) @@ -118,8 +118,8 @@ proc testPubSubDaemonPublish(gossip: bool = false, count: int = 1): Future[ if times >= count and not handlerFuture.finished: handlerFuture.complete(true) - await nativeNode.subscribeToPeer(NativePeerInfo.init(daemonPeer.peer, - daemonPeer.addresses)) + discard await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, + daemonPeer.addresses)) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) @@ -140,8 +140,8 @@ proc testPubSubDaemonPublish(gossip: bool = false, count: int = 1): Future[ await allFutures(awaiters) await daemonNode.close() -proc testPubSubNodePublish(gossip: bool = false, count: int = 1): Future[ - bool] {.async.} = +proc testPubSubNodePublish(gossip: bool = false, + count: int = 1): Future[bool] {.async.} = var pubsubData = "TEST MESSAGE" var testTopic = "test-topic" var msgData = cast[seq[byte]](pubsubData) @@ -157,8 +157,8 @@ proc testPubSubNodePublish(gossip: bool = false, count: int = 1): Future[ let nativePeer = nativeNode.peerInfo var handlerFuture = newFuture[bool]() - await nativeNode.subscribeToPeer(NativePeerInfo.init(daemonPeer.peer, - daemonPeer.addresses)) + discard await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, + daemonPeer.addresses)) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs)