mirror of
https://github.com/codex-storage/nim-libp2p.git
synced 2025-01-11 19:44:18 +00:00
use stream directly in chronosstream (#163)
* use stream directly in chronosstream for now, chronos.AsyncStream is not used to provide any features on top of chronos.Stream, so in order to simplify the code, chronosstream can be used directly. In particular, the exception handling is broken in the current chronosstream - opening and closing the stream is simplified this way as well. A future implementation that actually takes advantage of the AsyncStream features would wrap AsyncStream instead as a separate lpstream implementation, leaving this one as-is. * work around chronos exception type issue
This commit is contained in:
parent
c889224012
commit
ccd019b328
@ -19,7 +19,7 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
|
||||
if res.failed:
|
||||
let exc = res.readError()
|
||||
# We still don't abort but warn
|
||||
warn "Something went wrong in a future", error=exc.name
|
||||
warn "Something went wrong in a future", error=exc.name, msg = exc.msg
|
||||
else:
|
||||
quote do:
|
||||
for res in `futs`:
|
||||
@ -28,10 +28,10 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
|
||||
let exc = res.readError()
|
||||
for i in 0..<`nexclude`:
|
||||
if exc of `exclude`[i]:
|
||||
trace "Ignoring an error (no warning)", error=exc.name
|
||||
trace "Ignoring an error (no warning)", error=exc.name, msg = exc.msg
|
||||
break check
|
||||
# We still don't abort but warn
|
||||
warn "Something went wrong in a future", error=exc.name
|
||||
warn "Something went wrong in a future", error=exc.name, msg = exc.msg
|
||||
|
||||
proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] =
|
||||
var futs: seq[Future[T]]
|
||||
|
@ -14,18 +14,11 @@ logScope:
|
||||
topic = "ChronosStream"
|
||||
|
||||
type ChronosStream* = ref object of LPStream
|
||||
reader: AsyncStreamReader
|
||||
writer: AsyncStreamWriter
|
||||
server: StreamServer
|
||||
client: StreamTransport
|
||||
|
||||
proc newChronosStream*(server: StreamServer,
|
||||
client: StreamTransport): ChronosStream =
|
||||
proc newChronosStream*(client: StreamTransport): ChronosStream =
|
||||
new result
|
||||
result.server = server
|
||||
result.client = client
|
||||
result.reader = newAsyncStreamReader(client)
|
||||
result.writer = newAsyncStreamWriter(client)
|
||||
result.closeEvent = newAsyncEvent()
|
||||
|
||||
template withExceptions(body: untyped) =
|
||||
@ -35,47 +28,44 @@ template withExceptions(body: untyped) =
|
||||
raise newLPStreamIncompleteError()
|
||||
except TransportLimitError:
|
||||
raise newLPStreamLimitError()
|
||||
except TransportError as exc:
|
||||
raise newLPStreamIncorrectDefect(exc.msg)
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newLPStreamIncompleteError()
|
||||
except TransportUseClosedError:
|
||||
raise newLPStreamEOFError()
|
||||
except TransportError:
|
||||
# TODO https://github.com/status-im/nim-chronos/pull/99
|
||||
raise newLPStreamEOFError()
|
||||
# raise (ref LPStreamError)(msg: exc.msg, parent: exc)
|
||||
|
||||
method readExactly*(s: ChronosStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.async.} =
|
||||
if s.reader.atEof:
|
||||
if s.client.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
withExceptions:
|
||||
await s.reader.readExactly(pbytes, nbytes)
|
||||
await s.client.readExactly(pbytes, nbytes)
|
||||
|
||||
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
|
||||
if s.reader.atEof:
|
||||
if s.client.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
withExceptions:
|
||||
result = await s.reader.readOnce(pbytes, nbytes)
|
||||
result = await s.client.readOnce(pbytes, nbytes)
|
||||
|
||||
method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
|
||||
if s.writer.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
if msg.len == 0:
|
||||
return
|
||||
|
||||
withExceptions:
|
||||
await s.writer.write(msg)
|
||||
# Returns 0 sometimes when write fails - but there's not much we can do here?
|
||||
if (await s.client.write(msg)) != msg.len:
|
||||
raise (ref LPStreamError)(msg: "Write couldn't finish writing")
|
||||
|
||||
method closed*(s: ChronosStream): bool {.inline.} =
|
||||
# TODO: we might only need to check for reader's EOF
|
||||
result = s.reader.atEof()
|
||||
result = s.client.closed
|
||||
|
||||
method close*(s: ChronosStream) {.async.} =
|
||||
if not s.closed:
|
||||
trace "shutting chronos stream", address = $s.client.remoteAddress()
|
||||
if not s.writer.closed():
|
||||
await s.writer.closeWait()
|
||||
|
||||
if not s.reader.closed():
|
||||
await s.reader.closeWait()
|
||||
|
||||
if not s.client.closed():
|
||||
await s.client.closeWait()
|
||||
|
||||
|
@ -62,11 +62,10 @@ proc cleanup(t: Transport, conn: Connection) {.async.} =
|
||||
t.connections.keepItIf(it != conn)
|
||||
|
||||
proc connHandler*(t: TcpTransport,
|
||||
server: StreamServer,
|
||||
client: StreamTransport,
|
||||
initiator: bool): Connection =
|
||||
trace "handling connection for", address = $client.remoteAddress
|
||||
let conn: Connection = newConnection(newChronosStream(server, client))
|
||||
let conn: Connection = newConnection(newChronosStream(client))
|
||||
conn.observedAddrs = MultiAddress.init(client.remoteAddress)
|
||||
if not initiator:
|
||||
if not isNil(t.handler):
|
||||
@ -83,7 +82,7 @@ proc connCb(server: StreamServer,
|
||||
let t = cast[TcpTransport](server.udata)
|
||||
# we don't need result connection in this case
|
||||
# as it's added inside connHandler
|
||||
discard t.connHandler(server, client, false)
|
||||
discard t.connHandler(client, false)
|
||||
|
||||
method init*(t: TcpTransport) =
|
||||
t.multicodec = multiCodec("tcp")
|
||||
@ -141,7 +140,7 @@ method dial*(t: TcpTransport,
|
||||
trace "dialing remote peer", address = $address
|
||||
## dial a peer
|
||||
let client: StreamTransport = await connect(address)
|
||||
result = t.connHandler(t.server, client, true)
|
||||
result = t.connHandler(client, true)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||
if procCall Transport(t).handles(address):
|
||||
|
23
tests/helpers.nim
Normal file
23
tests/helpers.nim
Normal file
@ -0,0 +1,23 @@
|
||||
import chronos
|
||||
|
||||
import ../libp2p/transports/tcptransport
|
||||
import ../libp2p/stream/bufferstream
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
trackerNames = [
|
||||
BufferStreamTrackerName,
|
||||
TcpTransportTrackerName,
|
||||
StreamTransportTrackerName,
|
||||
StreamServerTrackerName
|
||||
]
|
||||
|
||||
iterator testTrackers*(extras: openArray[string] = []): TrackerBase =
|
||||
for name in trackerNames:
|
||||
let t = getTracker(name)
|
||||
if not isNil(t): yield t
|
||||
for name in extras:
|
||||
let t = getTracker(name)
|
||||
if not isNil(t): yield t
|
@ -20,9 +20,7 @@ import utils,
|
||||
protocols/pubsub/rpc/messages,
|
||||
protocols/pubsub/rpc/message]
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
import ../helpers
|
||||
|
||||
proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
|
||||
# turn things deterministic
|
||||
@ -37,17 +35,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
|
||||
|
||||
suite "FloodSub":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
# getTracker(ConnectionTrackerName),
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "FloodSub basic publish/subscribe A -> B":
|
||||
|
@ -4,26 +4,14 @@ import unittest
|
||||
import ../../libp2p/errors
|
||||
import ../../libp2p/stream/bufferstream
|
||||
|
||||
import ../helpers
|
||||
|
||||
type
|
||||
TestGossipSub = ref object of GossipSub
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
suite "GossipSub internal":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "`rebalanceMesh` Degree Lo":
|
||||
|
@ -20,9 +20,7 @@ import utils, ../../libp2p/[errors,
|
||||
protocols/pubsub/gossipsub,
|
||||
protocols/pubsub/rpc/messages]
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
import ../helpers
|
||||
|
||||
proc createGossipSub(): GossipSub =
|
||||
var peerInfo = PeerInfo.init(PrivateKey.random(RSA))
|
||||
@ -49,17 +47,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
|
||||
|
||||
suite "GossipSub":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "GossipSub validation should succeed":
|
||||
|
@ -9,26 +9,13 @@ import ../libp2p/[protocols/identify,
|
||||
transports/transport,
|
||||
transports/tcptransport,
|
||||
crypto/crypto]
|
||||
import ./helpers
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
suite "Identify":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(TcpTransportTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "handle identify message":
|
||||
|
@ -14,26 +14,13 @@ import ../libp2p/[errors,
|
||||
vbuffer,
|
||||
varint]
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
import ./helpers
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
suite "Mplex":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(TcpTransportTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "encode header with channel id 0":
|
||||
|
@ -11,6 +11,8 @@ import ../libp2p/errors,
|
||||
../libp2p/transports/tcptransport,
|
||||
../libp2p/protocols/protocol
|
||||
|
||||
import ./helpers
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
## Mock stream for select test
|
||||
@ -18,10 +20,6 @@ type
|
||||
TestSelectStream = ref object of LPStream
|
||||
step*: int
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
method readExactly*(s: TestSelectStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.async, gcsafe.} =
|
||||
@ -152,18 +150,7 @@ proc newTestNaStream(na: NaHandler): TestNaStream =
|
||||
|
||||
suite "Multistream select":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
# getTracker(ConnectionTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(TcpTransportTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "test select custom proto":
|
||||
|
@ -31,11 +31,10 @@ import ../libp2p/[switch,
|
||||
muxers/mplex/types,
|
||||
protocols/secure/noise,
|
||||
protocols/secure/secure]
|
||||
import ./helpers
|
||||
|
||||
const
|
||||
TestCodec = "/test/proto/1.0.0"
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
type
|
||||
TestProto = ref object of LPProtocol
|
||||
@ -71,18 +70,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
|
||||
|
||||
suite "Noise":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(TcpTransportTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "e2e: handle write + noise":
|
||||
|
@ -5,23 +5,11 @@ import chronos
|
||||
import ../libp2p/crypto/crypto,
|
||||
../libp2p/peerinfo,
|
||||
../libp2p/peer
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
import ./helpers
|
||||
|
||||
suite "PeerInfo":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "Should init with private key":
|
||||
|
@ -1,3 +1,5 @@
|
||||
{.used.}
|
||||
|
||||
import unittest, tables
|
||||
import chronos
|
||||
import chronicles
|
||||
@ -20,13 +22,10 @@ import ../libp2p/[errors,
|
||||
protocols/secure/secio,
|
||||
protocols/secure/secure,
|
||||
stream/lpstream]
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
import ./helpers
|
||||
|
||||
const
|
||||
TestCodec = "/test/proto/1.0.0"
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
|
||||
type
|
||||
TestProto = ref object of LPProtocol
|
||||
@ -52,19 +51,7 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
|
||||
|
||||
suite "Switch":
|
||||
teardown:
|
||||
let
|
||||
trackers = [
|
||||
# getTracker(ConnectionTrackerName),
|
||||
getTracker(BufferStreamTrackerName),
|
||||
getTracker(AsyncStreamWriterTrackerName),
|
||||
getTracker(TcpTransportTrackerName),
|
||||
getTracker(AsyncStreamReaderTrackerName),
|
||||
getTracker(StreamTransportTrackerName),
|
||||
getTracker(StreamServerTrackerName)
|
||||
]
|
||||
for tracker in trackers:
|
||||
if not isNil(tracker):
|
||||
# echo tracker.dump()
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "e2e use switch dial proto string":
|
||||
|
@ -1,3 +1,5 @@
|
||||
{.used.}
|
||||
|
||||
import unittest
|
||||
import chronos
|
||||
import ../libp2p/[connection,
|
||||
@ -5,21 +7,12 @@ import ../libp2p/[connection,
|
||||
transports/tcptransport,
|
||||
multiaddress,
|
||||
wire]
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
const
|
||||
StreamTransportTrackerName = "stream.transport"
|
||||
StreamServerTrackerName = "stream.server"
|
||||
import ./helpers
|
||||
|
||||
suite "TCP transport":
|
||||
teardown:
|
||||
check:
|
||||
# getTracker(ConnectionTrackerName).isLeaked() == false
|
||||
getTracker(AsyncStreamReaderTrackerName).isLeaked() == false
|
||||
getTracker(AsyncStreamWriterTrackerName).isLeaked() == false
|
||||
getTracker(StreamTransportTrackerName).isLeaked() == false
|
||||
getTracker(StreamServerTrackerName).isLeaked() == false
|
||||
for tracker in testTrackers():
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "test listener: handle write":
|
||||
proc testListener(): Future[bool] {.async, gcsafe.} =
|
||||
|
Loading…
x
Reference in New Issue
Block a user