Add getters for conns and streams (#878)
This commit is contained in:
parent
af5299f26c
commit
53b060f8f0
|
@ -113,6 +113,9 @@ proc connectedPeers*(c: ConnManager, dir: Direction): seq[PeerId] =
|
||||||
peers.add(peerId)
|
peers.add(peerId)
|
||||||
return peers
|
return peers
|
||||||
|
|
||||||
|
proc getConnections*(c: ConnManager): Table[PeerId, seq[Muxer]] =
|
||||||
|
return c.muxed
|
||||||
|
|
||||||
proc addConnEventHandler*(c: ConnManager,
|
proc addConnEventHandler*(c: ConnManager,
|
||||||
handler: ConnEventHandler,
|
handler: ConnEventHandler,
|
||||||
kind: ConnEventKind) =
|
kind: ConnEventKind) =
|
||||||
|
|
|
@ -248,3 +248,7 @@ method close*(m: Mplex) {.async, gcsafe.} =
|
||||||
m.channels[true].clear()
|
m.channels[true].clear()
|
||||||
|
|
||||||
trace "Closed mplex", m
|
trace "Closed mplex", m
|
||||||
|
|
||||||
|
method getStreams*(m: Mplex): seq[Connection] =
|
||||||
|
for c in m.channels[false].values: result.add(c)
|
||||||
|
for c in m.channels[true].values: result.add(c)
|
||||||
|
|
|
@ -63,3 +63,5 @@ proc new*(
|
||||||
|
|
||||||
let muxerProvider = T(newMuxer: creator, codec: codec)
|
let muxerProvider = T(newMuxer: creator, codec: codec)
|
||||||
muxerProvider
|
muxerProvider
|
||||||
|
|
||||||
|
method getStreams*(m: Muxer): seq[Connection] {.base.} = doAssert false, "not implemented"
|
||||||
|
|
|
@ -508,6 +508,9 @@ method handle*(m: Yamux) {.async, gcsafe.} =
|
||||||
await m.close()
|
await m.close()
|
||||||
trace "Stopped yamux handler"
|
trace "Stopped yamux handler"
|
||||||
|
|
||||||
|
method getStreams*(m: Yamux): seq[Connection] =
|
||||||
|
for c in m.channels.values: result.add(c)
|
||||||
|
|
||||||
method newStream*(
|
method newStream*(
|
||||||
m: Yamux,
|
m: Yamux,
|
||||||
name: string = "",
|
name: string = "",
|
||||||
|
|
|
@ -4,6 +4,7 @@ else:
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
import chronos
|
import chronos
|
||||||
|
import algorithm
|
||||||
|
|
||||||
import ../libp2p/transports/tcptransport
|
import ../libp2p/transports/tcptransport
|
||||||
import ../libp2p/stream/bufferstream
|
import ../libp2p/stream/bufferstream
|
||||||
|
@ -117,3 +118,19 @@ proc checkExpiringInternal(cond: proc(): bool {.raises: [Defect], gcsafe.} ): Fu
|
||||||
|
|
||||||
template checkExpiring*(code: untyped): untyped =
|
template checkExpiring*(code: untyped): untyped =
|
||||||
check await checkExpiringInternal(proc(): bool = code)
|
check await checkExpiringInternal(proc(): bool = code)
|
||||||
|
|
||||||
|
proc unorderedCompare*[T](a, b: seq[T]): bool =
|
||||||
|
if a == b:
|
||||||
|
return true
|
||||||
|
if a.len != b.len:
|
||||||
|
return false
|
||||||
|
|
||||||
|
var aSorted = a
|
||||||
|
var bSorted = b
|
||||||
|
aSorted.sort()
|
||||||
|
bSorted.sort()
|
||||||
|
|
||||||
|
if aSorted == bSorted:
|
||||||
|
return true
|
||||||
|
|
||||||
|
return false
|
|
@ -1,4 +1,4 @@
|
||||||
import sequtils
|
import std/[sequtils,tables]
|
||||||
import stew/results
|
import stew/results
|
||||||
import chronos
|
import chronos
|
||||||
import ../libp2p/[connmanager,
|
import ../libp2p/[connmanager,
|
||||||
|
@ -42,6 +42,19 @@ suite "Connection Manager":
|
||||||
|
|
||||||
await connMngr.close()
|
await connMngr.close()
|
||||||
|
|
||||||
|
asyncTest "get all connections":
|
||||||
|
let connMngr = ConnManager.new()
|
||||||
|
|
||||||
|
let peers = toSeq(0..<2).mapIt(PeerId.random.tryGet())
|
||||||
|
let muxs = toSeq(0..<2).mapIt(getMuxer(peers[it]))
|
||||||
|
for mux in muxs: connMngr.storeMuxer(mux)
|
||||||
|
|
||||||
|
let conns = connMngr.getConnections()
|
||||||
|
let connsMux = toSeq(conns.values).mapIt(it[0])
|
||||||
|
check unorderedCompare(connsMux, muxs)
|
||||||
|
|
||||||
|
await connMngr.close()
|
||||||
|
|
||||||
asyncTest "shouldn't allow a closed connection":
|
asyncTest "shouldn't allow a closed connection":
|
||||||
let connMngr = ConnManager.new()
|
let connMngr = ConnManager.new()
|
||||||
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
|
let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet()
|
||||||
|
|
|
@ -15,7 +15,7 @@ import ../libp2p/[errors,
|
||||||
|
|
||||||
import ./helpers
|
import ./helpers
|
||||||
|
|
||||||
{.used.}
|
{.used.}
|
||||||
|
|
||||||
suite "Mplex":
|
suite "Mplex":
|
||||||
teardown:
|
teardown:
|
||||||
|
@ -662,9 +662,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
for i, s in dialStreams:
|
for i, s in dialStreams:
|
||||||
await s.closeWithEOF()
|
await s.closeWithEOF()
|
||||||
|
@ -710,9 +711,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
proc dialReadLoop() {.async.} =
|
proc dialReadLoop() {.async.} =
|
||||||
for s in dialStreams:
|
for s in dialStreams:
|
||||||
|
@ -769,9 +771,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
await mplexDial.close()
|
await mplexDial.close()
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
|
@ -812,9 +815,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||||
|
|
||||||
|
@ -858,9 +862,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||||
|
|
||||||
|
@ -901,9 +906,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||||
|
|
||||||
|
@ -947,9 +953,10 @@ suite "Mplex":
|
||||||
|
|
||||||
let mplexDial = Mplex.new(conn)
|
let mplexDial = Mplex.new(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
var dialStreams: seq[Connection]
|
var dialStreams = toSeq(0..9).mapIt(await mplexDial.newStream())
|
||||||
for i in 0..9:
|
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
check:
|
||||||
|
unorderedCompare(dialStreams, mplexDial.getStreams())
|
||||||
|
|
||||||
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
checkExpiring: listenStreams.len == 10 and dialStreams.len == 10
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,8 @@ suite "Yamux":
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await streamA.writeLp(fromHex("1234"))
|
await streamA.writeLp(fromHex("1234"))
|
||||||
check (await streamA.readLp(100)) == fromHex("5678")
|
check (await streamA.readLp(100)) == fromHex("5678")
|
||||||
await streamA.close()
|
await streamA.close()
|
||||||
|
@ -53,6 +55,8 @@ suite "Yamux":
|
||||||
handlerBlocker.complete()
|
handlerBlocker.complete()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||||
await streamA.close()
|
await streamA.close()
|
||||||
readerBlocker.complete()
|
readerBlocker.complete()
|
||||||
|
@ -68,7 +72,10 @@ suite "Yamux":
|
||||||
var buffer: array[160000, byte]
|
var buffer: array[160000, byte]
|
||||||
discard await conn.readOnce(addr buffer[0], 160000)
|
discard await conn.readOnce(addr buffer[0], 160000)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||||
|
|
||||||
let secondWriter = streamA.write(newSeq[byte](20))
|
let secondWriter = streamA.write(newSeq[byte](20))
|
||||||
|
@ -88,7 +95,10 @@ suite "Yamux":
|
||||||
var buffer: array[160000, byte]
|
var buffer: array[160000, byte]
|
||||||
discard await conn.readOnce(addr buffer[0], 160000)
|
discard await conn.readOnce(addr buffer[0], 160000)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||||
|
|
||||||
let secondWriter = streamA.write(newSeq[byte](20))
|
let secondWriter = streamA.write(newSeq[byte](20))
|
||||||
|
@ -123,7 +133,10 @@ suite "Yamux":
|
||||||
numberOfRead.inc()
|
numberOfRead.inc()
|
||||||
writerBlocker.complete()
|
writerBlocker.complete()
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
# Need to exhaust initial window first
|
# Need to exhaust initial window first
|
||||||
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
await wait(streamA.write(newSeq[byte](256000)), 1.seconds) # shouldn't block
|
||||||
await streamA.write(newSeq[byte](142))
|
await streamA.write(newSeq[byte](142))
|
||||||
|
@ -144,6 +157,8 @@ suite "Yamux":
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await streamA.write(newSeq[byte](256000))
|
await streamA.write(newSeq[byte](256000))
|
||||||
let wrFut = collect(newSeq):
|
let wrFut = collect(newSeq):
|
||||||
for _ in 0..3:
|
for _ in 0..3:
|
||||||
|
@ -164,6 +179,8 @@ suite "Yamux":
|
||||||
check (await conn.readLp(100)) == fromHex("5678")
|
check (await conn.readLp(100)) == fromHex("5678")
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await streamA.writeLp(fromHex("1234"))
|
await streamA.writeLp(fromHex("1234"))
|
||||||
expect LPStreamRemoteClosedError: discard await streamA.readLp(100)
|
expect LPStreamRemoteClosedError: discard await streamA.readLp(100)
|
||||||
await streamA.writeLp(fromHex("5678"))
|
await streamA.writeLp(fromHex("5678"))
|
||||||
|
@ -180,6 +197,8 @@ suite "Yamux":
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
let streamA = await yamuxa.newStream()
|
let streamA = await yamuxa.newStream()
|
||||||
|
check streamA == yamuxa.getStreams()[0]
|
||||||
|
|
||||||
await yamuxa.close()
|
await yamuxa.close()
|
||||||
expect LPStreamClosedError: await streamA.writeLp(fromHex("1234"))
|
expect LPStreamClosedError: await streamA.writeLp(fromHex("1234"))
|
||||||
expect LPStreamClosedError: discard await streamA.readLp(100)
|
expect LPStreamClosedError: discard await streamA.readLp(100)
|
||||||
|
|
Loading…
Reference in New Issue