don't track Connection, track StreamTransport (#177)
* don't track Connection, track StreamTransport * make tests more deterministic
This commit is contained in:
parent
5583168965
commit
773b738c12
|
@ -25,6 +25,7 @@ const
|
||||||
type
|
type
|
||||||
TcpTransport* = ref object of Transport
|
TcpTransport* = ref object of Transport
|
||||||
server*: StreamServer
|
server*: StreamServer
|
||||||
|
clients: seq[StreamTransport]
|
||||||
cleanups*: seq[Future[void]]
|
cleanups*: seq[Future[void]]
|
||||||
handlers*: seq[Future[void]]
|
handlers*: seq[Future[void]]
|
||||||
|
|
||||||
|
@ -56,11 +57,6 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
|
||||||
result.isLeaked = leakTransport
|
result.isLeaked = leakTransport
|
||||||
addTracker(TcpTransportTrackerName, result)
|
addTracker(TcpTransportTrackerName, result)
|
||||||
|
|
||||||
proc cleanup(t: Transport, conn: Connection) {.async.} =
|
|
||||||
await conn.closeEvent.wait()
|
|
||||||
trace "connection cleanup event wait ended"
|
|
||||||
t.connections.keepItIf(it != conn)
|
|
||||||
|
|
||||||
proc connHandler*(t: TcpTransport,
|
proc connHandler*(t: TcpTransport,
|
||||||
client: StreamTransport,
|
client: StreamTransport,
|
||||||
initiator: bool): Connection =
|
initiator: bool): Connection =
|
||||||
|
@ -71,10 +67,7 @@ proc connHandler*(t: TcpTransport,
|
||||||
if not isNil(t.handler):
|
if not isNil(t.handler):
|
||||||
t.handlers &= t.handler(conn)
|
t.handlers &= t.handler(conn)
|
||||||
|
|
||||||
# TODO: store the streamtransport client here
|
t.clients.add(client)
|
||||||
t.connections.add(conn)
|
|
||||||
t.cleanups &= t.cleanup(conn)
|
|
||||||
|
|
||||||
result = conn
|
result = conn
|
||||||
|
|
||||||
proc connCb(server: StreamServer,
|
proc connCb(server: StreamServer,
|
||||||
|
@ -108,6 +101,9 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||||
trace "stopping transport"
|
trace "stopping transport"
|
||||||
await procCall Transport(t).close() # call base
|
await procCall Transport(t).close() # call base
|
||||||
|
|
||||||
|
checkFutures(await allFinished(
|
||||||
|
t.clients.mapIt(it.closeWait())))
|
||||||
|
|
||||||
# server can be nil
|
# server can be nil
|
||||||
if not isNil(t.server):
|
if not isNil(t.server):
|
||||||
t.server.stop()
|
t.server.stop()
|
||||||
|
@ -118,15 +114,13 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||||
for fut in t.handlers:
|
for fut in t.handlers:
|
||||||
if not fut.finished:
|
if not fut.finished:
|
||||||
fut.cancel()
|
fut.cancel()
|
||||||
t.handlers = await allFinished(t.handlers)
|
checkFutures(await allFinished(t.handlers))
|
||||||
checkFutures(t.handlers)
|
|
||||||
t.handlers = @[]
|
t.handlers = @[]
|
||||||
|
|
||||||
for fut in t.cleanups:
|
for fut in t.cleanups:
|
||||||
if not fut.finished:
|
if not fut.finished:
|
||||||
fut.cancel()
|
fut.cancel()
|
||||||
t.cleanups = await allFinished(t.cleanups)
|
checkFutures(await allFinished(t.cleanups))
|
||||||
checkFutures(t.cleanups)
|
|
||||||
t.cleanups = @[]
|
t.cleanups = @[]
|
||||||
|
|
||||||
trace "transport stopped"
|
trace "transport stopped"
|
||||||
|
|
|
@ -24,7 +24,6 @@ type
|
||||||
|
|
||||||
Transport* = ref object of RootObj
|
Transport* = ref object of RootObj
|
||||||
ma*: Multiaddress
|
ma*: Multiaddress
|
||||||
connections*: seq[Connection]
|
|
||||||
handler*: ConnHandler
|
handler*: ConnHandler
|
||||||
multicodec*: MultiCodec
|
multicodec*: MultiCodec
|
||||||
flags*: TransportFlags
|
flags*: TransportFlags
|
||||||
|
@ -49,8 +48,7 @@ proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsa
|
||||||
method close*(t: Transport) {.base, async, gcsafe.} =
|
method close*(t: Transport) {.base, async, gcsafe.} =
|
||||||
## stop and cleanup the transport
|
## stop and cleanup the transport
|
||||||
## including all outstanding connections
|
## including all outstanding connections
|
||||||
let futs = await allFinished(t.connections.mapIt(it.close()))
|
discard
|
||||||
checkFutures(futs)
|
|
||||||
|
|
||||||
method listen*(t: Transport,
|
method listen*(t: Transport,
|
||||||
ma: MultiAddress,
|
ma: MultiAddress,
|
||||||
|
|
|
@ -31,7 +31,7 @@ suite "Mplex":
|
||||||
|
|
||||||
let stream = newBufferStream(encHandler)
|
let stream = newBufferStream(encHandler)
|
||||||
let conn = newConnection(stream)
|
let conn = newConnection(stream)
|
||||||
await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1"))
|
await conn.writeMsg(0, MessageType.New, ("stream 1").toBytes)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testEncodeHeader())
|
waitFor(testEncodeHeader())
|
||||||
|
@ -43,7 +43,7 @@ suite "Mplex":
|
||||||
|
|
||||||
let stream = newBufferStream(encHandler)
|
let stream = newBufferStream(encHandler)
|
||||||
let conn = newConnection(stream)
|
let conn = newConnection(stream)
|
||||||
await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1"))
|
await conn.writeMsg(17, MessageType.New, ("stream 1").toBytes)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testEncodeHeader())
|
waitFor(testEncodeHeader())
|
||||||
|
@ -55,7 +55,7 @@ suite "Mplex":
|
||||||
|
|
||||||
let stream = newBufferStream(encHandler)
|
let stream = newBufferStream(encHandler)
|
||||||
let conn = newConnection(stream)
|
let conn = newConnection(stream)
|
||||||
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
|
await conn.writeMsg(0, MessageType.MsgOut, ("stream 1").toBytes)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testEncodeHeaderBody())
|
waitFor(testEncodeHeaderBody())
|
||||||
|
@ -67,7 +67,7 @@ suite "Mplex":
|
||||||
|
|
||||||
let stream = newBufferStream(encHandler)
|
let stream = newBufferStream(encHandler)
|
||||||
let conn = newConnection(stream)
|
let conn = newConnection(stream)
|
||||||
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
|
await conn.writeMsg(17, MessageType.MsgOut, ("stream 1").toBytes)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testEncodeHeaderBody())
|
waitFor(testEncodeHeaderBody())
|
||||||
|
@ -94,7 +94,7 @@ suite "Mplex":
|
||||||
|
|
||||||
check msg.id == 0
|
check msg.id == 0
|
||||||
check msg.msgType == MessageType.MsgOut
|
check msg.msgType == MessageType.MsgOut
|
||||||
check cast[string](msg.data) == "hello from channel 0!!"
|
check string.fromBytes(msg.data) == "hello from channel 0!!"
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testDecodeHeader())
|
waitFor(testDecodeHeader())
|
||||||
|
@ -108,29 +108,31 @@ suite "Mplex":
|
||||||
|
|
||||||
check msg.id == 17
|
check msg.id == 17
|
||||||
check msg.msgType == MessageType.MsgOut
|
check msg.msgType == MessageType.MsgOut
|
||||||
check cast[string](msg.data) == "hello from channel 0!!"
|
check string.fromBytes(msg.data) == "hello from channel 0!!"
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
waitFor(testDecodeHeader())
|
waitFor(testDecodeHeader())
|
||||||
|
|
||||||
test "half closed - channel should close for write":
|
test "half closed - channel should close for write":
|
||||||
proc testClosedForWrite() {.async.} =
|
proc testClosedForWrite(): Future[bool] {.async.} =
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let
|
let
|
||||||
conn = newConnection(newBufferStream(writeHandler))
|
conn = newConnection(newBufferStream(writeHandler))
|
||||||
chann = newChannel(1, conn, true)
|
chann = newChannel(1, conn, true)
|
||||||
try:
|
|
||||||
await chann.close()
|
await chann.close()
|
||||||
|
try:
|
||||||
await chann.write("Hello")
|
await chann.write("Hello")
|
||||||
|
except LPStreamEOFError:
|
||||||
|
result = true
|
||||||
finally:
|
finally:
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
expect LPStreamEOFError:
|
check:
|
||||||
waitFor(testClosedForWrite())
|
waitFor(testClosedForWrite()) == true
|
||||||
|
|
||||||
test "half closed - channel should close for read by remote":
|
test "half closed - channel should close for read by remote":
|
||||||
proc testClosedForRead() {.async.} =
|
proc testClosedForRead(): Future[bool] {.async.} =
|
||||||
let
|
let
|
||||||
conn = newConnection(newBufferStream(
|
conn = newConnection(newBufferStream(
|
||||||
proc (data: seq[byte]) {.gcsafe, async.} =
|
proc (data: seq[byte]) {.gcsafe, async.} =
|
||||||
|
@ -138,69 +140,77 @@ suite "Mplex":
|
||||||
))
|
))
|
||||||
chann = newChannel(1, conn, true)
|
chann = newChannel(1, conn, true)
|
||||||
|
|
||||||
try:
|
await chann.pushTo(("Hello!").toBytes)
|
||||||
await chann.pushTo(cast[seq[byte]]("Hello!"))
|
|
||||||
let closeFut = chann.closedByRemote()
|
let closeFut = chann.closedByRemote()
|
||||||
|
|
||||||
var data = newSeq[byte](6)
|
var data = newSeq[byte](6)
|
||||||
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
|
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
|
||||||
|
try:
|
||||||
await chann.readExactly(addr data[0], 6) # this should throw
|
await chann.readExactly(addr data[0], 6) # this should throw
|
||||||
await closeFut
|
await closeFut
|
||||||
|
except LPStreamEOFError:
|
||||||
|
result = true
|
||||||
finally:
|
finally:
|
||||||
await chann.close()
|
await chann.close()
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
expect LPStreamEOFError:
|
check:
|
||||||
waitFor(testClosedForRead())
|
waitFor(testClosedForRead()) == true
|
||||||
|
|
||||||
test "should not allow pushing data to channel when remote end closed":
|
test "should not allow pushing data to channel when remote end closed":
|
||||||
proc testResetWrite() {.async.} =
|
proc testResetWrite(): Future[bool] {.async.} =
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let
|
let
|
||||||
conn = newConnection(newBufferStream(writeHandler))
|
conn = newConnection(newBufferStream(writeHandler))
|
||||||
chann = newChannel(1, conn, true)
|
chann = newChannel(1, conn, true)
|
||||||
try:
|
|
||||||
await chann.closedByRemote()
|
await chann.closedByRemote()
|
||||||
|
try:
|
||||||
await chann.pushTo(@[byte(1)])
|
await chann.pushTo(@[byte(1)])
|
||||||
|
except LPStreamEOFError:
|
||||||
|
result = true
|
||||||
finally:
|
finally:
|
||||||
await chann.close()
|
await chann.close()
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
expect LPStreamEOFError:
|
check:
|
||||||
waitFor(testResetWrite())
|
waitFor(testResetWrite()) == true
|
||||||
|
|
||||||
test "reset - channel should fail reading":
|
test "reset - channel should fail reading":
|
||||||
proc testResetRead() {.async.} =
|
proc testResetRead(): Future[bool] {.async.} =
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let
|
let
|
||||||
conn = newConnection(newBufferStream(writeHandler))
|
conn = newConnection(newBufferStream(writeHandler))
|
||||||
chann = newChannel(1, conn, true)
|
chann = newChannel(1, conn, true)
|
||||||
|
|
||||||
try:
|
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
var data = newSeq[byte](1)
|
var data = newSeq[byte](1)
|
||||||
|
try:
|
||||||
await chann.readExactly(addr data[0], 1)
|
await chann.readExactly(addr data[0], 1)
|
||||||
doAssert(len(data) == 1)
|
doAssert(len(data) == 1)
|
||||||
|
except LPStreamEOFError:
|
||||||
|
result = true
|
||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
expect LPStreamEOFError:
|
check:
|
||||||
waitFor(testResetRead())
|
waitFor(testResetRead()) == true
|
||||||
|
|
||||||
test "reset - channel should fail writing":
|
test "reset - channel should fail writing":
|
||||||
proc testResetWrite() {.async.} =
|
proc testResetWrite(): Future[bool] {.async.} =
|
||||||
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
|
||||||
let
|
let
|
||||||
conn = newConnection(newBufferStream(writeHandler))
|
conn = newConnection(newBufferStream(writeHandler))
|
||||||
chann = newChannel(1, conn, true)
|
chann = newChannel(1, conn, true)
|
||||||
try:
|
|
||||||
await chann.reset()
|
await chann.reset()
|
||||||
await chann.write(cast[seq[byte]]("Hello!"))
|
try:
|
||||||
|
await chann.write(("Hello!").toBytes)
|
||||||
|
except LPStreamEOFError:
|
||||||
|
result = true
|
||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
expect LPStreamEOFError:
|
check:
|
||||||
waitFor(testResetWrite())
|
waitFor(testResetWrite()) == true
|
||||||
|
|
||||||
test "e2e - read/write receiver":
|
test "e2e - read/write receiver":
|
||||||
proc testNewStream() {.async.} =
|
proc testNewStream() {.async.} =
|
||||||
|
@ -212,7 +222,7 @@ suite "Mplex":
|
||||||
mplexListen.streamHandler = proc(stream: Connection)
|
mplexListen.streamHandler = proc(stream: Connection)
|
||||||
{.async, gcsafe.} =
|
{.async, gcsafe.} =
|
||||||
let msg = await stream.readLp(1024)
|
let msg = await stream.readLp(1024)
|
||||||
check cast[string](msg) == "HELLO"
|
check string.fromBytes(msg) == "HELLO"
|
||||||
await stream.close()
|
await stream.close()
|
||||||
done.complete()
|
done.complete()
|
||||||
|
|
||||||
|
@ -250,7 +260,7 @@ suite "Mplex":
|
||||||
mplexListen.streamHandler = proc(stream: Connection)
|
mplexListen.streamHandler = proc(stream: Connection)
|
||||||
{.async, gcsafe.} =
|
{.async, gcsafe.} =
|
||||||
let msg = await stream.readLp(1024)
|
let msg = await stream.readLp(1024)
|
||||||
check cast[string](msg) == "HELLO"
|
check string.fromBytes(msg) == "HELLO"
|
||||||
await stream.close()
|
await stream.close()
|
||||||
done.complete()
|
done.complete()
|
||||||
|
|
||||||
|
@ -351,7 +361,7 @@ suite "Mplex":
|
||||||
let mplexDial = newMplex(conn)
|
let mplexDial = newMplex(conn)
|
||||||
let mplexDialFut = mplexDial.handle()
|
let mplexDialFut = mplexDial.handle()
|
||||||
let stream = await mplexDial.newStream("DIALER")
|
let stream = await mplexDial.newStream("DIALER")
|
||||||
let msg = cast[string](await stream.readLp(1024))
|
let msg = string.fromBytes(await stream.readLp(1024))
|
||||||
await stream.close()
|
await stream.close()
|
||||||
check msg == "Hello from stream!"
|
check msg == "Hello from stream!"
|
||||||
|
|
||||||
|
@ -416,7 +426,7 @@ suite "Mplex":
|
||||||
mplexListen.streamHandler = proc(stream: Connection)
|
mplexListen.streamHandler = proc(stream: Connection)
|
||||||
{.async, gcsafe.} =
|
{.async, gcsafe.} =
|
||||||
let msg = await stream.readLp(1024)
|
let msg = await stream.readLp(1024)
|
||||||
check cast[string](msg) == &"stream {count} from dialer!"
|
check string.fromBytes(msg) == &"stream {count} from dialer!"
|
||||||
await stream.writeLp(&"stream {count} from listener!")
|
await stream.writeLp(&"stream {count} from listener!")
|
||||||
count.inc
|
count.inc
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
@ -438,7 +448,7 @@ suite "Mplex":
|
||||||
let stream = await mplexDial.newStream("dialer stream")
|
let stream = await mplexDial.newStream("dialer stream")
|
||||||
await stream.writeLp(&"stream {i} from dialer!")
|
await stream.writeLp(&"stream {i} from dialer!")
|
||||||
let msg = await stream.readLp(1024)
|
let msg = await stream.readLp(1024)
|
||||||
check cast[string](msg) == &"stream {i} from listener!"
|
check string.fromBytes(msg) == &"stream {i} from listener!"
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
await done.wait(5.seconds)
|
await done.wait(5.seconds)
|
||||||
|
|
Loading…
Reference in New Issue