Add getters for conns and streams (#878)
This commit is contained in:
parent
af5299f26c
commit
53b060f8f0
|
@ -113,6 +113,9 @@ proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] =
|
|||
peers.add(peerId)
|
||||
return peers
|
||||
|
||||
proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] =
|
||||
return c.muxed
|
||||
|
||||
proc addConnEventHandler*(c: ConnManager,
|
||||
handler: ConnEventHandler,
|
||||
kind: ConnEventKind) =
|
||||
|
|
|
@ -248,3 +248,7 @@ method close*(m: Mplex) {.async, gcsafe.} =
|
|||
m.channels[true].clear()
|
||||
|
||||
trace "Closed mplex", m
|
||||
|
||||
method getStreams*(m: Mplex): seq[Connection] =
|
||||
for c in m.channels[false].values: result.add(c)
|
||||
for c in m.channels[true].values: result.add(c)
|
||||
|
|
|
@ -63,3 +63,5 @@ proc new*(
|
|||
|
||||
let muxerProvider = T(newMuxer: creator, codec: codec)
|
||||
muxerProvider
|
||||
|
||||
method getStreams*(m: Muxer): seq[Connection] {.base.} = doAssert false, "not implemented"
|
||||
|
|
|
@ -508,6 +508,9 @@ method handle*(m: Yamux) {.async, gcsafe.} =
|
|||
await m.close()
|
||||
trace "Stopped yamux handler"
|
||||
|
||||
method getStreams*(m: Yamux): seq[Connection] =
|
||||
for c in m.channels.values: result.add(c)
|
||||
|
||||
method newStream*(
|
||||
m: Yamux,
|
||||
name: string = "",
|
||||
|
|
|
@ -4,6 +4,7 @@ else:
|
|||
{.push raises: [].}
|
||||
|
||||
import chronos
|
||||
import algorithm
|
||||
|
||||
import ../libp2p/transports/tcptransport
|
||||
import ../libp2p/stream/bufferstream
|
||||
|
@ -117,3 +118,19 @@ proc checkExpiringInternal(cond: proc(): bool {.raises: [Defect], gcsafe.} ): Fu
|
|||
|
||||
template checkExpiring*(code: untyped): untyped =
|
||||
check await checkExpiringInternal(proc(): bool = code)
|
||||
|
||||
proc unorderedCompare*[T](a, b: seq[T]): bool =
|
||||
if a == b:
|
||||
return true
|
||||
if a.len != b.len:
|
||||
return false
|
||||
|
||||
var aSorted = a
|
||||
var bSorted = b
|
||||
aSorted.sort()
|
||||
bSorted.sort()
|
||||
|
||||
if aSorted == bSorted:
|
||||
return true
|
||||
|
||||
return false
|
|
@ -1,4 +1,4 @@
|
|||
import sequtils
|
||||
import std/[sequtils,tables]
|
||||
import stew/results
|
||||
import chronos
|
||||
import ../libp2p/[connmanager,
|
||||
|
@ -42,6 +42,19 @@ suite "Connection Manager":
|
|||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "get all connections":
|
||||
let connMngr = ConnManager.new()
|
||||
|
||||
let peers = toSeq(0..<2).mapIt(PeerId.random.tryGet())
|
||||
let muxs = toSeq(0..<2).mapIt(getMuxer(peers[it]))
|
||||
for mux in muxs: connMngr.storeMuxer(mux)
|
||||
|
||||
let conns = connMngr.getConnections()
|
||||
let connsMux = toSeq(conns.values).mapIt(it[0])
|
||||
check unorderedCompare(connsMux, muxs)
|
||||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "shouldn't allow a closed connection":
|
||||
let connMngr = ConnManager.new()
|
||||
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
|
||||
|
|
|
@ -662,9 +662,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
for i, s in dialStreams:
|
||||
await s.closeWithEOF()
|
||||
|
@ -710,9 +711,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
proc dialReadLoop() {.async.} =
|
||||
for s in dialStreams:
|
||||
|
@ -769,9 +771,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
await mplexDial.close()
|
||||
await allFuturesThrowing(
|
||||
|
@ -812,9 +815,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||
|
||||
|
@ -858,9 +862,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||
|
||||
|
@ -901,9 +906,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||
|
||||
|
@ -947,9 +953,10 @@ suite "Mplex":
|
|||
|
||||
let mplexDial = Mplex.new(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
var dialStreams: seq[Connection]
|
||||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||
|
||||
check:
|
||||
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||
|
||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||
|
||||
|
|
|
@ -34,6 +34,8 @@ suite "Yamux":
|
|||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await streamA.writeLp(fromHex("1234"))
|
||||
check (await streamA.readLp(100)) == fromHex("5678")
|
||||
await streamA.close()
|
||||
|
@ -53,6 +55,8 @@ suite "Yamux":
|
|||
handlerBlocker.complete()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||
await streamA.close()
|
||||
readerBlocker.complete()
|
||||
|
@ -68,7 +72,10 @@ suite "Yamux":
|
|||
var buffer: array[160000, byte]
|
||||
discard await conn.readOnce(addr buffer[0], 160000)
|
||||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||
|
||||
let secondWriter = streamA.write(newSeq[byte](20))
|
||||
|
@ -88,7 +95,10 @@ suite "Yamux":
|
|||
var buffer: array[160000, byte]
|
||||
discard await conn.readOnce(addr buffer[0], 160000)
|
||||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||
|
||||
let secondWriter = streamA.write(newSeq[byte](20))
|
||||
|
@ -123,7 +133,10 @@ suite "Yamux":
|
|||
numberOfRead.inc()
|
||||
writerBlocker.complete()
|
||||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
# Need to exhaust initial window first
|
||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||
await streamA.write(newSeq[byte](142))
|
||||
|
@ -144,6 +157,8 @@ suite "Yamux":
|
|||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await streamA.write(newSeq[byte](256000))
|
||||
let wrFut = collect(newSeq):
|
||||
for _ in 0..3:
|
||||
|
@ -164,6 +179,8 @@ suite "Yamux":
|
|||
check (await conn.readLp(100)) == fromHex("5678")
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await streamA.writeLp(fromHex("1234"))
|
||||
expect LPStreamRemoteClosedError: discard await streamA.readLp(100)
|
||||
await streamA.writeLp(fromHex("5678"))
|
||||
|
@ -180,6 +197,8 @@ suite "Yamux":
|
|||
await conn.close()
|
||||
|
||||
let streamA = await yamuxa.newStream()
|
||||
check streamA == yamuxa.getStreams()[0]
|
||||
|
||||
await yamuxa.close()
|
||||
expect LPStreamClosedError: await streamA.writeLp(fromHex("1234"))
|
||||
expect LPStreamClosedError: discard await streamA.readLp(100)
|
||||
|
|
Loading…
Reference in New Issue