don't track Connection, track StreamTransport (#177)

* don't track Connection, track StreamTransport

* make tests more deterministic
This commit is contained in:
Dmitriy Ryajov 2020-05-18 11:05:34 -06:00 committed by GitHub
parent 5583168965
commit 773b738c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 53 deletions

View File

@ -25,6 +25,7 @@ const
type
TcpTransport* = ref object of Transport
server*: StreamServer
clients: seq[StreamTransport]
cleanups*: seq[Future[void]]
handlers*: seq[Future[void]]
@ -56,11 +57,6 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
result.isLeaked = leakTransport
addTracker(TcpTransportTrackerName, result)
proc cleanup(t: Transport, conn: Connection) {.async.} =
await conn.closeEvent.wait()
trace "connection cleanup event wait ended"
t.connections.keepItIf(it != conn)
proc connHandler*(t: TcpTransport,
client: StreamTransport,
initiator: bool): Connection =
@ -71,10 +67,7 @@ proc connHandler*(t: TcpTransport,
if not isNil(t.handler):
t.handlers &= t.handler(conn)
# TODO: store the streamtransport client here
t.connections.add(conn)
t.cleanups &= t.cleanup(conn)
t.clients.add(client)
result = conn
proc connCb(server: StreamServer,
@ -108,6 +101,9 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
trace "stopping transport"
await procCall Transport(t).close() # call base
checkFutures(await allFinished(
t.clients.mapIt(it.closeWait())))
# server can be nil
if not isNil(t.server):
t.server.stop()
@ -118,15 +114,13 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
for fut in t.handlers:
if not fut.finished:
fut.cancel()
t.handlers = await allFinished(t.handlers)
checkFutures(t.handlers)
checkFutures(await allFinished(t.handlers))
t.handlers = @[]
for fut in t.cleanups:
if not fut.finished:
fut.cancel()
t.cleanups = await allFinished(t.cleanups)
checkFutures(t.cleanups)
checkFutures(await allFinished(t.cleanups))
t.cleanups = @[]
trace "transport stopped"

View File

@ -24,7 +24,6 @@ type
Transport* = ref object of RootObj
ma*: Multiaddress
connections*: seq[Connection]
handler*: ConnHandler
multicodec*: MultiCodec
flags*: TransportFlags
@ -49,8 +48,7 @@ proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsa
method close*(t: Transport) {.base, async, gcsafe.} =
## stop and cleanup the transport
## including all outstanding connections
let futs = await allFinished(t.connections.mapIt(it.close()))
checkFutures(futs)
discard
method listen*(t: Transport,
ma: MultiAddress,

View File

@ -31,7 +31,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler)
let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1"))
await conn.writeMsg(0, MessageType.New, ("stream 1").toBytes)
await conn.close()
waitFor(testEncodeHeader())
@ -43,7 +43,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler)
let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1"))
await conn.writeMsg(17, MessageType.New, ("stream 1").toBytes)
await conn.close()
waitFor(testEncodeHeader())
@ -55,7 +55,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler)
let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
await conn.writeMsg(0, MessageType.MsgOut, ("stream 1").toBytes)
await conn.close()
waitFor(testEncodeHeaderBody())
@ -67,7 +67,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler)
let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
await conn.writeMsg(17, MessageType.MsgOut, ("stream 1").toBytes)
await conn.close()
waitFor(testEncodeHeaderBody())
@ -94,7 +94,7 @@ suite "Mplex":
check msg.id == 0
check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!"
check string.fromBytes(msg.data) == "hello from channel 0!!"
await conn.close()
waitFor(testDecodeHeader())
@ -108,29 +108,31 @@ suite "Mplex":
check msg.id == 17
check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!"
check string.fromBytes(msg.data) == "hello from channel 0!!"
await conn.close()
waitFor(testDecodeHeader())
test "half closed - channel should close for write":
proc testClosedForWrite() {.async.} =
proc testClosedForWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true)
await chann.close()
try:
await chann.close()
await chann.write("Hello")
except LPStreamEOFError:
result = true
finally:
await chann.reset()
await conn.close()
expect LPStreamEOFError:
waitFor(testClosedForWrite())
check:
waitFor(testClosedForWrite()) == true
test "half closed - channel should close for read by remote":
proc testClosedForRead() {.async.} =
proc testClosedForRead(): Future[bool] {.async.} =
let
conn = newConnection(newBufferStream(
proc (data: seq[byte]) {.gcsafe, async.} =
@ -138,69 +140,77 @@ suite "Mplex":
))
chann = newChannel(1, conn, true)
try:
await chann.pushTo(cast[seq[byte]]("Hello!"))
let closeFut = chann.closedByRemote()
await chann.pushTo(("Hello!").toBytes)
let closeFut = chann.closedByRemote()
var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
try:
await chann.readExactly(addr data[0], 6) # this should throw
await closeFut
except LPStreamEOFError:
result = true
finally:
await chann.close()
await conn.close()
expect LPStreamEOFError:
waitFor(testClosedForRead())
check:
waitFor(testClosedForRead()) == true
test "should not allow pushing data to channel when remote end closed":
proc testResetWrite() {.async.} =
proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true)
await chann.closedByRemote()
try:
await chann.closedByRemote()
await chann.pushTo(@[byte(1)])
except LPStreamEOFError:
result = true
finally:
await chann.close()
await conn.close()
expect LPStreamEOFError:
waitFor(testResetWrite())
check:
waitFor(testResetWrite()) == true
test "reset - channel should fail reading":
proc testResetRead() {.async.} =
proc testResetRead(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true)
await chann.reset()
var data = newSeq[byte](1)
try:
await chann.reset()
var data = newSeq[byte](1)
await chann.readExactly(addr data[0], 1)
doAssert(len(data) == 1)
except LPStreamEOFError:
result = true
finally:
await conn.close()
expect LPStreamEOFError:
waitFor(testResetRead())
check:
waitFor(testResetRead()) == true
test "reset - channel should fail writing":
proc testResetWrite() {.async.} =
proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true)
await chann.reset()
try:
await chann.reset()
await chann.write(cast[seq[byte]]("Hello!"))
await chann.write(("Hello!").toBytes)
except LPStreamEOFError:
result = true
finally:
await conn.close()
expect LPStreamEOFError:
waitFor(testResetWrite())
check:
waitFor(testResetWrite()) == true
test "e2e - read/write receiver":
proc testNewStream() {.async.} =
@ -212,7 +222,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} =
let msg = await stream.readLp(1024)
check cast[string](msg) == "HELLO"
check string.fromBytes(msg) == "HELLO"
await stream.close()
done.complete()
@ -250,7 +260,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} =
let msg = await stream.readLp(1024)
check cast[string](msg) == "HELLO"
check string.fromBytes(msg) == "HELLO"
await stream.close()
done.complete()
@ -351,7 +361,7 @@ suite "Mplex":
let mplexDial = newMplex(conn)
let mplexDialFut = mplexDial.handle()
let stream = await mplexDial.newStream("DIALER")
let msg = cast[string](await stream.readLp(1024))
let msg = string.fromBytes(await stream.readLp(1024))
await stream.close()
check msg == "Hello from stream!"
@ -416,7 +426,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} =
let msg = await stream.readLp(1024)
check cast[string](msg) == &"stream {count} from dialer!"
check string.fromBytes(msg) == &"stream {count} from dialer!"
await stream.writeLp(&"stream {count} from listener!")
count.inc
await stream.close()
@ -438,7 +448,7 @@ suite "Mplex":
let stream = await mplexDial.newStream("dialer stream")
await stream.writeLp(&"stream {i} from dialer!")
let msg = await stream.readLp(1024)
check cast[string](msg) == &"stream {i} from listener!"
check string.fromBytes(msg) == &"stream {i} from listener!"
await stream.close()
await done.wait(5.seconds)