don't track Connection, track StreamTransport (#177)

* don't track Connection, track StreamTransport

* make tests more deterministic
This commit is contained in:
Dmitriy Ryajov 2020-05-18 11:05:34 -06:00 committed by GitHub
parent 5583168965
commit 773b738c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 53 deletions

View File

@ -25,6 +25,7 @@ const
type type
TcpTransport* = ref object of Transport TcpTransport* = ref object of Transport
server*: StreamServer server*: StreamServer
clients: seq[StreamTransport]
cleanups*: seq[Future[void]] cleanups*: seq[Future[void]]
handlers*: seq[Future[void]] handlers*: seq[Future[void]]
@ -56,11 +57,6 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
result.isLeaked = leakTransport result.isLeaked = leakTransport
addTracker(TcpTransportTrackerName, result) addTracker(TcpTransportTrackerName, result)
proc cleanup(t: Transport, conn: Connection) {.async.} =
await conn.closeEvent.wait()
trace "connection cleanup event wait ended"
t.connections.keepItIf(it != conn)
proc connHandler*(t: TcpTransport, proc connHandler*(t: TcpTransport,
client: StreamTransport, client: StreamTransport,
initiator: bool): Connection = initiator: bool): Connection =
@ -71,10 +67,7 @@ proc connHandler*(t: TcpTransport,
if not isNil(t.handler): if not isNil(t.handler):
t.handlers &= t.handler(conn) t.handlers &= t.handler(conn)
# TODO: store the streamtransport client here t.clients.add(client)
t.connections.add(conn)
t.cleanups &= t.cleanup(conn)
result = conn result = conn
proc connCb(server: StreamServer, proc connCb(server: StreamServer,
@ -108,6 +101,9 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
trace "stopping transport" trace "stopping transport"
await procCall Transport(t).close() # call base await procCall Transport(t).close() # call base
checkFutures(await allFinished(
t.clients.mapIt(it.closeWait())))
# server can be nil # server can be nil
if not isNil(t.server): if not isNil(t.server):
t.server.stop() t.server.stop()
@ -118,15 +114,13 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
for fut in t.handlers: for fut in t.handlers:
if not fut.finished: if not fut.finished:
fut.cancel() fut.cancel()
t.handlers = await allFinished(t.handlers) checkFutures(await allFinished(t.handlers))
checkFutures(t.handlers)
t.handlers = @[] t.handlers = @[]
for fut in t.cleanups: for fut in t.cleanups:
if not fut.finished: if not fut.finished:
fut.cancel() fut.cancel()
t.cleanups = await allFinished(t.cleanups) checkFutures(await allFinished(t.cleanups))
checkFutures(t.cleanups)
t.cleanups = @[] t.cleanups = @[]
trace "transport stopped" trace "transport stopped"

View File

@ -24,7 +24,6 @@ type
Transport* = ref object of RootObj Transport* = ref object of RootObj
ma*: Multiaddress ma*: Multiaddress
connections*: seq[Connection]
handler*: ConnHandler handler*: ConnHandler
multicodec*: MultiCodec multicodec*: MultiCodec
flags*: TransportFlags flags*: TransportFlags
@ -49,8 +48,7 @@ proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsa
method close*(t: Transport) {.base, async, gcsafe.} = method close*(t: Transport) {.base, async, gcsafe.} =
## stop and cleanup the transport ## stop and cleanup the transport
## including all outstanding connections ## including all outstanding connections
let futs = await allFinished(t.connections.mapIt(it.close())) discard
checkFutures(futs)
method listen*(t: Transport, method listen*(t: Transport,
ma: MultiAddress, ma: MultiAddress,

View File

@ -31,7 +31,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1")) await conn.writeMsg(0, MessageType.New, ("stream 1").toBytes)
await conn.close() await conn.close()
waitFor(testEncodeHeader()) waitFor(testEncodeHeader())
@ -43,7 +43,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1")) await conn.writeMsg(17, MessageType.New, ("stream 1").toBytes)
await conn.close() await conn.close()
waitFor(testEncodeHeader()) waitFor(testEncodeHeader())
@ -55,7 +55,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1")) await conn.writeMsg(0, MessageType.MsgOut, ("stream 1").toBytes)
await conn.close() await conn.close()
waitFor(testEncodeHeaderBody()) waitFor(testEncodeHeaderBody())
@ -67,7 +67,7 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1")) await conn.writeMsg(17, MessageType.MsgOut, ("stream 1").toBytes)
await conn.close() await conn.close()
waitFor(testEncodeHeaderBody()) waitFor(testEncodeHeaderBody())
@ -94,7 +94,7 @@ suite "Mplex":
check msg.id == 0 check msg.id == 0
check msg.msgType == MessageType.MsgOut check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!" check string.fromBytes(msg.data) == "hello from channel 0!!"
await conn.close() await conn.close()
waitFor(testDecodeHeader()) waitFor(testDecodeHeader())
@ -108,29 +108,31 @@ suite "Mplex":
check msg.id == 17 check msg.id == 17
check msg.msgType == MessageType.MsgOut check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!" check string.fromBytes(msg.data) == "hello from channel 0!!"
await conn.close() await conn.close()
waitFor(testDecodeHeader()) waitFor(testDecodeHeader())
test "half closed - channel should close for write": test "half closed - channel should close for write":
proc testClosedForWrite() {.async.} = proc testClosedForWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let
conn = newConnection(newBufferStream(writeHandler)) conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
try:
await chann.close() await chann.close()
try:
await chann.write("Hello") await chann.write("Hello")
except LPStreamEOFError:
result = true
finally: finally:
await chann.reset() await chann.reset()
await conn.close() await conn.close()
expect LPStreamEOFError: check:
waitFor(testClosedForWrite()) waitFor(testClosedForWrite()) == true
test "half closed - channel should close for read by remote": test "half closed - channel should close for read by remote":
proc testClosedForRead() {.async.} = proc testClosedForRead(): Future[bool] {.async.} =
let let
conn = newConnection(newBufferStream( conn = newConnection(newBufferStream(
proc (data: seq[byte]) {.gcsafe, async.} = proc (data: seq[byte]) {.gcsafe, async.} =
@ -138,69 +140,77 @@ suite "Mplex":
)) ))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
try: await chann.pushTo(("Hello!").toBytes)
await chann.pushTo(cast[seq[byte]]("Hello!"))
let closeFut = chann.closedByRemote() let closeFut = chann.closedByRemote()
var data = newSeq[byte](6) var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
try:
await chann.readExactly(addr data[0], 6) # this should throw await chann.readExactly(addr data[0], 6) # this should throw
await closeFut await closeFut
except LPStreamEOFError:
result = true
finally: finally:
await chann.close() await chann.close()
await conn.close() await conn.close()
expect LPStreamEOFError: check:
waitFor(testClosedForRead()) waitFor(testClosedForRead()) == true
test "should not allow pushing data to channel when remote end closed": test "should not allow pushing data to channel when remote end closed":
proc testResetWrite() {.async.} = proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let
conn = newConnection(newBufferStream(writeHandler)) conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
try:
await chann.closedByRemote() await chann.closedByRemote()
try:
await chann.pushTo(@[byte(1)]) await chann.pushTo(@[byte(1)])
except LPStreamEOFError:
result = true
finally: finally:
await chann.close() await chann.close()
await conn.close() await conn.close()
expect LPStreamEOFError: check:
waitFor(testResetWrite()) waitFor(testResetWrite()) == true
test "reset - channel should fail reading": test "reset - channel should fail reading":
proc testResetRead() {.async.} = proc testResetRead(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let
conn = newConnection(newBufferStream(writeHandler)) conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
try:
await chann.reset() await chann.reset()
var data = newSeq[byte](1) var data = newSeq[byte](1)
try:
await chann.readExactly(addr data[0], 1) await chann.readExactly(addr data[0], 1)
doAssert(len(data) == 1) doAssert(len(data) == 1)
except LPStreamEOFError:
result = true
finally: finally:
await conn.close() await conn.close()
expect LPStreamEOFError: check:
waitFor(testResetRead()) waitFor(testResetRead()) == true
test "reset - channel should fail writing": test "reset - channel should fail writing":
proc testResetWrite() {.async.} = proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let
conn = newConnection(newBufferStream(writeHandler)) conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
try:
await chann.reset() await chann.reset()
await chann.write(cast[seq[byte]]("Hello!")) try:
await chann.write(("Hello!").toBytes)
except LPStreamEOFError:
result = true
finally: finally:
await conn.close() await conn.close()
expect LPStreamEOFError: check:
waitFor(testResetWrite()) waitFor(testResetWrite()) == true
test "e2e - read/write receiver": test "e2e - read/write receiver":
proc testNewStream() {.async.} = proc testNewStream() {.async.} =
@ -212,7 +222,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection) mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} = {.async, gcsafe.} =
let msg = await stream.readLp(1024) let msg = await stream.readLp(1024)
check cast[string](msg) == "HELLO" check string.fromBytes(msg) == "HELLO"
await stream.close() await stream.close()
done.complete() done.complete()
@ -250,7 +260,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection) mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} = {.async, gcsafe.} =
let msg = await stream.readLp(1024) let msg = await stream.readLp(1024)
check cast[string](msg) == "HELLO" check string.fromBytes(msg) == "HELLO"
await stream.close() await stream.close()
done.complete() done.complete()
@ -351,7 +361,7 @@ suite "Mplex":
let mplexDial = newMplex(conn) let mplexDial = newMplex(conn)
let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
let stream = await mplexDial.newStream("DIALER") let stream = await mplexDial.newStream("DIALER")
let msg = cast[string](await stream.readLp(1024)) let msg = string.fromBytes(await stream.readLp(1024))
await stream.close() await stream.close()
check msg == "Hello from stream!" check msg == "Hello from stream!"
@ -416,7 +426,7 @@ suite "Mplex":
mplexListen.streamHandler = proc(stream: Connection) mplexListen.streamHandler = proc(stream: Connection)
{.async, gcsafe.} = {.async, gcsafe.} =
let msg = await stream.readLp(1024) let msg = await stream.readLp(1024)
check cast[string](msg) == &"stream {count} from dialer!" check string.fromBytes(msg) == &"stream {count} from dialer!"
await stream.writeLp(&"stream {count} from listener!") await stream.writeLp(&"stream {count} from listener!")
count.inc count.inc
await stream.close() await stream.close()
@ -438,7 +448,7 @@ suite "Mplex":
let stream = await mplexDial.newStream("dialer stream") let stream = await mplexDial.newStream("dialer stream")
await stream.writeLp(&"stream {i} from dialer!") await stream.writeLp(&"stream {i} from dialer!")
let msg = await stream.readLp(1024) let msg = await stream.readLp(1024)
check cast[string](msg) == &"stream {i} from listener!" check string.fromBytes(msg) == &"stream {i} from listener!"
await stream.close() await stream.close()
await done.wait(5.seconds) await done.wait(5.seconds)