From ccd019b328a774923a379287d8e641fc62a9720e Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Fri, 8 May 2020 22:10:06 +0200 Subject: [PATCH] use stream directly in chronosstream (#163) * use stream directly in chronosstream for now, chronos.AsyncStream is not used to provide any features on top of chronos.Stream, so in order to simplify the code, chronosstream can be used directly. In particular, the exception handling is broken in the current chronosstream - opening and closing the stream is simplified this way as well. A future implementation that actually takes advantage of the AsyncStream features would wrap AsyncStream instead as a separate lpstream implementation, leaving this one as-is. * work around chronos exception type issue --- libp2p/errors.nim | 6 ++-- libp2p/stream/chronosstream.nim | 44 +++++++++++------------------ libp2p/transports/tcptransport.nim | 7 ++--- tests/helpers.nim | 23 +++++++++++++++ tests/pubsub/testfloodsub.nim | 20 +++---------- tests/pubsub/testgossipinternal.nim | 20 +++---------- tests/pubsub/testgossipsub.nim | 18 ++---------- tests/testidentify.nim | 19 ++----------- tests/testmplex.nim | 21 +++----------- tests/testmultistream.nim | 21 +++----------- tests/testnoise.nim | 18 ++---------- tests/testpeerinfo.nim | 18 ++---------- tests/testswitch.nim | 23 ++++----------- tests/testtransport.nim | 17 ++++------- 14 files changed, 84 insertions(+), 191 deletions(-) create mode 100644 tests/helpers.nim diff --git a/libp2p/errors.nim b/libp2p/errors.nim index 3cd1613c7..ba3909a8e 100644 --- a/libp2p/errors.nim +++ b/libp2p/errors.nim @@ -19,7 +19,7 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped = if res.failed: let exc = res.readError() # We still don't abort but warn - warn "Something went wrong in a future", error=exc.name + warn "Something went wrong in a future", error=exc.name, msg = exc.msg else: quote do: for res in `futs`: @@ -28,10 +28,10 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped = let exc = res.readError() for i in 0..<`nexclude`: if exc of `exclude`[i]: - trace "Ignoring an error (no warning)", error=exc.name + trace "Ignoring an error (no warning)", error=exc.name, msg = exc.msg break check # We still don't abort but warn - warn "Something went wrong in a future", error=exc.name + warn "Something went wrong in a future", error=exc.name, msg = exc.msg proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] = var futs: seq[Future[T]] diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index aaaf23ae0..fe98d4298 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -14,18 +14,11 @@ logScope: topic = "ChronosStream" type ChronosStream* = ref object of LPStream - reader: AsyncStreamReader - writer: AsyncStreamWriter - server: StreamServer client: StreamTransport -proc newChronosStream*(server: StreamServer, - client: StreamTransport): ChronosStream = +proc newChronosStream*(client: StreamTransport): ChronosStream = new result - result.server = server result.client = client - result.reader = newAsyncStreamReader(client) - result.writer = newAsyncStreamWriter(client) result.closeEvent = newAsyncEvent() template withExceptions(body: untyped) = @@ -35,47 +28,44 @@ template withExceptions(body: untyped) = raise newLPStreamIncompleteError() except TransportLimitError: raise newLPStreamLimitError() - except TransportError as exc: - raise newLPStreamIncorrectDefect(exc.msg) - except AsyncStreamIncompleteError: - raise newLPStreamIncompleteError() + except TransportUseClosedError: + raise newLPStreamEOFError() + except TransportError: + # TODO https://github.com/status-im/nim-chronos/pull/99 + raise newLPStreamEOFError() + # raise (ref LPStreamError)(msg: exc.msg, parent: exc) method readExactly*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[void] {.async.} = - if s.reader.atEof: + if s.client.atEof: raise newLPStreamEOFError() withExceptions: - await s.reader.readExactly(pbytes, nbytes) + await s.client.readExactly(pbytes, nbytes) method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = - if s.reader.atEof: + if s.client.atEof: raise newLPStreamEOFError() withExceptions: - result = await s.reader.readOnce(pbytes, nbytes) + result = await s.client.readOnce(pbytes, nbytes) method write*(s: ChronosStream, msg: seq[byte]) {.async.} = - if s.writer.atEof: - raise newLPStreamEOFError() + if msg.len == 0: + return withExceptions: - await s.writer.write(msg) + # Returns 0 sometimes when write fails - but there's not much we can do here? + if (await s.client.write(msg)) != msg.len: + raise (ref LPStreamError)(msg: "Write couldn't finish writing") method closed*(s: ChronosStream): bool {.inline.} = - # TODO: we might only need to check for reader's EOF - result = s.reader.atEof() + result = s.client.closed method close*(s: ChronosStream) {.async.} = if not s.closed: trace "shutting chronos stream", address = $s.client.remoteAddress() - if not s.writer.closed(): - await s.writer.closeWait() - - if not s.reader.closed(): - await s.reader.closeWait() - if not s.client.closed(): await s.client.closeWait() diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index a28ef072b..4b80ee740 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -62,11 +62,10 @@ proc cleanup(t: Transport, conn: Connection) {.async.} = t.connections.keepItIf(it != conn) proc connHandler*(t: TcpTransport, - server: StreamServer, client: StreamTransport, initiator: bool): Connection = trace "handling connection for", address = $client.remoteAddress - let conn: Connection = newConnection(newChronosStream(server, client)) + let conn: Connection = newConnection(newChronosStream(client)) conn.observedAddrs = MultiAddress.init(client.remoteAddress) if not initiator: if not isNil(t.handler): @@ -83,7 +82,7 @@ proc connCb(server: StreamServer, let t = cast[TcpTransport](server.udata) # we don't need result connection in this case # as it's added inside connHandler - discard t.connHandler(server, client, false) + discard t.connHandler(client, false) method init*(t: TcpTransport) = t.multicodec = multiCodec("tcp") @@ -141,7 +140,7 @@ method dial*(t: TcpTransport, trace "dialing remote peer", address = $address ## dial a peer let client: StreamTransport = await connect(address) - result = t.connHandler(t.server, client, true) + result = t.connHandler(client, true) method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = if procCall Transport(t).handles(address): diff --git a/tests/helpers.nim b/tests/helpers.nim new file mode 100644 index 000000000..aac3ce759 --- /dev/null +++ b/tests/helpers.nim @@ -0,0 +1,23 @@ +import chronos + +import ../libp2p/transports/tcptransport +import ../libp2p/stream/bufferstream + +const + StreamTransportTrackerName = "stream.transport" + StreamServerTrackerName = "stream.server" + + trackerNames = [ + BufferStreamTrackerName, + TcpTransportTrackerName, + StreamTransportTrackerName, + StreamServerTrackerName + ] + +iterator testTrackers*(extras: openArray[string] = []): TrackerBase = + for name in trackerNames: + let t = getTracker(name) + if not isNil(t): yield t + for name in extras: + let t = getTracker(name) + if not isNil(t): yield t diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 46a154f3f..1f3ab2352 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -20,9 +20,7 @@ import utils, protocols/pubsub/rpc/messages, protocols/pubsub/rpc/message] -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" +import ../helpers proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = # turn things deterministic @@ -37,18 +35,8 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = suite "FloodSub": teardown: - let - trackers = [ - # getTracker(ConnectionTrackerName), - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "FloodSub basic publish/subscribe A -> B": proc runTests(): Future[bool] {.async.} = @@ -242,7 +230,7 @@ suite "FloodSub": var awaitters: seq[Future[void]] for i in 0..<10: awaitters.add(await nodes[i].start()) - + await subscribeNodes(nodes) for i in 0..<10: diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index ec8b6b6b1..06104bd90 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -4,27 +4,15 @@ import unittest import ../../libp2p/errors import ../../libp2p/stream/bufferstream +import ../helpers + type TestGossipSub = ref object of GossipSub -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" - suite "GossipSub internal": teardown: - let - trackers = [ - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "`rebalanceMesh` Degree Lo": proc testRun(): Future[bool] {.async.} = diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 22a4cbd6c..82c629964 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -20,9 +20,7 @@ import utils, ../../libp2p/[errors, protocols/pubsub/gossipsub, protocols/pubsub/rpc/messages] -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" +import ../helpers proc createGossipSub(): GossipSub = var peerInfo = PeerInfo.init(PrivateKey.random(RSA)) @@ -49,18 +47,8 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = suite "GossipSub": teardown: - let - trackers = [ - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "GossipSub validation should succeed": proc runTests(): Future[bool] {.async.} = diff --git a/tests/testidentify.nim b/tests/testidentify.nim index e61464b9f..d82ed5378 100644 --- a/tests/testidentify.nim +++ b/tests/testidentify.nim @@ -9,27 +9,14 @@ import ../libp2p/[protocols/identify, transports/transport, transports/tcptransport, crypto/crypto] +import ./helpers when defined(nimHasUsed): {.used.} -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" - suite "Identify": teardown: - let - trackers = [ - getTracker(AsyncStreamWriterTrackerName), - getTracker(TcpTransportTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "handle identify message": proc testHandle(): Future[bool] {.async.} = diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 7821415a4..83ac55860 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -14,27 +14,14 @@ import ../libp2p/[errors, vbuffer, varint] -when defined(nimHasUsed): {.used.} +import ./helpers -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" +when defined(nimHasUsed): {.used.} suite "Mplex": teardown: - let - trackers = [ - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(TcpTransportTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "encode header with channel id 0": proc testEncodeHeader(): Future[bool] {.async.} = diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 5b20ad2b9..eaba30e29 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -11,6 +11,8 @@ import ../libp2p/errors, ../libp2p/transports/tcptransport, ../libp2p/protocols/protocol +import ./helpers + when defined(nimHasUsed): {.used.} ## Mock stream for select test @@ -18,10 +20,6 @@ type TestSelectStream = ref object of LPStream step*: int -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" - method readExactly*(s: TestSelectStream, pbytes: pointer, nbytes: int): Future[void] {.async, gcsafe.} = @@ -152,19 +150,8 @@ proc newTestNaStream(na: NaHandler): TestNaStream = suite "Multistream select": teardown: - let - trackers = [ - # getTracker(ConnectionTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(TcpTransportTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "test select custom proto": proc testSelect(): Future[bool] {.async.} = diff --git a/tests/testnoise.nim b/tests/testnoise.nim index 1b2cda0c5..520cf5695 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -31,11 +31,10 @@ import ../libp2p/[switch, muxers/mplex/types, protocols/secure/noise, protocols/secure/secure] +import ./helpers const TestCodec = "/test/proto/1.0.0" - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" type TestProto = ref object of LPProtocol @@ -71,19 +70,8 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) = suite "Noise": teardown: - let - trackers = [ - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(TcpTransportTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "e2e: handle write + noise": proc testListenerDialer(): Future[bool] {.async.} = diff --git a/tests/testpeerinfo.nim b/tests/testpeerinfo.nim index a33a15bfa..2993d9354 100644 --- a/tests/testpeerinfo.nim +++ b/tests/testpeerinfo.nim @@ -5,24 +5,12 @@ import chronos import ../libp2p/crypto/crypto, ../libp2p/peerinfo, ../libp2p/peer - -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" +import ./helpers suite "PeerInfo": teardown: - let - trackers = [ - getTracker(AsyncStreamWriterTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "Should init with private key": let seckey = PrivateKey.random(RSA) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 913d1458e..9f3ab7a16 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,3 +1,5 @@ +{.used.} + import unittest, tables import chronos import chronicles @@ -20,13 +22,10 @@ import ../libp2p/[errors, protocols/secure/secio, protocols/secure/secure, stream/lpstream] - -when defined(nimHasUsed): {.used.} +import ./helpers const TestCodec = "/test/proto/1.0.0" - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" type TestProto = ref object of LPProtocol @@ -52,20 +51,8 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) = suite "Switch": teardown: - let - trackers = [ - # getTracker(ConnectionTrackerName), - getTracker(BufferStreamTrackerName), - getTracker(AsyncStreamWriterTrackerName), - getTracker(TcpTransportTrackerName), - getTracker(AsyncStreamReaderTrackerName), - getTracker(StreamTransportTrackerName), - getTracker(StreamServerTrackerName) - ] - for tracker in trackers: - if not isNil(tracker): - # echo tracker.dump() - check tracker.isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "e2e use switch dial proto string": proc testSwitch(): Future[bool] {.async, gcsafe.} = diff --git a/tests/testtransport.nim b/tests/testtransport.nim index 28d3a84f3..4a538720d 100644 --- a/tests/testtransport.nim +++ b/tests/testtransport.nim @@ -1,3 +1,5 @@ +{.used.} + import unittest import chronos import ../libp2p/[connection, @@ -5,21 +7,12 @@ import ../libp2p/[connection, transports/tcptransport, multiaddress, wire] - -when defined(nimHasUsed): {.used.} - -const - StreamTransportTrackerName = "stream.transport" - StreamServerTrackerName = "stream.server" +import ./helpers suite "TCP transport": teardown: - check: - # getTracker(ConnectionTrackerName).isLeaked() == false - getTracker(AsyncStreamReaderTrackerName).isLeaked() == false - getTracker(AsyncStreamWriterTrackerName).isLeaked() == false - getTracker(StreamTransportTrackerName).isLeaked() == false - getTracker(StreamServerTrackerName).isLeaked() == false + for tracker in testTrackers(): + check tracker.isLeaked() == false test "test listener: handle write": proc testListener(): Future[bool] {.async, gcsafe.} =