Add getters for conns and streams (#878)

This commit is contained in:
Alvaro Revuelta 2023-03-31 00:16:39 +02:00 committed by GitHub
parent af5299f26c
commit 53b060f8f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 91 additions and 23 deletions

View File

@ -113,6 +113,9 @@ proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] =
peers.add(peerId) peers.add(peerId)
return peers return peers
proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] =
return c.muxed
proc addConnEventHandler*(c: ConnManager, proc addConnEventHandler*(c: ConnManager,
handler: ConnEventHandler, handler: ConnEventHandler,
kind: ConnEventKind) = kind: ConnEventKind) =

View File

@ -248,3 +248,7 @@ method close*(m: Mplex) {.async, gcsafe.} =
m.channels[true].clear() m.channels[true].clear()
trace "Closed mplex", m trace "Closed mplex", m
method getStreams*(m: Mplex): seq[Connection] =
for c in m.channels[false].values: result.add(c)
for c in m.channels[true].values: result.add(c)

View File

@ -63,3 +63,5 @@ proc new*(
let muxerProvider = T(newMuxer: creator, codec: codec) let muxerProvider = T(newMuxer: creator, codec: codec)
muxerProvider muxerProvider
method getStreams*(m: Muxer): seq[Connection] {.base.} = doAssert false, "not implemented"

View File

@ -508,6 +508,9 @@ method handle*(m: Yamux) {.async, gcsafe.} =
await m.close() await m.close()
trace "Stopped yamux handler" trace "Stopped yamux handler"
method getStreams*(m: Yamux): seq[Connection] =
for c in m.channels.values: result.add(c)
method newStream*( method newStream*(
m: Yamux, m: Yamux,
name: string = "", name: string = "",

View File

@ -4,6 +4,7 @@ else:
{.push raises: [].} {.push raises: [].}
import chronos import chronos
import algorithm
import ../libp2p/transports/tcptransport import ../libp2p/transports/tcptransport
import ../libp2p/stream/bufferstream import ../libp2p/stream/bufferstream
@ -117,3 +118,19 @@ proc checkExpiringInternal(cond: proc(): bool {.raises: [Defect], gcsafe.} ): Fu
template checkExpiring*(code: untyped): untyped = template checkExpiring*(code: untyped): untyped =
check await checkExpiringInternal(proc(): bool = code) check await checkExpiringInternal(proc(): bool = code)
proc unorderedCompare*[T](a, b: seq[T]): bool =
if a == b:
return true
if a.len != b.len:
return false
var aSorted = a
var bSorted = b
aSorted.sort()
bSorted.sort()
if aSorted == bSorted:
return true
return false

View File

@ -1,4 +1,4 @@
import sequtils import std/[sequtils,tables]
import stew/results import stew/results
import chronos import chronos
import ../libp2p/[connmanager, import ../libp2p/[connmanager,
@ -42,6 +42,19 @@ suite "Connection Manager":
await connMngr.close() await connMngr.close()
asyncTest "get all connections":
let connMngr = ConnManager.new()
let peers = toSeq(0..<2).mapIt(PeerId.random.tryGet())
let muxs = toSeq(0..<2).mapIt(getMuxer(peers[it]))
for mux in muxs: connMngr.storeMuxer(mux)
let conns = connMngr.getConnections()
let connsMux = toSeq(conns.values).mapIt(it[0])
check unorderedCompare(connsMux, muxs)
await connMngr.close()
asyncTest "shouldn't allow a closed connection": asyncTest "shouldn't allow a closed connection":
let connMngr = ConnManager.new() let connMngr = ConnManager.new()
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()

View File

@ -662,9 +662,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
for i, s in dialStreams: for i, s in dialStreams:
await s.closeWithEOF() await s.closeWithEOF()
@ -710,9 +711,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
proc dialReadLoop() {.async.} = proc dialReadLoop() {.async.} =
for s in dialStreams: for s in dialStreams:
@ -769,9 +771,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
await mplexDial.close() await mplexDial.close()
await allFuturesThrowing( await allFuturesThrowing(
@ -812,9 +815,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
@ -858,9 +862,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
@ -901,9 +906,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
@ -947,9 +953,10 @@ suite "Mplex":
let mplexDial = Mplex.new(conn) let mplexDial = Mplex.new(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
var dialStreams: seq[Connection] var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
for i in 0..9:
dialStreams.add((await mplexDial.newStream())) check:
unorderedCompare(dialStreams, mplexDial.getStreams())
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10 checkExpiring: listenStreams.len == 10 and dialStreams.len == 10

View File

@ -34,6 +34,8 @@ suite "Yamux":
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await streamA.writeLp(fromHex("1234")) await streamA.writeLp(fromHex("1234"))
check (await streamA.readLp(100)) == fromHex("5678") check (await streamA.readLp(100)) == fromHex("5678")
await streamA.close() await streamA.close()
@ -53,6 +55,8 @@ suite "Yamux":
handlerBlocker.complete() handlerBlocker.complete()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
await streamA.close() await streamA.close()
readerBlocker.complete() readerBlocker.complete()
@ -68,7 +72,10 @@ suite "Yamux":
var buffer: array[160000, byte] var buffer: array[160000, byte]
discard await conn.readOnce(addr buffer[0], 160000) discard await conn.readOnce(addr buffer[0], 160000)
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
let secondWriter = streamA.write(newSeq[byte](20)) let secondWriter = streamA.write(newSeq[byte](20))
@ -88,7 +95,10 @@ suite "Yamux":
var buffer: array[160000, byte] var buffer: array[160000, byte]
discard await conn.readOnce(addr buffer[0], 160000) discard await conn.readOnce(addr buffer[0], 160000)
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
let secondWriter = streamA.write(newSeq[byte](20)) let secondWriter = streamA.write(newSeq[byte](20))
@ -123,7 +133,10 @@ suite "Yamux":
numberOfRead.inc() numberOfRead.inc()
writerBlocker.complete() writerBlocker.complete()
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
# Need to exhaust initial window first # Need to exhaust initial window first
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
await streamA.write(newSeq[byte](142)) await streamA.write(newSeq[byte](142))
@ -144,6 +157,8 @@ suite "Yamux":
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await streamA.write(newSeq[byte](256000)) await streamA.write(newSeq[byte](256000))
let wrFut = collect(newSeq): let wrFut = collect(newSeq):
for _ in 0..3: for _ in 0..3:
@ -164,6 +179,8 @@ suite "Yamux":
check (await conn.readLp(100)) == fromHex("5678") check (await conn.readLp(100)) == fromHex("5678")
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await streamA.writeLp(fromHex("1234")) await streamA.writeLp(fromHex("1234"))
expect LPStreamRemoteClosedError: discard await streamA.readLp(100) expect LPStreamRemoteClosedError: discard await streamA.readLp(100)
await streamA.writeLp(fromHex("5678")) await streamA.writeLp(fromHex("5678"))
@ -180,6 +197,8 @@ suite "Yamux":
await conn.close() await conn.close()
let streamA = await yamuxa.newStream() let streamA = await yamuxa.newStream()
check streamA == yamuxa.getStreams()[0]
await yamuxa.close() await yamuxa.close()
expect LPStreamClosedError: await streamA.writeLp(fromHex("1234")) expect LPStreamClosedError: await streamA.writeLp(fromHex("1234"))
expect LPStreamClosedError: discard await streamA.readLp(100) expect LPStreamClosedError: discard await streamA.readLp(100)