diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index 502f4f2..dcb2d64 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -106,6 +106,11 @@ method init(f: FloodSub) = f.handler = handler f.codec = FloodSubCodec +method subscribeToPeer*(p: FloodSub, + conn: Connection) {.async.} = + await procCall PubSub(p).subscribeToPeer(conn) + asyncCheck p.handleConn(conn, FloodSubCodec) + method publish*(f: FloodSub, topic: string, data: seq[byte]) {.async.} = diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 4fb0df2..2cc95d2 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -101,6 +101,11 @@ method handleDisconnect(g: GossipSub, peer: PubSubPeer) {.async.} = for t in g.fanout.keys: g.fanout[t].excl(peer.id) +method subscribeToPeer*(p: GossipSub, + conn: Connection) {.async.} = + await procCall PubSub(p).subscribeToPeer(conn) + asyncCheck p.handleConn(conn, GossipSubCodec) + method subscribeTopic*(g: GossipSub, topic: string, subscribe: bool, diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 8c3fb2c..a18581b 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -110,6 +110,16 @@ proc getPeer(p: PubSub, peer.observers = p.observers result = peer +proc internalClenaup(p: PubSub, conn: Connection) {.async.} = + # handle connection close + if conn.closed: + return + + var peer = p.getPeer(conn.peerInfo, p.codec) + await conn.closeEvent.wait() + trace "connection closed, cleaning up peer", peer = conn.peerInfo.id + await p.cleanUpHelper(peer) + method handleConn*(p: PubSub, conn: Connection, proto: string) {.base, async.} = @@ -141,15 +151,7 @@ method handleConn*(p: PubSub, peer.handler = handler await peer.handle(conn) # spawn peer read loop trace "pubsub peer handler ended, cleaning up" - await p.cleanUpHelper(peer) - -proc internalClenaup(p: PubSub, conn: Connection) {.async.} = - # handle connection close - var peer = p.getPeer(conn.peerInfo, p.codec) - await conn.closeEvent.wait() - trace "connection closed, cleaning up peer", peer = conn.peerInfo.id - - await p.cleanUpHelper(peer) + await p.internalClenaup(conn) method subscribeToPeer*(p: PubSub, conn: Connection) {.base, async.} = diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 597f3a4..c1bca76 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -71,7 +71,8 @@ method atEof*(s: ChronosStream): bool {.inline.} = s.client.atEof() method close*(s: ChronosStream) {.async.} = - if not s.closed: + if not s.isClosed: + s.isClosed = true trace "shutting chronos stream", address = $s.client.remoteAddress() if not s.client.closed(): try: diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 2f284ab..288d491 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -439,6 +439,15 @@ proc newSwitch*(peerInfo: PeerInfo, muxer.connection.peerInfo = await s.identify(stream) await stream.close() + # store muxer for connection + s.muxed[muxer.connection.peerInfo.id] = muxer + + # store muxed connection + s.connections[muxer.connection.peerInfo.id] = muxer.connection + + # try establishing a pubsub connection + await s.subscribeToPeer(muxer.connection.peerInfo) + for k in secureManagers.keys: trace "adding secure manager ", codec = secureManagers[k].codec result.secureManagers[k] = secureManagers[k] diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index c45ceec..4388c7a 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -24,10 +24,6 @@ import utils, ../../libp2p/[errors, import ../helpers -proc createGossipSub(): GossipSub = - var peerInfo = PeerInfo.init(PrivateKey.random(RSA).get()) - result = newPubSub(GossipSub, peerInfo) - proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = if sender == receiver: return @@ -35,7 +31,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = # this is for testing purposes only # peers can be inside `mesh` and `fanout`, not just `gossipsub` var ceil = 15 - let fsub = cast[GossipSub](sender.pubSub.get()) + let fsub = GossipSub(sender.pubSub.get()) while (not fsub.gossipsub.hasKey(key) or not fsub.gossipsub[key].contains(receiver.peerInfo.id)) and (not fsub.mesh.hasKey(key) or @@ -277,7 +273,7 @@ suite "GossipSub": check: "foobar" in gossipSub1.gossipsub - await passed.wait(5.seconds) + await passed.wait(1.seconds) trace "test done, stopping..."