mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2025-01-11 09:16:15 +00:00
don't track Connection, track StreamTransport (#177)
* don't track Connection, track StreamTransport * make tests more deterministic
This commit is contained in:
parent
5583168965
commit
773b738c12
@ -25,6 +25,7 @@ const
|
||||
type
|
||||
TcpTransport* = ref object of Transport
|
||||
server*: StreamServer
|
||||
clients: seq[StreamTransport]
|
||||
cleanups*: seq[Future[void]]
|
||||
handlers*: seq[Future[void]]
|
||||
|
||||
@ -56,11 +57,6 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
|
||||
result.isLeaked = leakTransport
|
||||
addTracker(TcpTransportTrackerName, result)
|
||||
|
||||
proc cleanup(t: Transport, conn: Connection) {.async.} =
|
||||
await conn.closeEvent.wait()
|
||||
trace "connection cleanup event wait ended"
|
||||
t.connections.keepItIf(it != conn)
|
||||
|
||||
proc connHandler*(t: TcpTransport,
|
||||
client: StreamTransport,
|
||||
initiator: bool): Connection =
|
||||
@ -71,10 +67,7 @@ proc connHandler*(t: TcpTransport,
|
||||
if not isNil(t.handler):
|
||||
t.handlers &= t.handler(conn)
|
||||
|
||||
# TODO: store the streamtransport client here
|
||||
t.connections.add(conn)
|
||||
t.cleanups &= t.cleanup(conn)
|
||||
|
||||
t.clients.add(client)
|
||||
result = conn
|
||||
|
||||
proc connCb(server: StreamServer,
|
||||
@ -108,6 +101,9 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||
trace "stopping transport"
|
||||
await procCall Transport(t).close() # call base
|
||||
|
||||
checkFutures(await allFinished(
|
||||
t.clients.mapIt(it.closeWait())))
|
||||
|
||||
# server can be nil
|
||||
if not isNil(t.server):
|
||||
t.server.stop()
|
||||
@ -118,15 +114,13 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||
for fut in t.handlers:
|
||||
if not fut.finished:
|
||||
fut.cancel()
|
||||
t.handlers = await allFinished(t.handlers)
|
||||
checkFutures(t.handlers)
|
||||
checkFutures(await allFinished(t.handlers))
|
||||
t.handlers = @[]
|
||||
|
||||
for fut in t.cleanups:
|
||||
if not fut.finished:
|
||||
fut.cancel()
|
||||
t.cleanups = await allFinished(t.cleanups)
|
||||
checkFutures(t.cleanups)
|
||||
checkFutures(await allFinished(t.cleanups))
|
||||
t.cleanups = @[]
|
||||
|
||||
trace "transport stopped"
|
||||
|
@ -24,7 +24,6 @@ type
|
||||
|
||||
Transport* = ref object of RootObj
|
||||
ma*: Multiaddress
|
||||
connections*: seq[Connection]
|
||||
handler*: ConnHandler
|
||||
multicodec*: MultiCodec
|
||||
flags*: TransportFlags
|
||||
@ -49,8 +48,7 @@ proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsa
|
||||
method close*(t: Transport) {.base, async, gcsafe.} =
|
||||
## stop and cleanup the transport
|
||||
## including all outstanding connections
|
||||
let futs = await allFinished(t.connections.mapIt(it.close()))
|
||||
checkFutures(futs)
|
||||
discard
|
||||
|
||||
method listen*(t: Transport,
|
||||
ma: MultiAddress,
|
||||
|
@ -31,7 +31,7 @@ suite "Mplex":
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1"))
|
||||
await conn.writeMsg(0, MessageType.New, ("stream 1").toBytes)
|
||||
await conn.close()
|
||||
|
||||
waitFor(testEncodeHeader())
|
||||
@ -43,7 +43,7 @@ suite "Mplex":
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1"))
|
||||
await conn.writeMsg(17, MessageType.New, ("stream 1").toBytes)
|
||||
await conn.close()
|
||||
|
||||
waitFor(testEncodeHeader())
|
||||
@ -55,7 +55,7 @@ suite "Mplex":
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
|
||||
await conn.writeMsg(0, MessageType.MsgOut, ("stream 1").toBytes)
|
||||
await conn.close()
|
||||
|
||||
waitFor(testEncodeHeaderBody())
|
||||
@ -67,7 +67,7 @@ suite "Mplex":
|
||||
|
||||
let stream = newBufferStream(encHandler)
|
||||
let conn = newConnection(stream)
|
||||
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
|
||||
await conn.writeMsg(17, MessageType.MsgOut, ("stream 1").toBytes)
|
||||
await conn.close()
|
||||
|
||||
waitFor(testEncodeHeaderBody())
|
||||
@ -94,7 +94,7 @@ suite "Mplex":
|
||||
|
||||
check msg.id == 0
|
||||
check msg.msgType == MessageType.MsgOut
|
||||
check cast[string](msg.data) == "hello from channel 0!!"
|
||||
check string.fromBytes(msg.data) == "hello from channel 0!!"
|
||||
await conn.close()
|
||||
|
||||
waitFor(testDecodeHeader())
|
||||
@ -108,29 +108,31 @@ suite "Mplex":
|
||||
|
||||
check msg.id == 17
|
||||
check msg.msgType == MessageType.MsgOut
|
||||
check cast[string](msg.data) == "hello from channel 0!!"
|
||||
check string.fromBytes(msg.data) == "hello from channel 0!!"
|
||||
await conn.close()
|
||||
|
||||
waitFor(testDecodeHeader())
|
||||
|
||||
test "half closed - channel should close for write":
|
||||
proc testClosedForWrite() {.async.} =
|
||||
proc testClosedForWrite(): Future[bool] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newConnection(newBufferStream(writeHandler))
|
||||
chann = newChannel(1, conn, true)
|
||||
await chann.close()
|
||||
try:
|
||||
await chann.close()
|
||||
await chann.write("Hello")
|
||||
except LPStreamEOFError:
|
||||
result = true
|
||||
finally:
|
||||
await chann.reset()
|
||||
await conn.close()
|
||||
|
||||
expect LPStreamEOFError:
|
||||
waitFor(testClosedForWrite())
|
||||
check:
|
||||
waitFor(testClosedForWrite()) == true
|
||||
|
||||
test "half closed - channel should close for read by remote":
|
||||
proc testClosedForRead() {.async.} =
|
||||
proc testClosedForRead(): Future[bool] {.async.} =
|
||||
let
|
||||
conn = newConnection(newBufferStream(
|
||||
proc (data: seq[byte]) {.gcsafe, async.} =
|
||||
@ -138,69 +140,77 @@ suite "Mplex":
|
||||
))
|
||||
chann = newChannel(1, conn, true)
|
||||
|
||||
try:
|
||||
await chann.pushTo(cast[seq[byte]]("Hello!"))
|
||||
let closeFut = chann.closedByRemote()
|
||||
await chann.pushTo(("Hello!").toBytes)
|
||||
let closeFut = chann.closedByRemote()
|
||||
|
||||
var data = newSeq[byte](6)
|
||||
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
|
||||
var data = newSeq[byte](6)
|
||||
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
|
||||
try:
|
||||
await chann.readExactly(addr data[0], 6) # this should throw
|
||||
await closeFut
|
||||
except LPStreamEOFError:
|
||||
result = true
|
||||
finally:
|
||||
await chann.close()
|
||||
await conn.close()
|
||||
|
||||
expect LPStreamEOFError:
|
||||
waitFor(testClosedForRead())
|
||||
check:
|
||||
waitFor(testClosedForRead()) == true
|
||||
|
||||
test "should not allow pushing data to channel when remote end closed":
|
||||
proc testResetWrite() {.async.} =
|
||||
proc testResetWrite(): Future[bool] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newConnection(newBufferStream(writeHandler))
|
||||
chann = newChannel(1, conn, true)
|
||||
await chann.closedByRemote()
|
||||
try:
|
||||
await chann.closedByRemote()
|
||||
await chann.pushTo(@[byte(1)])
|
||||
except LPStreamEOFError:
|
||||
result = true
|
||||
finally:
|
||||
await chann.close()
|
||||
await conn.close()
|
||||
|
||||
expect LPStreamEOFError:
|
||||
waitFor(testResetWrite())
|
||||
check:
|
||||
waitFor(testResetWrite()) == true
|
||||
|
||||
test "reset - channel should fail reading":
|
||||
proc testResetRead() {.async.} =
|
||||
proc testResetRead(): Future[bool] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newConnection(newBufferStream(writeHandler))
|
||||
chann = newChannel(1, conn, true)
|
||||
|
||||
await chann.reset()
|
||||
var data = newSeq[byte](1)
|
||||
try:
|
||||
await chann.reset()
|
||||
var data = newSeq[byte](1)
|
||||
await chann.readExactly(addr data[0], 1)
|
||||
doAssert(len(data) == 1)
|
||||
except LPStreamEOFError:
|
||||
result = true
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
expect LPStreamEOFError:
|
||||
waitFor(testResetRead())
|
||||
check:
|
||||
waitFor(testResetRead()) == true
|
||||
|
||||
test "reset - channel should fail writing":
|
||||
proc testResetWrite() {.async.} =
|
||||
proc testResetWrite(): Future[bool] {.async.} =
|
||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||
let
|
||||
conn = newConnection(newBufferStream(writeHandler))
|
||||
chann = newChannel(1, conn, true)
|
||||
await chann.reset()
|
||||
try:
|
||||
await chann.reset()
|
||||
await chann.write(cast[seq[byte]]("Hello!"))
|
||||
await chann.write(("Hello!").toBytes)
|
||||
except LPStreamEOFError:
|
||||
result = true
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
expect LPStreamEOFError:
|
||||
waitFor(testResetWrite())
|
||||
check:
|
||||
waitFor(testResetWrite()) == true
|
||||
|
||||
test "e2e - read/write receiver":
|
||||
proc testNewStream() {.async.} =
|
||||
@ -212,7 +222,7 @@ suite "Mplex":
|
||||
mplexListen.streamHandler = proc(stream: Connection)
|
||||
{.async, gcsafe.} =
|
||||
let msg = await stream.readLp(1024)
|
||||
check cast[string](msg) == "HELLO"
|
||||
check string.fromBytes(msg) == "HELLO"
|
||||
await stream.close()
|
||||
done.complete()
|
||||
|
||||
@ -250,7 +260,7 @@ suite "Mplex":
|
||||
mplexListen.streamHandler = proc(stream: Connection)
|
||||
{.async, gcsafe.} =
|
||||
let msg = await stream.readLp(1024)
|
||||
check cast[string](msg) == "HELLO"
|
||||
check string.fromBytes(msg) == "HELLO"
|
||||
await stream.close()
|
||||
done.complete()
|
||||
|
||||
@ -351,7 +361,7 @@ suite "Mplex":
|
||||
let mplexDial = newMplex(conn)
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
let stream = await mplexDial.newStream("DIALER")
|
||||
let msg = cast[string](await stream.readLp(1024))
|
||||
let msg = string.fromBytes(await stream.readLp(1024))
|
||||
await stream.close()
|
||||
check msg == "Hello from stream!"
|
||||
|
||||
@ -416,7 +426,7 @@ suite "Mplex":
|
||||
mplexListen.streamHandler = proc(stream: Connection)
|
||||
{.async, gcsafe.} =
|
||||
let msg = await stream.readLp(1024)
|
||||
check cast[string](msg) == &"stream {count} from dialer!"
|
||||
check string.fromBytes(msg) == &"stream {count} from dialer!"
|
||||
await stream.writeLp(&"stream {count} from listener!")
|
||||
count.inc
|
||||
await stream.close()
|
||||
@ -438,7 +448,7 @@ suite "Mplex":
|
||||
let stream = await mplexDial.newStream("dialer stream")
|
||||
await stream.writeLp(&"stream {i} from dialer!")
|
||||
let msg = await stream.readLp(1024)
|
||||
check cast[string](msg) == &"stream {i} from listener!"
|
||||
check string.fromBytes(msg) == &"stream {i} from listener!"
|
||||
await stream.close()
|
||||
|
||||
await done.wait(5.seconds)
|
||||
|
Loading…
x
Reference in New Issue
Block a user