use stream directly in chronosstream (#163)

* use stream directly in chronosstream

for now, chronos.AsyncStream is not used to provide any features on top
of chronos.Stream, so in order to simplify the code, chronosstream can
be used directly.

In particular, the exception handling is broken in the current
chronosstream - opening and closing the stream is simplified this way as
well.

A future implementation that actually takes advantage of the AsyncStream
features would wrap AsyncStream instead as a separate lpstream
implementation, leaving this one as-is.

* work around chronos exception type issue
This commit is contained in:
Jacek Sieka 2020-05-08 22:10:06 +02:00 committed by GitHub
parent c889224012
commit ccd019b328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 84 additions and 191 deletions

View File

@ -19,7 +19,7 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
if res.failed: if res.failed:
let exc = res.readError() let exc = res.readError()
# We still don't abort but warn # We still don't abort but warn
warn "Something went wrong in a future", error=exc.name warn "Something went wrong in a future", error=exc.name, msg = exc.msg
else: else:
quote do: quote do:
for res in `futs`: for res in `futs`:
@ -28,10 +28,10 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
let exc = res.readError() let exc = res.readError()
for i in 0..<`nexclude`: for i in 0..<`nexclude`:
if exc of `exclude`[i]: if exc of `exclude`[i]:
trace "Ignoring an error (no warning)", error=exc.name trace "Ignoring an error (no warning)", error=exc.name, msg = exc.msg
break check break check
# We still don't abort but warn # We still don't abort but warn
warn "Something went wrong in a future", error=exc.name warn "Something went wrong in a future", error=exc.name, msg = exc.msg
proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] = proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] =
var futs: seq[Future[T]] var futs: seq[Future[T]]

View File

@ -14,18 +14,11 @@ logScope:
topic = "ChronosStream" topic = "ChronosStream"
type ChronosStream* = ref object of LPStream type ChronosStream* = ref object of LPStream
reader: AsyncStreamReader
writer: AsyncStreamWriter
server: StreamServer
client: StreamTransport client: StreamTransport
proc newChronosStream*(server: StreamServer, proc newChronosStream*(client: StreamTransport): ChronosStream =
client: StreamTransport): ChronosStream =
new result new result
result.server = server
result.client = client result.client = client
result.reader = newAsyncStreamReader(client)
result.writer = newAsyncStreamWriter(client)
result.closeEvent = newAsyncEvent() result.closeEvent = newAsyncEvent()
template withExceptions(body: untyped) = template withExceptions(body: untyped) =
@ -35,47 +28,44 @@ template withExceptions(body: untyped) =
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
except TransportLimitError: except TransportLimitError:
raise newLPStreamLimitError() raise newLPStreamLimitError()
except TransportError as exc: except TransportUseClosedError:
raise newLPStreamIncorrectDefect(exc.msg) raise newLPStreamEOFError()
except AsyncStreamIncompleteError: except TransportError:
raise newLPStreamIncompleteError() # TODO https://github.com/status-im/nim-chronos/pull/99
raise newLPStreamEOFError()
# raise (ref LPStreamError)(msg: exc.msg, parent: exc)
method readExactly*(s: ChronosStream, method readExactly*(s: ChronosStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async.} = nbytes: int): Future[void] {.async.} =
if s.reader.atEof: if s.client.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
withExceptions: withExceptions:
await s.reader.readExactly(pbytes, nbytes) await s.client.readExactly(pbytes, nbytes)
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.reader.atEof: if s.client.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
withExceptions: withExceptions:
result = await s.reader.readOnce(pbytes, nbytes) result = await s.client.readOnce(pbytes, nbytes)
method write*(s: ChronosStream, msg: seq[byte]) {.async.} = method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
if s.writer.atEof: if msg.len == 0:
raise newLPStreamEOFError() return
withExceptions: withExceptions:
await s.writer.write(msg) # Returns 0 sometimes when write fails - but there's not much we can do here?
if (await s.client.write(msg)) != msg.len:
raise (ref LPStreamError)(msg: "Write couldn't finish writing")
method closed*(s: ChronosStream): bool {.inline.} = method closed*(s: ChronosStream): bool {.inline.} =
# TODO: we might only need to check for reader's EOF result = s.client.closed
result = s.reader.atEof()
method close*(s: ChronosStream) {.async.} = method close*(s: ChronosStream) {.async.} =
if not s.closed: if not s.closed:
trace "shutting chronos stream", address = $s.client.remoteAddress() trace "shutting chronos stream", address = $s.client.remoteAddress()
if not s.writer.closed():
await s.writer.closeWait()
if not s.reader.closed():
await s.reader.closeWait()
if not s.client.closed(): if not s.client.closed():
await s.client.closeWait() await s.client.closeWait()

View File

@ -62,11 +62,10 @@ proc cleanup(t: Transport, conn: Connection) {.async.} =
t.connections.keepItIf(it != conn) t.connections.keepItIf(it != conn)
proc connHandler*(t: TcpTransport, proc connHandler*(t: TcpTransport,
server: StreamServer,
client: StreamTransport, client: StreamTransport,
initiator: bool): Connection = initiator: bool): Connection =
trace "handling connection for", address = $client.remoteAddress trace "handling connection for", address = $client.remoteAddress
let conn: Connection = newConnection(newChronosStream(server, client)) let conn: Connection = newConnection(newChronosStream(client))
conn.observedAddrs = MultiAddress.init(client.remoteAddress) conn.observedAddrs = MultiAddress.init(client.remoteAddress)
if not initiator: if not initiator:
if not isNil(t.handler): if not isNil(t.handler):
@ -83,7 +82,7 @@ proc connCb(server: StreamServer,
let t = cast[TcpTransport](server.udata) let t = cast[TcpTransport](server.udata)
# we don't need result connection in this case # we don't need result connection in this case
# as it's added inside connHandler # as it's added inside connHandler
discard t.connHandler(server, client, false) discard t.connHandler(client, false)
method init*(t: TcpTransport) = method init*(t: TcpTransport) =
t.multicodec = multiCodec("tcp") t.multicodec = multiCodec("tcp")
@ -141,7 +140,7 @@ method dial*(t: TcpTransport,
trace "dialing remote peer", address = $address trace "dialing remote peer", address = $address
## dial a peer ## dial a peer
let client: StreamTransport = await connect(address) let client: StreamTransport = await connect(address)
result = t.connHandler(t.server, client, true) result = t.connHandler(client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address): if procCall Transport(t).handles(address):

23
tests/helpers.nim Normal file
View File

@ -0,0 +1,23 @@
import chronos
import ../libp2p/transports/tcptransport
import ../libp2p/stream/bufferstream
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
trackerNames = [
BufferStreamTrackerName,
TcpTransportTrackerName,
StreamTransportTrackerName,
StreamServerTrackerName
]
iterator testTrackers*(extras: openArray[string] = []): TrackerBase =
for name in trackerNames:
let t = getTracker(name)
if not isNil(t): yield t
for name in extras:
let t = getTracker(name)
if not isNil(t): yield t

View File

@ -20,9 +20,7 @@ import utils,
protocols/pubsub/rpc/messages, protocols/pubsub/rpc/messages,
protocols/pubsub/rpc/message] protocols/pubsub/rpc/message]
const import ../helpers
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
# turn things deterministic # turn things deterministic
@ -37,18 +35,8 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
suite "FloodSub": suite "FloodSub":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
# getTracker(ConnectionTrackerName),
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
check tracker.isLeaked() == false
test "FloodSub basic publish/subscribe A -> B": test "FloodSub basic publish/subscribe A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =

View File

@ -4,27 +4,15 @@ import unittest
import ../../libp2p/errors import ../../libp2p/errors
import ../../libp2p/stream/bufferstream import ../../libp2p/stream/bufferstream
import ../helpers
type type
TestGossipSub = ref object of GossipSub TestGossipSub = ref object of GossipSub
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "GossipSub internal": suite "GossipSub internal":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "`rebalanceMesh` Degree Lo": test "`rebalanceMesh` Degree Lo":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =

View File

@ -20,9 +20,7 @@ import utils, ../../libp2p/[errors,
protocols/pubsub/gossipsub, protocols/pubsub/gossipsub,
protocols/pubsub/rpc/messages] protocols/pubsub/rpc/messages]
const import ../helpers
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
proc createGossipSub(): GossipSub = proc createGossipSub(): GossipSub =
var peerInfo = PeerInfo.init(PrivateKey.random(RSA)) var peerInfo = PeerInfo.init(PrivateKey.random(RSA))
@ -49,18 +47,8 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
suite "GossipSub": suite "GossipSub":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "GossipSub validation should succeed": test "GossipSub validation should succeed":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =

View File

@ -9,27 +9,14 @@ import ../libp2p/[protocols/identify,
transports/transport, transports/transport,
transports/tcptransport, transports/tcptransport,
crypto/crypto] crypto/crypto]
import ./helpers
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "Identify": suite "Identify":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "handle identify message": test "handle identify message":
proc testHandle(): Future[bool] {.async.} = proc testHandle(): Future[bool] {.async.} =

View File

@ -14,27 +14,14 @@ import ../libp2p/[errors,
vbuffer, vbuffer,
varint] varint]
when defined(nimHasUsed): {.used.} import ./helpers
const when defined(nimHasUsed): {.used.}
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "Mplex": suite "Mplex":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "encode header with channel id 0": test "encode header with channel id 0":
proc testEncodeHeader(): Future[bool] {.async.} = proc testEncodeHeader(): Future[bool] {.async.} =

View File

@ -11,6 +11,8 @@ import ../libp2p/errors,
../libp2p/transports/tcptransport, ../libp2p/transports/tcptransport,
../libp2p/protocols/protocol ../libp2p/protocols/protocol
import ./helpers
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
## Mock stream for select test ## Mock stream for select test
@ -18,10 +20,6 @@ type
TestSelectStream = ref object of LPStream TestSelectStream = ref object of LPStream
step*: int step*: int
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
method readExactly*(s: TestSelectStream, method readExactly*(s: TestSelectStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async, gcsafe.} = nbytes: int): Future[void] {.async, gcsafe.} =
@ -152,19 +150,8 @@ proc newTestNaStream(na: NaHandler): TestNaStream =
suite "Multistream select": suite "Multistream select":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
# getTracker(ConnectionTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "test select custom proto": test "test select custom proto":
proc testSelect(): Future[bool] {.async.} = proc testSelect(): Future[bool] {.async.} =

View File

@ -31,11 +31,10 @@ import ../libp2p/[switch,
muxers/mplex/types, muxers/mplex/types,
protocols/secure/noise, protocols/secure/noise,
protocols/secure/secure] protocols/secure/secure]
import ./helpers
const const
TestCodec = "/test/proto/1.0.0" TestCodec = "/test/proto/1.0.0"
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
type type
TestProto = ref object of LPProtocol TestProto = ref object of LPProtocol
@ -71,19 +70,8 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
suite "Noise": suite "Noise":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e: handle write + noise": test "e2e: handle write + noise":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =

View File

@ -5,24 +5,12 @@ import chronos
import ../libp2p/crypto/crypto, import ../libp2p/crypto/crypto,
../libp2p/peerinfo, ../libp2p/peerinfo,
../libp2p/peer ../libp2p/peer
import ./helpers
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "PeerInfo": suite "PeerInfo":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "Should init with private key": test "Should init with private key":
let seckey = PrivateKey.random(RSA) let seckey = PrivateKey.random(RSA)

View File

@ -1,3 +1,5 @@
{.used.}
import unittest, tables import unittest, tables
import chronos import chronos
import chronicles import chronicles
@ -20,13 +22,10 @@ import ../libp2p/[errors,
protocols/secure/secio, protocols/secure/secio,
protocols/secure/secure, protocols/secure/secure,
stream/lpstream] stream/lpstream]
import ./helpers
when defined(nimHasUsed): {.used.}
const const
TestCodec = "/test/proto/1.0.0" TestCodec = "/test/proto/1.0.0"
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
type type
TestProto = ref object of LPProtocol TestProto = ref object of LPProtocol
@ -52,20 +51,8 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
suite "Switch": suite "Switch":
teardown: teardown:
let for tracker in testTrackers():
trackers = [ check tracker.isLeaked() == false
# getTracker(ConnectionTrackerName),
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e use switch dial proto string": test "e2e use switch dial proto string":
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch(): Future[bool] {.async, gcsafe.} =

View File

@ -1,3 +1,5 @@
{.used.}
import unittest import unittest
import chronos import chronos
import ../libp2p/[connection, import ../libp2p/[connection,
@ -5,21 +7,12 @@ import ../libp2p/[connection,
transports/tcptransport, transports/tcptransport,
multiaddress, multiaddress,
wire] wire]
import ./helpers
when defined(nimHasUsed): {.used.}
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "TCP transport": suite "TCP transport":
teardown: teardown:
check: for tracker in testTrackers():
# getTracker(ConnectionTrackerName).isLeaked() == false check tracker.isLeaked() == false
getTracker(AsyncStreamReaderTrackerName).isLeaked() == false
getTracker(AsyncStreamWriterTrackerName).isLeaked() == false
getTracker(StreamTransportTrackerName).isLeaked() == false
getTracker(StreamServerTrackerName).isLeaked() == false
test "test listener: handle write": test "test listener: handle write":
proc testListener(): Future[bool] {.async, gcsafe.} = proc testListener(): Future[bool] {.async, gcsafe.} =