diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index bde345046..a0f65b11f 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -113,6 +113,9 @@ proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] = peers.add(peerId) return peers +proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] = + return c.muxed + proc addConnEventHandler*(c: ConnManager, handler: ConnEventHandler, kind: ConnEventKind) = diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index a68d80e71..13d988b47 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -248,3 +248,7 @@ method close*(m: Mplex) {.async, gcsafe.} = m.channels[true].clear() trace "Closed mplex", m + +method getStreams*(m: Mplex): seq[Connection] = + for c in m.channels[false].values: result.add(c) + for c in m.channels[true].values: result.add(c) diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 0221ed743..ac00759b1 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -63,3 +63,5 @@ proc new*( let muxerProvider = T(newMuxer: creator, codec: codec) muxerProvider + +method getStreams*(m: Muxer): seq[Connection] {.base.} = doAssert false, "not implemented" diff --git a/libp2p/muxers/yamux/yamux.nim b/libp2p/muxers/yamux/yamux.nim index 8d94b719d..9caa365c7 100644 --- a/libp2p/muxers/yamux/yamux.nim +++ b/libp2p/muxers/yamux/yamux.nim @@ -508,6 +508,9 @@ method handle*(m: Yamux) {.async, gcsafe.} = await m.close() trace "Stopped yamux handler" +method getStreams*(m: Yamux): seq[Connection] = + for c in m.channels.values: result.add(c) + method newStream*( m: Yamux, name: string = "", diff --git a/tests/helpers.nim b/tests/helpers.nim index d79ca72f9..42533a1c7 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -4,6 +4,7 @@ else: {.push raises: [].} import chronos +import algorithm import ../libp2p/transports/tcptransport import ../libp2p/stream/bufferstream @@ -117,3 +118,19 @@ proc checkExpiringInternal(cond: proc(): bool {.raises: [Defect], gcsafe.} ): Fu template checkExpiring*(code: untyped): untyped = check await checkExpiringInternal(proc(): bool = code) + +proc unorderedCompare*[T](a, b: seq[T]): bool = + if a == b: + return true + if a.len != b.len: + return false + + var aSorted = a + var bSorted = b + aSorted.sort() + bSorted.sort() + + if aSorted == bSorted: + return true + + return false \ No newline at end of file diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 72f2f403c..4f2b246b4 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -1,4 +1,4 @@ -import sequtils +import std/[sequtils,tables] import stew/results import chronos import ../libp2p/[connmanager, @@ -42,6 +42,19 @@ suite "Connection Manager": await connMngr.close() + asyncTest "get all connections": + let connMngr = ConnManager.new() + + let peers = toSeq(0..<2).mapIt(PeerId.random.tryGet()) + let muxs = toSeq(0..<2).mapIt(getMuxer(peers[it])) + for mux in muxs: connMngr.storeMuxer(mux) + + let conns = connMngr.getConnections() + let connsMux = toSeq(conns.values).mapIt(it[0]) + check unorderedCompare(connsMux, muxs) + + await connMngr.close() + asyncTest "shouldn't allow a closed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 85698630e..661ff3b9a 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -15,7 +15,7 @@ import ../libp2p/[errors, import ./helpers -{.used.} +{.used.} suite "Mplex": teardown: @@ -662,9 +662,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) for i, s in dialStreams: await s.closeWithEOF() @@ -710,9 +711,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) proc dialReadLoop() {.async.} = for s in dialStreams: @@ -769,9 +771,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) await mplexDial.close() await allFuturesThrowing( @@ -812,9 +815,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 @@ -858,9 +862,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 @@ -901,9 +906,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 @@ -947,9 +953,10 @@ suite "Mplex": let mplexDial = Mplex.new(conn) let mplexDialFut = mplexDial.handle() - var dialStreams: seq[Connection] - for i in 0..9: - dialStreams.add((await mplexDial.newStream())) + var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream()) + + check: + unorderedCompare(dialStreams, mplexDial.getStreams()) checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 diff --git a/tests/testyamux.nim b/tests/testyamux.nim index b9a4cf590..86aee697b 100644 --- a/tests/testyamux.nim +++ b/tests/testyamux.nim @@ -34,6 +34,8 @@ suite "Yamux": await conn.close() let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await streamA.writeLp(fromHex("1234")) check (await streamA.readLp(100)) == fromHex("5678") await streamA.close() @@ -53,6 +55,8 @@ suite "Yamux": handlerBlocker.complete() let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await streamA.close() readerBlocker.complete() @@ -68,7 +72,10 @@ suite "Yamux": var buffer: array[160000, byte] discard await conn.readOnce(addr buffer[0], 160000) await conn.close() + let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block let secondWriter = streamA.write(newSeq[byte](20)) @@ -88,7 +95,10 @@ suite "Yamux": var buffer: array[160000, byte] discard await conn.readOnce(addr buffer[0], 160000) await conn.close() + let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block let secondWriter = streamA.write(newSeq[byte](20)) @@ -123,7 +133,10 @@ suite "Yamux": numberOfRead.inc() writerBlocker.complete() await conn.close() + let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + # Need to exhaust initial window first await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await streamA.write(newSeq[byte](142)) @@ -144,6 +157,8 @@ suite "Yamux": await conn.close() let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await streamA.write(newSeq[byte](256000)) let wrFut = collect(newSeq): for _ in 0..3: @@ -164,6 +179,8 @@ suite "Yamux": check (await conn.readLp(100)) == fromHex("5678") let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await streamA.writeLp(fromHex("1234")) expect LPStreamRemoteClosedError: discard await streamA.readLp(100) await streamA.writeLp(fromHex("5678")) @@ -180,6 +197,8 @@ suite "Yamux": await conn.close() let streamA = await yamuxa.newStream() + check streamA == yamuxa.getStreams()[0] + await yamuxa.close() expect LPStreamClosedError: await streamA.writeLp(fromHex("1234")) expect LPStreamClosedError: discard await streamA.readLp(100)