diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index eca5389..6dd53b5 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -200,16 +200,26 @@ method close*(m: Mplex) {.async, gcsafe.} = try: trace "closing mplex muxer", oid = m.oid - await all( - toSeq(m.remote.values).mapIt(it.reset()) & - toSeq(m.local.values).mapIt(it.reset())) + let channs = toSeq(m.remote.values) & + toSeq(m.local.values) + + for chann in channs: + try: + await chann.reset() + except CatchableError as exc: + warn "error resetting channel", exc = exc.msg + + for conn in m.conns: + try: + await conn.close() + except CatchableError as exc: + warn "error closing channel's connection" + + checkFutures( + await allFinished(m.handlerFuts)) - await all(m.conns.mapIt(it.close())) # dispose of channel's connections - await all(m.handlerFuts) - except CatchableError as exc: - trace "exception in mplex close", exc = exc.msg - finally: await m.connection.close() + finally: m.remote.clear() m.local.clear() m.conns = @[] diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 800ecc4..7a95b1c 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -60,7 +60,7 @@ method init(c: MuxerProvider) = if not isNil(c.muxerHandler): futs &= c.muxerHandler(muxer) - await all(futs) + checkFutures(await allFinished(futs)) except CatchableError as exc: trace "exception in muxer handler", exc = exc.msg diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 6192c6d..1a315fc 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -140,9 +140,8 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = s.muxed.del(id) if id in s.connections: - if not s.connections[id].closed: - await s.connections[id].close() s.connections.del(id) + await conn.close() s.dialedPubSubPeers.excl(id) @@ -438,7 +437,7 @@ proc newSwitch*(peerInfo: PeerInfo, if not(stream.closed): await stream.close() except CatchableError as exc: - trace "excepton in stream handler", exc = exc.msg + trace "exception in stream handler", exc = exc.msg result.mount(identity) for key, val in muxers: @@ -448,6 +447,8 @@ proc newSwitch*(peerInfo: PeerInfo, try: trace "got new muxer" stream = await muxer.newStream() + # once we got a muxed connection, attempt to + # identify it muxer.connection.peerInfo = await s.identify(stream) # store muxer for connection @@ -456,6 +457,10 @@ proc newSwitch*(peerInfo: PeerInfo, # store muxed connection s.connections[muxer.connection.peerInfo.id] = muxer.connection + muxer.connection.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.cleanupConn(muxer.connection) + # try establishing a pubsub connection await s.subscribeToPeer(muxer.connection.peerInfo) except CatchableError as exc: diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 0002e39..81bde52 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -192,6 +192,9 @@ suite "Switch": awaiters.add(await switch2.start()) await switch2.connect(switch1.peerInfo) + check switch1.connections.len > 0 + check switch2.connections.len > 0 + await sleepAsync(100.millis) await switch2.disconnect(switch1.peerInfo) @@ -204,6 +207,9 @@ suite "Switch": # echo connTracker.dump() check connTracker.isLeaked() == false + check switch1.connections.len == 0 + check switch2.connections.len == 0 + await all( switch1.stop(), switch2.stop()