From 773b738c12e2250f204c49df8aa834b5687a197e Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Mon, 18 May 2020 11:05:34 -0600 Subject: [PATCH] don't track Connection, track StreamTransport (#177) * don't track Connection, track StreamTransport * make tests more deterministic --- libp2p/transports/tcptransport.nim | 20 +++---- libp2p/transports/transport.nim | 4 +- tests/testmplex.nim | 84 +++++++++++++++++------------- 3 files changed, 55 insertions(+), 53 deletions(-) diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index fd2d7c5..ccab6d6 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -25,6 +25,7 @@ const type TcpTransport* = ref object of Transport server*: StreamServer + clients: seq[StreamTransport] cleanups*: seq[Future[void]] handlers*: seq[Future[void]] @@ -56,11 +57,6 @@ proc setupTcpTransportTracker(): TcpTransportTracker = result.isLeaked = leakTransport addTracker(TcpTransportTrackerName, result) -proc cleanup(t: Transport, conn: Connection) {.async.} = - await conn.closeEvent.wait() - trace "connection cleanup event wait ended" - t.connections.keepItIf(it != conn) - proc connHandler*(t: TcpTransport, client: StreamTransport, initiator: bool): Connection = @@ -71,10 +67,7 @@ proc connHandler*(t: TcpTransport, if not isNil(t.handler): t.handlers &= t.handler(conn) - # TODO: store the streamtransport client here - t.connections.add(conn) - t.cleanups &= t.cleanup(conn) - + t.clients.add(client) result = conn proc connCb(server: StreamServer, @@ -108,6 +101,9 @@ method close*(t: TcpTransport) {.async, gcsafe.} = trace "stopping transport" await procCall Transport(t).close() # call base + checkFutures(await allFinished( + t.clients.mapIt(it.closeWait()))) + # server can be nil if not isNil(t.server): t.server.stop() @@ -118,15 +114,13 @@ method close*(t: TcpTransport) {.async, gcsafe.} = for fut in t.handlers: if not fut.finished: fut.cancel() - t.handlers = await allFinished(t.handlers) - checkFutures(t.handlers) + checkFutures(await allFinished(t.handlers)) t.handlers = @[] for fut in t.cleanups: if not fut.finished: fut.cancel() - t.cleanups = await allFinished(t.cleanups) - checkFutures(t.cleanups) + checkFutures(await allFinished(t.cleanups)) t.cleanups = @[] trace "transport stopped" diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 6919f72..68f7f85 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -24,7 +24,6 @@ type Transport* = ref object of RootObj ma*: Multiaddress - connections*: seq[Connection] handler*: ConnHandler multicodec*: MultiCodec flags*: TransportFlags @@ -49,8 +48,7 @@ proc newTransport*(t: typedesc[Transport], flags: TransportFlags = {}): t {.gcsa method close*(t: Transport) {.base, async, gcsafe.} = ## stop and cleanup the transport ## including all outstanding connections - let futs = await allFinished(t.connections.mapIt(it.close())) - checkFutures(futs) + discard method listen*(t: Transport, ma: MultiAddress, diff --git a/tests/testmplex.nim b/tests/testmplex.nim index bad5ede..6641ff3 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -31,7 +31,7 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) - await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1")) + await conn.writeMsg(0, MessageType.New, ("stream 1").toBytes) await conn.close() waitFor(testEncodeHeader()) @@ -43,7 +43,7 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) - await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1")) + await conn.writeMsg(17, MessageType.New, ("stream 1").toBytes) await conn.close() waitFor(testEncodeHeader()) @@ -55,7 +55,7 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) - await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1")) + await conn.writeMsg(0, MessageType.MsgOut, ("stream 1").toBytes) await conn.close() waitFor(testEncodeHeaderBody()) @@ -67,7 +67,7 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) - await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1")) + await conn.writeMsg(17, MessageType.MsgOut, ("stream 1").toBytes) await conn.close() waitFor(testEncodeHeaderBody()) @@ -94,7 +94,7 @@ suite "Mplex": check msg.id == 0 check msg.msgType == MessageType.MsgOut - check cast[string](msg.data) == "hello from channel 0!!" + check string.fromBytes(msg.data) == "hello from channel 0!!" await conn.close() waitFor(testDecodeHeader()) @@ -108,29 +108,31 @@ suite "Mplex": check msg.id == 17 check msg.msgType == MessageType.MsgOut - check cast[string](msg.data) == "hello from channel 0!!" + check string.fromBytes(msg.data) == "hello from channel 0!!" await conn.close() waitFor(testDecodeHeader()) test "half closed - channel should close for write": - proc testClosedForWrite() {.async.} = + proc testClosedForWrite(): Future[bool] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newConnection(newBufferStream(writeHandler)) chann = newChannel(1, conn, true) + await chann.close() try: - await chann.close() await chann.write("Hello") + except LPStreamEOFError: + result = true finally: await chann.reset() await conn.close() - expect LPStreamEOFError: - waitFor(testClosedForWrite()) + check: + waitFor(testClosedForWrite()) == true test "half closed - channel should close for read by remote": - proc testClosedForRead() {.async.} = + proc testClosedForRead(): Future[bool] {.async.} = let conn = newConnection(newBufferStream( proc (data: seq[byte]) {.gcsafe, async.} = @@ -138,69 +140,77 @@ suite "Mplex": )) chann = newChannel(1, conn, true) - try: - await chann.pushTo(cast[seq[byte]]("Hello!")) - let closeFut = chann.closedByRemote() + await chann.pushTo(("Hello!").toBytes) + let closeFut = chann.closedByRemote() - var data = newSeq[byte](6) - await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer + var data = newSeq[byte](6) + await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer + try: await chann.readExactly(addr data[0], 6) # this should throw await closeFut + except LPStreamEOFError: + result = true finally: await chann.close() await conn.close() - expect LPStreamEOFError: - waitFor(testClosedForRead()) + check: + waitFor(testClosedForRead()) == true test "should not allow pushing data to channel when remote end closed": - proc testResetWrite() {.async.} = + proc testResetWrite(): Future[bool] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newConnection(newBufferStream(writeHandler)) chann = newChannel(1, conn, true) + await chann.closedByRemote() try: - await chann.closedByRemote() await chann.pushTo(@[byte(1)]) + except LPStreamEOFError: + result = true finally: await chann.close() await conn.close() - expect LPStreamEOFError: - waitFor(testResetWrite()) + check: + waitFor(testResetWrite()) == true test "reset - channel should fail reading": - proc testResetRead() {.async.} = + proc testResetRead(): Future[bool] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newConnection(newBufferStream(writeHandler)) chann = newChannel(1, conn, true) + await chann.reset() + var data = newSeq[byte](1) try: - await chann.reset() - var data = newSeq[byte](1) await chann.readExactly(addr data[0], 1) doAssert(len(data) == 1) + except LPStreamEOFError: + result = true finally: await conn.close() - expect LPStreamEOFError: - waitFor(testResetRead()) + check: + waitFor(testResetRead()) == true test "reset - channel should fail writing": - proc testResetWrite() {.async.} = + proc testResetWrite(): Future[bool] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newConnection(newBufferStream(writeHandler)) chann = newChannel(1, conn, true) + await chann.reset() try: - await chann.reset() - await chann.write(cast[seq[byte]]("Hello!")) + await chann.write(("Hello!").toBytes) + except LPStreamEOFError: + result = true finally: await conn.close() - expect LPStreamEOFError: - waitFor(testResetWrite()) + check: + waitFor(testResetWrite()) == true test "e2e - read/write receiver": proc testNewStream() {.async.} = @@ -212,7 +222,7 @@ suite "Mplex": mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp(1024) - check cast[string](msg) == "HELLO" + check string.fromBytes(msg) == "HELLO" await stream.close() done.complete() @@ -250,7 +260,7 @@ suite "Mplex": mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp(1024) - check cast[string](msg) == "HELLO" + check string.fromBytes(msg) == "HELLO" await stream.close() done.complete() @@ -351,7 +361,7 @@ suite "Mplex": let mplexDial = newMplex(conn) let mplexDialFut = mplexDial.handle() let stream = await mplexDial.newStream("DIALER") - let msg = cast[string](await stream.readLp(1024)) + let msg = string.fromBytes(await stream.readLp(1024)) await stream.close() check msg == "Hello from stream!" @@ -416,7 +426,7 @@ suite "Mplex": mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp(1024) - check cast[string](msg) == &"stream {count} from dialer!" + check string.fromBytes(msg) == &"stream {count} from dialer!" await stream.writeLp(&"stream {count} from listener!") count.inc await stream.close() @@ -438,7 +448,7 @@ suite "Mplex": let stream = await mplexDial.newStream("dialer stream") await stream.writeLp(&"stream {i} from dialer!") let msg = await stream.readLp(1024) - check cast[string](msg) == &"stream {i} from listener!" + check string.fromBytes(msg) == &"stream {i} from listener!" await stream.close() await done.wait(5.seconds)