Add chronos trackers and used them to sanitize resource disposal (#131)

* Add chronos trackers and used them to sanitize resource disposal

* Chronos trackers for transport tests wip

* No more chronos leaks in testtransport

* Make tcp transport and test more robust when closing

* Test async leaking tracking wip

* Fix a regression in wire connect

* Add chronos trackers to more tests and sanitize resource closure

* Wip fixing floodsub tests

* Floodsub wip

* Made floodsub basically deterministic, hit a nim bug with captures tho

* Wrap up floodsub tests refactor

* Wrapping up

* Add allFuturesThrowing utility

* Fix missing allFuturesThrowing in noise tests!

* Make tests green

* attempt fixing gossipsub failing cases

* Make sure to check also fanout in waitSub

* More verbose traces

* Gossipsub test improvments

* Refactor TcpTransport remove asyncCheck

* Add Connection trackers

* Add stricter connection tracking, wip mplex fix

* More asynccheck removal, in order to avoid connection leaks

* bump chronicles requirement

* Enable tracker dump to check CI output

* Wait for more futures in testmplex

* Remove tracker dump messages

* add tryAndWarn utility, fix mplex issue with go interop

* All allFuturesThrowing to directchat too

* make sure to cleanup on transport close
This commit is contained in:
Giovanni Petrantoni 2020-04-21 10:24:42 +09:00 committed by GitHub
parent 027e8227ea
commit 4c6a123d31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1079 additions and 275 deletions

View File

@ -6,6 +6,7 @@ import chronos # an efficient library for async
import ../libp2p/[switch, # manage transports, a single entry point for dialing and listening import ../libp2p/[switch, # manage transports, a single entry point for dialing and listening
multistream, # tag stream with short header to identify it multistream, # tag stream with short header to identify it
crypto/crypto, # cryptographic functions crypto/crypto, # cryptographic functions
errors, # error handling utilities
protocols/identify, # identify the peer info of a peer protocols/identify, # identify the peer info of a peer
connection, # create and close stream read / write connections connection, # create and close stream read / write connections
transports/transport, # listen and dial to other peers using p2p protocol transports/transport, # listen and dial to other peers using p2p protocol
@ -196,7 +197,7 @@ proc processInput(rfd: AsyncFD) {.async.} =
echo &"{a}/ipfs/{id}" echo &"{a}/ipfs/{id}"
await chatProto.readWriteLoop() await chatProto.readWriteLoop()
await allFutures(libp2pFuts) await allFuturesThrowing(libp2pFuts)
proc main() {.async.} = proc main() {.async.} =
let (rfd, wfd) = createAsyncPipe() let (rfd, wfd) = createAsyncPipe()

View File

@ -9,12 +9,10 @@ skipDirs = @["tests", "examples", "Nim"]
requires "nim >= 1.2.0", requires "nim >= 1.2.0",
"nimcrypto >= 0.4.1", "nimcrypto >= 0.4.1",
"chronos >= 2.3.8",
"bearssl >= 0.1.4", "bearssl >= 0.1.4",
"chronicles >= 0.7.1", "chronicles >= 0.7.2",
"chronos >= 2.3.8", "chronos >= 2.3.8",
"metrics", "metrics",
"nimcrypto >= 0.4.1",
"secp256k1", "secp256k1",
"stew" "stew"

View File

@ -10,6 +10,7 @@
import oids import oids
import chronos, chronicles, metrics import chronos, chronicles, metrics
import peerinfo, import peerinfo,
errors,
multiaddress, multiaddress,
stream/lpstream, stream/lpstream,
peerinfo, peerinfo,
@ -19,17 +20,50 @@ import peerinfo,
logScope: logScope:
topic = "Connection" topic = "Connection"
const DefaultReadSize* = 1 shl 20 const
DefaultReadSize* = 1 shl 20
ConnectionTrackerName* = "libp2p.connection"
type type
Connection* = ref object of LPStream Connection* = ref object of LPStream
peerInfo*: PeerInfo peerInfo*: PeerInfo
stream*: LPStream stream*: LPStream
observedAddrs*: Multiaddress observedAddrs*: Multiaddress
# notice this is a ugly circular reference collection
# (we got many actually :-))
readLoops*: seq[Future[void]]
InvalidVarintException = object of LPStreamError InvalidVarintException = object of LPStreamError
InvalidVarintSizeException = object of LPStreamError InvalidVarintSizeException = object of LPStreamError
ConnectionTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
proc setupConnectionTracker(): ConnectionTracker {.gcsafe.}
proc getConnectionTracker*(): ConnectionTracker {.gcsafe.} =
result = cast[ConnectionTracker](getTracker(ConnectionTrackerName))
if isNil(result):
result = setupConnectionTracker()
proc dumpTracking(): string {.gcsafe.} =
var tracker = getConnectionTracker()
result = "Opened conns: " & $tracker.opened & "\n" &
"Closed conns: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getConnectionTracker()
result = (tracker.opened != tracker.closed)
proc setupConnectionTracker(): ConnectionTracker =
result = new ConnectionTracker
result.opened = 0
result.closed = 0
result.dump = dumpTracking
result.isLeaked = leakTransport
addTracker(ConnectionTrackerName, result)
declareGauge libp2p_open_connection, "open Connection instances" declareGauge libp2p_open_connection, "open Connection instances"
proc newInvalidVarintException*(): ref InvalidVarintException = proc newInvalidVarintException*(): ref InvalidVarintException =
@ -50,9 +84,9 @@ proc bindStreamClose(conn: Connection) {.async.} =
trace "wrapped stream closed, closing conn", closed = conn.isClosed, trace "wrapped stream closed, closing conn", closed = conn.isClosed,
peer = if not isNil(conn.peerInfo): peer = if not isNil(conn.peerInfo):
conn.peerInfo.id else: "" conn.peerInfo.id else: ""
asyncCheck conn.close() await conn.close()
proc init*[T: Connection](self: var T, stream: LPStream): T = proc init[T: Connection](self: var T, stream: LPStream): T =
## create a new Connection for the specified async reader/writer ## create a new Connection for the specified async reader/writer
new self new self
self.stream = stream self.stream = stream
@ -60,6 +94,7 @@ proc init*[T: Connection](self: var T, stream: LPStream): T =
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
self.oid = genOid() self.oid = genOid()
asyncCheck self.bindStreamClose() asyncCheck self.bindStreamClose()
inc getConnectionTracker().opened
libp2p_open_connection.inc() libp2p_open_connection.inc()
return self return self
@ -116,15 +151,18 @@ method write*(s: Connection,
method closed*(s: Connection): bool = method closed*(s: Connection): bool =
if isNil(s.stream): if isNil(s.stream):
return false return true
result = s.stream.closed result = s.stream.closed
method close*(s: Connection) {.async, gcsafe.} = method close*(s: Connection) {.async, gcsafe.} =
if not s.closed: trace "about to close connection", closed = s.closed,
trace "about to close connection", closed = s.closed, peer = if not isNil(s.peerInfo):
peer = if not isNil(s.peerInfo): s.peerInfo.id else: ""
s.peerInfo.id else: ""
if not s.isClosed:
s.isClosed = true
inc getConnectionTracker().closed
if not isNil(s.stream) and not s.stream.closed: if not isNil(s.stream) and not s.stream.closed:
trace "closing child stream", closed = s.closed, trace "closing child stream", closed = s.closed,
@ -133,7 +171,11 @@ method close*(s: Connection) {.async, gcsafe.} =
await s.stream.close() await s.stream.close()
s.closeEvent.fire() s.closeEvent.fire()
s.isClosed = true
trace "waiting readloops", count=s.readLoops.len
let loopFuts = await allFinished(s.readLoops)
checkFutures(loopFuts)
s.readLoops = @[]
trace "connection closed", closed = s.closed, trace "connection closed", closed = s.closed,
peer = if not isNil(s.peerInfo): peer = if not isNil(s.peerInfo):

View File

@ -37,3 +37,22 @@ macro checkFutures*[T](futs: seq[Future[T]], exclude: untyped = []): untyped =
# We still don't abort but warn # We still don't abort but warn
warn "Something went wrong in a future", warn "Something went wrong in a future",
error=exc.name, file=pos.filename, line=pos.line error=exc.name, file=pos.filename, line=pos.line
proc allFuturesThrowing*[T](args: varargs[Future[T]]): Future[void] =
var futs: seq[Future[T]]
for fut in args:
futs &= fut
proc call() {.async.} =
futs = await allFinished(futs)
for fut in futs:
if fut.failed:
raise fut.readError()
return call()
template tryAndWarn*(msg: static[string]; body: untyped): untyped =
try:
body
except CancelledError as ex:
raise ex
except CatchableError as ex:
warn "ignored an error", name=ex.name, msg=msg

View File

@ -11,6 +11,7 @@ import strutils
import chronos, chronicles import chronos, chronicles
import connection, import connection,
vbuffer, vbuffer,
errors,
protocols/protocol protocols/protocol
logScope: logScope:
@ -116,7 +117,7 @@ proc list*(m: MultistreamSelect,
proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} =
trace "handle: starting multistream handling" trace "handle: starting multistream handling"
try: tryAndWarn "multistream handle":
while not conn.closed: while not conn.closed:
var ms = cast[string]((await conn.readLp())) var ms = cast[string]((await conn.readLp()))
ms.removeSuffix("\n") ms.removeSuffix("\n")
@ -145,18 +146,14 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} =
if (not isNil(h.match) and h.match(ms)) or ms == h.proto: if (not isNil(h.match) and h.match(ms)) or ms == h.proto:
trace "found handler for", protocol = ms trace "found handler for", protocol = ms
await conn.writeLp((h.proto & "\n")) await conn.writeLp((h.proto & "\n"))
try: tryAndWarn "multistream handle handler":
await h.protocol.handler(conn, ms) await h.protocol.handler(conn, ms)
return return
except CatchableError as exc:
warn "exception while handling", msg = exc.msg
return
warn "no handlers for ", protocol = ms warn "no handlers for ", protocol = ms
await conn.write(m.na) await conn.write(m.na)
except CatchableError as exc: trace "leaving multistream loop"
trace "Exception occurred", exc = exc.msg # we might be tempted to close conn here but that would be a bad idea!
finally: # we indeed will reuse it later on
trace "leaving multistream loop"
proc addHandler*[T: LPProtocol](m: MultistreamSelect, proc addHandler*[T: LPProtocol](m: MultistreamSelect,
codec: string, codec: string,

View File

@ -92,12 +92,15 @@ method handle*(m: Mplex) {.async, gcsafe.} =
let stream = newConnection(channel) let stream = newConnection(channel)
stream.peerInfo = m.connection.peerInfo stream.peerInfo = m.connection.peerInfo
# cleanup channel once handler is finished proc handler() {.async.} =
# stream.closeEvent.wait().addCallback( tryAndWarn "mplex channel handler":
# proc(udata: pointer) = await m.streamHandler(stream)
# asyncCheck cleanupChann(m, channel, initiator)) # TODO closing stream
# or doing cleanupChann
# will make go interop tests fail
# need to investigate why
asyncCheck m.streamHandler(stream) asynccheck handler()
continue continue
of MessageType.MsgIn, MessageType.MsgOut: of MessageType.MsgIn, MessageType.MsgOut:
trace "pushing data to channel", id = id, trace "pushing data to channel", id = id,

View File

@ -87,7 +87,10 @@ method init(g: GossipSub) =
method handleDisconnect(g: GossipSub, peer: PubSubPeer) {.async.} = method handleDisconnect(g: GossipSub, peer: PubSubPeer) {.async.} =
## handle peer disconnects ## handle peer disconnects
trace "peer disconnected", peer=peer.id
await procCall FloodSub(g).handleDisconnect(peer) await procCall FloodSub(g).handleDisconnect(peer)
for t in g.gossipsub.keys: for t in g.gossipsub.keys:
g.gossipsub[t].excl(peer.id) g.gossipsub[t].excl(peer.id)
@ -263,9 +266,9 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} =
if g.mesh[topic].len < GossipSubDlo: if g.mesh[topic].len < GossipSubDlo:
trace "replenishing mesh" trace "replenishing mesh"
# replenish the mesh if we're bellow GossipSubDlo # replenish the mesh if we're below GossipSubDlo
while g.mesh[topic].len < GossipSubD: while g.mesh[topic].len < GossipSubD:
trace "gattering peers", peers = g.mesh[topic].len trace "gathering peers", peers = g.mesh[topic].len
var id: string var id: string
if topic in g.fanout and g.fanout[topic].len > 0: if topic in g.fanout and g.fanout[topic].len > 0:
id = g.fanout[topic].pop() id = g.fanout[topic].pop()
@ -457,12 +460,31 @@ when isMainModule:
## ##
import unittest import unittest
import ../../errors
import ../../stream/bufferstream import ../../stream/bufferstream
type type
TestGossipSub = ref object of GossipSub TestGossipSub = ref object of GossipSub
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "GossipSub": suite "GossipSub":
teardown:
let
trackers = [
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "`rebalanceMesh` Degree Lo": test "`rebalanceMesh` Degree Lo":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, let gossipSub = newPubSub(TestGossipSub,
@ -473,8 +495,10 @@ when isMainModule:
proc writeHandler(data: seq[byte]) {.async.} = proc writeHandler(data: seq[byte]) {.async.} =
discard discard
var conns = newSeq[Connection]()
for i in 0..<15: for i in 0..<15:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -485,6 +509,8 @@ when isMainModule:
await gossipSub.rebalanceMesh(topic) await gossipSub.rebalanceMesh(topic)
check gossipSub.mesh[topic].len == GossipSubD check gossipSub.mesh[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -500,8 +526,10 @@ when isMainModule:
proc writeHandler(data: seq[byte]) {.async.} = proc writeHandler(data: seq[byte]) {.async.} =
discard discard
var conns = newSeq[Connection]()
for i in 0..<15: for i in 0..<15:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -512,6 +540,8 @@ when isMainModule:
await gossipSub.rebalanceMesh(topic) await gossipSub.rebalanceMesh(topic)
check gossipSub.mesh[topic].len == GossipSubD check gossipSub.mesh[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -530,8 +560,10 @@ when isMainModule:
proc writeHandler(data: seq[byte]) {.async.} = proc writeHandler(data: seq[byte]) {.async.} =
discard discard
var conns = newSeq[Connection]()
for i in 0..<15: for i in 0..<15:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
var peerInfo = PeerInfo.init(PrivateKey.random(RSA)) var peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -542,6 +574,8 @@ when isMainModule:
await gossipSub.replenishFanout(topic) await gossipSub.replenishFanout(topic)
check gossipSub.fanout[topic].len == GossipSubD check gossipSub.fanout[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -561,8 +595,10 @@ when isMainModule:
proc writeHandler(data: seq[byte]) {.async.} = proc writeHandler(data: seq[byte]) {.async.} =
discard discard
var conns = newSeq[Connection]()
for i in 0..<6: for i in 0..<6:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -574,6 +610,8 @@ when isMainModule:
await gossipSub.dropFanoutPeers() await gossipSub.dropFanoutPeers()
check topic notin gossipSub.fanout check topic notin gossipSub.fanout
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -597,8 +635,10 @@ when isMainModule:
proc writeHandler(data: seq[byte]) {.async.} = proc writeHandler(data: seq[byte]) {.async.} =
discard discard
var conns = newSeq[Connection]()
for i in 0..<6: for i in 0..<6:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -613,6 +653,8 @@ when isMainModule:
check topic1 notin gossipSub.fanout check topic1 notin gossipSub.fanout
check topic2 in gossipSub.fanout check topic2 in gossipSub.fanout
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -633,8 +675,10 @@ when isMainModule:
gossipSub.mesh[topic] = initHashSet[string]() gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[string]() gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]() gossipSub.gossipsub[topic] = initHashSet[string]()
var conns = newSeq[Connection]()
for i in 0..<30: for i in 0..<30:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -646,6 +690,7 @@ when isMainModule:
for i in 0..<15: for i in 0..<15:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -662,6 +707,8 @@ when isMainModule:
check p notin gossipSub.fanout[topic] check p notin gossipSub.fanout[topic]
check p notin gossipSub.mesh[topic] check p notin gossipSub.mesh[topic]
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -681,8 +728,10 @@ when isMainModule:
let topic = "foobar" let topic = "foobar"
gossipSub.fanout[topic] = initHashSet[string]() gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]() gossipSub.gossipsub[topic] = initHashSet[string]()
var conns = newSeq[Connection]()
for i in 0..<30: for i in 0..<30:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -694,6 +743,9 @@ when isMainModule:
let peers = gossipSub.getGossipPeers() let peers = gossipSub.getGossipPeers()
check peers.len == GossipSubD check peers.len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -713,8 +765,10 @@ when isMainModule:
let topic = "foobar" let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]() gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]() gossipSub.gossipsub[topic] = initHashSet[string]()
var conns = newSeq[Connection]()
for i in 0..<30: for i in 0..<30:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -726,6 +780,9 @@ when isMainModule:
let peers = gossipSub.getGossipPeers() let peers = gossipSub.getGossipPeers()
check peers.len == GossipSubD check peers.len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:
@ -745,8 +802,10 @@ when isMainModule:
let topic = "foobar" let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]() gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[string]() gossipSub.fanout[topic] = initHashSet[string]()
var conns = newSeq[Connection]()
for i in 0..<30: for i in 0..<30:
let conn = newConnection(newBufferStream(writeHandler)) let conn = newConnection(newBufferStream(writeHandler))
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(RSA)) let peerInfo = PeerInfo.init(PrivateKey.random(RSA))
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec)
@ -758,6 +817,9 @@ when isMainModule:
let peers = gossipSub.getGossipPeers() let peers = gossipSub.getGossipPeers()
check peers.len == 0 check peers.len == 0
await allFuturesThrowing(conns.mapIt(it.close()))
result = true result = true
check: check:

View File

@ -23,7 +23,6 @@ proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} =
while true: while true:
var topic: string var topic: string
if pb.getString(1, topic) < 0: if pb.getString(1, topic) < 0:
trace "unable to read topic field from graft msg, breaking"
break break
trace "read topic field from graft msg", topicID = topic trace "read topic field from graft msg", topicID = topic
@ -38,8 +37,8 @@ proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
var topic: string var topic: string
if pb.getString(1, topic) < 0: if pb.getString(1, topic) < 0:
break break
trace "read topic field", topicID = topic
trace "read topic field from prune msg", topicID = topic
result.add(ControlPrune(topicID: topic)) result.add(ControlPrune(topicID: topic))
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} = proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =

View File

@ -499,6 +499,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
var secure = new NoiseConnection var secure = new NoiseConnection
inc getConnectionTracker().opened
secure.stream = conn secure.stream = conn
secure.closeEvent = newAsyncEvent() secure.closeEvent = newAsyncEvent()
secure.peerInfo = PeerInfo.init(remotePubKey) secure.peerInfo = PeerInfo.init(remotePubKey)

View File

@ -266,6 +266,8 @@ proc newSecioConn(conn: Connection,
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
result.oid = genOid() result.oid = genOid()
inc getConnectionTracker().opened
proc transactMessage(conn: Connection, proc transactMessage(conn: Connection,
msg: seq[byte]): Future[seq[byte]] {.async.} = msg: seq[byte]): Future[seq[byte]] {.async.} =
var buf = newSeq[byte](4) var buf = newSeq[byte](4)

View File

@ -61,7 +61,7 @@ proc handleConn*(s: Secure, conn: Connection, initiator: bool = false): Future[C
await sconn.writeMessage(data) await sconn.writeMessage(data)
result = newConnection(newBufferStream(writeHandler)) result = newConnection(newBufferStream(writeHandler))
asyncCheck readLoop(sconn, result) conn.readLoops &= readLoop(sconn, result)
if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome:
result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get())

View File

@ -34,7 +34,9 @@ import deques, math, oids
import chronos, chronicles, metrics import chronos, chronicles, metrics
import ../stream/lpstream import ../stream/lpstream
const DefaultBufferSize* = 1024 const
BufferStreamTrackerName* = "libp2p.bufferstream"
DefaultBufferSize* = 1024
type type
# TODO: figure out how to make this generic to avoid casts # TODO: figure out how to make this generic to avoid casts
@ -52,6 +54,33 @@ type
AlreadyPipedError* = object of CatchableError AlreadyPipedError* = object of CatchableError
NotWritableError* = object of CatchableError NotWritableError* = object of CatchableError
BufferStreamTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
proc setupBufferStreamTracker(): BufferStreamTracker {.gcsafe.}
proc getBufferStreamTracker(): BufferStreamTracker {.gcsafe.} =
result = cast[BufferStreamTracker](getTracker(BufferStreamTrackerName))
if isNil(result):
result = setupBufferStreamTracker()
proc dumpTracking(): string {.gcsafe.} =
var tracker = getBufferStreamTracker()
result = "Opened buffers: " & $tracker.opened & "\n" &
"Closed buffers: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getBufferStreamTracker()
result = (tracker.opened != tracker.closed)
proc setupBufferStreamTracker(): BufferStreamTracker =
result = new BufferStreamTracker
result.opened = 0
result.closed = 0
result.dump = dumpTracking
result.isLeaked = leakTransport
addTracker(BufferStreamTrackerName, result)
declareGauge libp2p_open_bufferstream, "open BufferStream instances" declareGauge libp2p_open_bufferstream, "open BufferStream instances"
proc newAlreadyPipedError*(): ref Exception {.inline.} = proc newAlreadyPipedError*(): ref Exception {.inline.} =
@ -77,6 +106,7 @@ proc initBufferStream*(s: BufferStream,
s.lock = newAsyncLock() s.lock = newAsyncLock()
s.writeHandler = handler s.writeHandler = handler
s.closeEvent = newAsyncEvent() s.closeEvent = newAsyncEvent()
inc getBufferStreamTracker().opened
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
s.oid = genOid() s.oid = genOid()
s.isClosed = false s.isClosed = false
@ -181,7 +211,7 @@ method readExactly*(s: BufferStream,
try: try:
buff = await s.read(nbytes) buff = await s.read(nbytes)
except LPStreamEOFError as exc: except LPStreamEOFError as exc:
trace "Exception occured", exc = exc.msg trace "Exception occurred", exc = exc.msg
if nbytes > buff.len(): if nbytes > buff.len():
raise newLPStreamIncompleteError() raise newLPStreamIncompleteError()
@ -399,5 +429,7 @@ method close*(s: BufferStream) {.async.} =
s.readBuf.clear() s.readBuf.clear()
s.closeEvent.fire() s.closeEvent.fire()
s.isClosed = true s.isClosed = true
inc getBufferStreamTracker().closed
libp2p_open_bufferstream.dec() libp2p_open_bufferstream.dec()
else:
trace "attempt to close an already closed bufferstream", trace=getStackTrace()

View File

@ -324,6 +324,8 @@ proc stop*(s: Switch) {.async.} =
futs = await allFinished(futs) futs = await allFinished(futs)
checkFutures(futs) checkFutures(futs)
trace "switch stopped"
proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
## Subscribe to pub sub peer ## Subscribe to pub sub peer
if s.pubSub.isSome and peerInfo.id notin s.dialedPubSubPeers: if s.pubSub.isSome and peerInfo.id notin s.dialedPubSubPeers:

View File

@ -7,8 +7,9 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import chronos, chronicles, sequtils import chronos, chronicles, sequtils, sets
import transport, import transport,
../errors,
../wire, ../wire,
../connection, ../connection,
../multiaddress, ../multiaddress,
@ -18,40 +19,78 @@ import transport,
logScope: logScope:
topic = "TcpTransport" topic = "TcpTransport"
type TcpTransport* = ref object of Transport const
server*: StreamServer TcpTransportTrackerName* = "libp2p.tcptransport"
type
TcpTransport* = ref object of Transport
server*: StreamServer
cleanups*: seq[Future[void]]
handlers*: seq[Future[void]]
TcpTransportTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
proc setupTcpTransportTracker(): TcpTransportTracker {.gcsafe.}
proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} =
result = cast[TcpTransportTracker](getTracker(TcpTransportTrackerName))
if isNil(result):
result = setupTcpTransportTracker()
proc dumpTracking(): string {.gcsafe.} =
var tracker = getTcpTransportTracker()
result = "Opened transports: " & $tracker.opened & "\n" &
"Closed transports: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getTcpTransportTracker()
result = (tracker.opened != tracker.closed)
proc setupTcpTransportTracker(): TcpTransportTracker =
result = new TcpTransportTracker
result.opened = 0
result.closed = 0
result.dump = dumpTracking
result.isLeaked = leakTransport
addTracker(TcpTransportTrackerName, result)
proc cleanup(t: Transport, conn: Connection) {.async.} = proc cleanup(t: Transport, conn: Connection) {.async.} =
await conn.closeEvent.wait() await conn.closeEvent.wait()
trace "connection cleanup event wait ended"
t.connections.keepItIf(it != conn) t.connections.keepItIf(it != conn)
proc connHandler*(t: Transport, proc connHandler*(t: TcpTransport,
server: StreamServer, server: StreamServer,
client: StreamTransport, client: StreamTransport,
initiator: bool = false): initiator: bool): Connection =
Future[Connection] {.async, gcsafe.} =
trace "handling connection for", address = $client.remoteAddress trace "handling connection for", address = $client.remoteAddress
let conn: Connection = newConnection(newChronosStream(server, client)) let conn: Connection = newConnection(newChronosStream(server, client))
conn.observedAddrs = MultiAddress.init(client.remoteAddress) conn.observedAddrs = MultiAddress.init(client.remoteAddress)
if not initiator: if not initiator:
if not isNil(t.handler): if not isNil(t.handler):
asyncCheck t.handler(conn) t.handlers &= t.handler(conn)
t.connections.add(conn) t.connections.add(conn)
asyncCheck t.cleanup(conn) t.cleanups &= t.cleanup(conn)
result = conn result = conn
proc connCb(server: StreamServer, proc connCb(server: StreamServer,
client: StreamTransport) {.async, gcsafe.} = client: StreamTransport) {.async, gcsafe.} =
trace "incomming connection for", address = $client.remoteAddress trace "incomming connection for", address = $client.remoteAddress
let t: Transport = cast[Transport](server.udata) let t = cast[TcpTransport](server.udata)
asyncCheck t.connHandler(server, client) # we don't need result connection in this case
# as it's added inside connHandler
discard t.connHandler(server, client, false)
method init*(t: TcpTransport) = method init*(t: TcpTransport) =
t.multicodec = multiCodec("tcp") t.multicodec = multiCodec("tcp")
method close*(t: TcpTransport): Future[void] {.async, gcsafe.} = inc getTcpTransportTracker().opened
method close*(t: TcpTransport) {.async, gcsafe.} =
## start the transport ## start the transport
trace "stopping transport" trace "stopping transport"
await procCall Transport(t).close() # call base await procCall Transport(t).close() # call base
@ -59,11 +98,28 @@ method close*(t: TcpTransport): Future[void] {.async, gcsafe.} =
# server can be nil # server can be nil
if not isNil(t.server): if not isNil(t.server):
t.server.stop() t.server.stop()
t.server.close() await t.server.closeWait()
await t.server.join()
t.server = nil
for fut in t.handlers:
if not fut.finished:
fut.cancel()
t.handlers = await allFinished(t.handlers)
checkFutures(t.handlers)
t.handlers = @[]
for fut in t.cleanups:
if not fut.finished:
fut.cancel()
t.cleanups = await allFinished(t.cleanups)
checkFutures(t.cleanups)
t.cleanups = @[]
trace "transport stopped" trace "transport stopped"
inc getTcpTransportTracker().closed
method listen*(t: TcpTransport, method listen*(t: TcpTransport,
ma: MultiAddress, ma: MultiAddress,
handler: ConnHandler): handler: ConnHandler):
@ -85,7 +141,7 @@ method dial*(t: TcpTransport,
trace "dialing remote peer", address = $address trace "dialing remote peer", address = $address
## dial a peer ## dial a peer
let client: StreamTransport = await connect(address) let client: StreamTransport = await connect(address)
result = await t.connHandler(t.server, client, true) result = t.connHandler(t.server, client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address): if procCall Transport(t).handles(address):

View File

@ -56,7 +56,7 @@ proc initTAddress*(ma: MultiAddress): TransportAddress =
"Could not initialize address!") "Could not initialize address!")
proc connect*(ma: MultiAddress, bufferSize = DefaultStreamBufferSize, proc connect*(ma: MultiAddress, bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil): Future[StreamTransport] = child: StreamTransport = nil): Future[StreamTransport] {.async.} =
## Open new connection to remote peer with address ``ma`` and create ## Open new connection to remote peer with address ``ma`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport. ## ``bufferSize`` is size of internal buffer for transport.
@ -64,11 +64,8 @@ proc connect*(ma: MultiAddress, bufferSize = DefaultStreamBufferSize,
let address = initTAddress(ma) let address = initTAddress(ma)
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if ma[1].protoCode() != multiCodec("tcp"): if ma[1].protoCode() != multiCodec("tcp"):
var retFuture = newFuture[StreamTransport]() raise newException(TransportAddressError, "Incorrect address type!")
retFuture.fail(newException(TransportAddressError, result = await connect(address, bufferSize, child)
"Incorrect address type!"))
return retFuture
result = connect(address, bufferSize, child)
proc createStreamServer*[T](ma: MultiAddress, proc createStreamServer*[T](ma: MultiAddress,
cbproc: StreamCallback, cbproc: StreamCallback,

View File

@ -7,16 +7,49 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import unittest, sequtils import unittest, sequtils, options, tables, sets
import chronos import chronos
import utils, import utils,
../../libp2p/[switch, ../../libp2p/[errors,
switch,
connection,
stream/bufferstream,
crypto/crypto, crypto/crypto,
protocols/pubsub/pubsub, protocols/pubsub/pubsub,
protocols/pubsub/floodsub,
protocols/pubsub/rpc/messages, protocols/pubsub/rpc/messages,
protocols/pubsub/rpc/message] protocols/pubsub/rpc/message]
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
# turn things deterministic
# this is for testing purposes only
var ceil = 15
let fsub = cast[FloodSub](sender.pubSub.get())
while not fsub.floodsub.hasKey(key) or
not fsub.floodsub[key].contains(receiver.peerInfo.id):
await sleepAsync(100.millis)
dec ceil
doAssert(ceil > 0, "waitSub timeout!")
suite "FloodSub": suite "FloodSub":
teardown:
let
trackers = [
# getTracker(ConnectionTrackerName),
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
check tracker.isLeaked() == false
test "FloodSub basic publish/subscribe A -> B": test "FloodSub basic publish/subscribe A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
var completionFut = newFuture[bool]() var completionFut = newFuture[bool]()
@ -24,21 +57,30 @@ suite "FloodSub":
check topic == "foobar" check topic == "foobar"
completionFut.complete(true) completionFut.complete(true)
var nodes = generateNodes(2) let
var awaiters: seq[Future[void]] nodes = generateNodes(2)
awaiters.add((await nodes[0].start())) nodesFut = await allFinished(
awaiters.add((await nodes[1].start())) nodes[0].start(),
nodes[1].start()
)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(1000.millis) await waitSub(nodes[0], nodes[1], "foobar")
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
result = await completionFut result = await completionFut.wait(5.seconds)
await allFutures(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters)
await allFuturesThrowing(
nodes[0].stop(),
nodes[1].stop()
)
for fut in nodesFut:
let res = fut.read()
await allFuturesThrowing(res)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -55,14 +97,16 @@ suite "FloodSub":
awaiters.add((await nodes[1].start())) awaiters.add((await nodes[1].start()))
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await sleepAsync(1000.millis) await waitSub(nodes[1], nodes[0], "foobar")
await nodes[1].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[1].publish("foobar", cast[seq[byte]]("Hello!"))
result = await completionFut result = await completionFut.wait(5.seconds)
await allFutures(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFuturesThrowing(awaiters)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -81,7 +125,7 @@ suite "FloodSub":
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(1000.millis) await waitSub(nodes[0], nodes[1], "foobar")
var validatorFut = newFuture[bool]() var validatorFut = newFuture[bool]()
proc validator(topic: string, proc validator(topic: string,
@ -91,11 +135,12 @@ suite "FloodSub":
result = true result = true
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
await allFutures(handlerFut, handlerFut) await allFuturesThrowing(handlerFut, handlerFut)
await allFutures(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(awaiters)
result = true result = true
check: check:
@ -113,7 +158,7 @@ suite "FloodSub":
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(100.millis) await waitSub(nodes[0], nodes[1], "foobar")
var validatorFut = newFuture[bool]() var validatorFut = newFuture[bool]()
proc validator(topic: string, proc validator(topic: string,
@ -122,9 +167,11 @@ suite "FloodSub":
result = false result = false
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
await allFutures(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFuturesThrowing(awaiters)
result = true result = true
check: check:
@ -144,8 +191,9 @@ suite "FloodSub":
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foo", handler) await nodes[1].subscribe("foo", handler)
await waitSub(nodes[0], nodes[1], "foo")
await nodes[1].subscribe("bar", handler) await nodes[1].subscribe("bar", handler)
await sleepAsync(1000.millis) await waitSub(nodes[0], nodes[1], "bar")
proc validator(topic: string, proc validator(topic: string,
message: Message): Future[bool] {.async.} = message: Message): Future[bool] {.async.} =
@ -155,12 +203,12 @@ suite "FloodSub":
result = false result = false
nodes[1].addValidator("foo", "bar", validator) nodes[1].addValidator("foo", "bar", validator)
await nodes[0].publish("foo", cast[seq[byte]]("Hello!")) await nodes[0].publish("foo", cast[seq[byte]]("Hello!"))
await nodes[0].publish("bar", cast[seq[byte]]("Hello!")) await nodes[0].publish("bar", cast[seq[byte]]("Hello!"))
await sleepAsync(100.millis) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFutures(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(awaiters)
await allFutures(awaiters)
result = true result = true
check: check:
@ -169,65 +217,107 @@ suite "FloodSub":
test "FloodSub multiple peers, no self trigger": test "FloodSub multiple peers, no self trigger":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
var passed = 0 var passed = 0
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" var futs = newSeq[(Future[void], TopicHandler, ref int)](10)
passed.inc() for i in 0..<10:
closureScope:
var
fut = newFuture[void]()
counter = new int
futs[i] = (
fut,
(proc(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar"
inc counter[]
if counter[] == 9:
fut.complete()),
counter
)
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<10: for i in 0..<10:
nodes.add(newStandardSwitch()) nodes.add newStandardSwitch()
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for i in 0..<10:
awaitters.add(await node.start()) awaitters.add(await nodes[i].start())
await node.subscribe("foobar", handler)
await sleepAsync(100.millis)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(1000.millis)
for node in nodes: for i in 0..<10:
await node.publish("foobar", cast[seq[byte]]("Hello!")) await nodes[i].subscribe("foobar", futs[i][1])
await sleepAsync(100.millis)
await sleepAsync(1.minutes) var subs: seq[Future[void]]
await allFutures(nodes.mapIt(it.stop())) for i in 0..<10:
await allFutures(awaitters) for y in 0..<10:
if y != i:
subs &= waitSub(nodes[i], nodes[y], "foobar")
await allFuturesThrowing(subs)
result = passed >= 10 # non deterministic, so at least 10 times var pubs: seq[Future[void]]
for i in 0..<10:
pubs &= nodes[i].publish("foobar", cast[seq[byte]]("Hello!"))
await allFuturesThrowing(pubs)
await allFuturesThrowing(futs.mapIt(it[0]))
await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFuturesThrowing(awaitters)
result = true
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
test "FloodSub multiple peers, with self trigger": test "FloodSub multiple peers, with self trigger":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
var passed = 0 var passed = 0
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" var futs = newSeq[(Future[void], TopicHandler, ref int)](10)
passed.inc() for i in 0..<10:
closureScope:
var
fut = newFuture[void]()
counter = new int
futs[i] = (
fut,
(proc(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar"
inc counter[]
if counter[] == 10:
fut.complete()),
counter
)
var nodes: seq[Switch] = newSeq[Switch]() var nodes: seq[Switch] = newSeq[Switch]()
for i in 0..<10: for i in 0..<10:
nodes.add newStandardSwitch(triggerSelf = true) nodes.add newStandardSwitch(triggerSelf = true)
var awaitters: seq[Future[void]] var awaitters: seq[Future[void]]
for node in nodes: for i in 0..<10:
awaitters.add((await node.start())) awaitters.add(await nodes[i].start())
await node.subscribe("foobar", handler)
await sleepAsync(100.millis)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(1000.millis)
for node in nodes: for i in 0..<10:
await node.publish("foobar", cast[seq[byte]]("Hello!")) await nodes[i].subscribe("foobar", futs[i][1])
await sleepAsync(100.millis)
await sleepAsync(1.minutes) var subs: seq[Future[void]]
await allFutures(nodes.mapIt(it.stop())) for i in 0..<10:
await allFutures(awaitters) for y in 0..<10:
if y != i:
subs &= waitSub(nodes[i], nodes[y], "foobar")
await allFuturesThrowing(subs)
result = passed >= 20 # non deterministic, so at least 10 times var pubs: seq[Future[void]]
for i in 0..<10:
pubs &= nodes[i].publish("foobar", cast[seq[byte]]("Hello!"))
await allFuturesThrowing(pubs)
await allFuturesThrowing(futs.mapIt(it[0]))
await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFuturesThrowing(awaitters)
result = true
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true

View File

@ -9,7 +9,9 @@
import unittest, sequtils, options, tables, sets import unittest, sequtils, options, tables, sets
import chronos import chronos
import utils, ../../libp2p/[peer, import chronicles
import utils, ../../libp2p/[errors,
peer,
peerinfo, peerinfo,
connection, connection,
crypto/crypto, crypto/crypto,
@ -18,11 +20,48 @@ import utils, ../../libp2p/[peer,
protocols/pubsub/gossipsub, protocols/pubsub/gossipsub,
protocols/pubsub/rpc/messages] protocols/pubsub/rpc/messages]
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
proc createGossipSub(): GossipSub = proc createGossipSub(): GossipSub =
var peerInfo = PeerInfo.init(PrivateKey.random(RSA)) var peerInfo = PeerInfo.init(PrivateKey.random(RSA))
result = newPubSub(GossipSub, peerInfo) result = newPubSub(GossipSub, peerInfo)
proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
if sender == receiver:
return
# turn things deterministic
# this is for testing purposes only
# peers can be inside `mesh` and `fanout`, not just `gossipsub`
var ceil = 15
let fsub = cast[GossipSub](sender.pubSub.get())
while (not fsub.gossipsub.hasKey(key) or
not fsub.gossipsub[key].contains(receiver.peerInfo.id)) and
(not fsub.mesh.hasKey(key) or
not fsub.mesh[key].contains(receiver.peerInfo.id)) and
(not fsub.fanout.hasKey(key) or
not fsub.fanout[key].contains(receiver.peerInfo.id)):
trace "waitSub sleeping...", peers=fsub.gossipsub[key]
await sleepAsync(100.millis)
dec ceil
doAssert(ceil > 0, "waitSub timeout!")
suite "GossipSub": suite "GossipSub":
teardown:
let
trackers = [
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "GossipSub validation should succeed": test "GossipSub validation should succeed":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
var handlerFut = newFuture[bool]() var handlerFut = newFuture[bool]()
@ -35,10 +74,12 @@ suite "GossipSub":
awaiters.add((await nodes[0].start())) awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start())) awaiters.add((await nodes[1].start()))
await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(100.millis)
await nodes[0].subscribe("foobar", handler)
await waitSub(nodes[1], nodes[0], "foobar")
await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar")
var validatorFut = newFuture[bool]() var validatorFut = newFuture[bool]()
proc validator(topic: string, proc validator(topic: string,
@ -52,8 +93,8 @@ suite "GossipSub":
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
result = (await validatorFut) and (await handlerFut) result = (await validatorFut) and (await handlerFut)
await allFutures(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(awaiters)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -69,8 +110,9 @@ suite "GossipSub":
awaiters.add((await nodes[1].start())) awaiters.add((await nodes[1].start()))
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(100.millis) await waitSub(nodes[0], nodes[1], "foobar")
var validatorFut = newFuture[bool]() var validatorFut = newFuture[bool]()
proc validator(topic: string, proc validator(topic: string,
@ -82,10 +124,9 @@ suite "GossipSub":
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
await sleepAsync(100.millis)
result = await validatorFut result = await validatorFut
await allFutures(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(awaiters)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -102,10 +143,11 @@ suite "GossipSub":
awaiters.add((await nodes[0].start())) awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start())) awaiters.add((await nodes[1].start()))
await nodes[1].subscribe("foo", handler)
await nodes[1].subscribe("bar", handler)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(100.millis) await nodes[1].subscribe("foo", handler)
await waitSub(nodes[0], nodes[1], "foo")
await nodes[1].subscribe("bar", handler)
await waitSub(nodes[0], nodes[1], "bar")
var passed, failed: Future[bool] = newFuture[bool]() var passed, failed: Future[bool] = newFuture[bool]()
proc validator(topic: string, proc validator(topic: string,
@ -123,8 +165,8 @@ suite "GossipSub":
await nodes[0].publish("bar", cast[seq[byte]]("Hello!")) await nodes[0].publish("bar", cast[seq[byte]]("Hello!"))
result = ((await passed) and (await failed) and (await handlerFut)) result = ((await passed) and (await failed) and (await handlerFut))
await allFutures(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(nodes[0].stop(), nodes[1].stop())
await allFutures(awaiters) await allFuturesThrowing(awaiters)
result = true result = true
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -151,12 +193,17 @@ suite "GossipSub":
asyncCheck gossip2.handleConn(conn1, GossipSubCodec) asyncCheck gossip2.handleConn(conn1, GossipSubCodec)
await gossip1.subscribe("foobar", handler) await gossip1.subscribe("foobar", handler)
await sleepAsync(10.millis) await sleepAsync(1.seconds)
check: check:
"foobar" in gossip2.gossipsub "foobar" in gossip2.gossipsub
gossip1.peerInfo.id in gossip2.gossipsub["foobar"] gossip1.peerInfo.id in gossip2.gossipsub["foobar"]
await allFuturesThrowing(
buf1.close(),
buf2.close()
)
result = true result = true
check: check:
@ -175,9 +222,9 @@ suite "GossipSub":
for node in nodes: for node in nodes:
awaitters.add(await node.start()) awaitters.add(await node.start())
await nodes[1].subscribe("foobar", handler)
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(100.millis) await nodes[1].subscribe("foobar", handler)
await sleepAsync(1.seconds)
let gossip1 = GossipSub(nodes[0].pubSub.get()) let gossip1 = GossipSub(nodes[0].pubSub.get())
let gossip2 = GossipSub(nodes[1].pubSub.get()) let gossip2 = GossipSub(nodes[1].pubSub.get())
@ -187,8 +234,8 @@ suite "GossipSub":
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
gossip2.peerInfo.id in gossip1.gossipsub["foobar"] gossip2.peerInfo.id in gossip1.gossipsub["foobar"]
await allFutures(nodes.mapIt(it.stop())) await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFutures(awaitters) await allFuturesThrowing(awaitters)
result = true result = true
@ -221,7 +268,7 @@ suite "GossipSub":
await gossip1.subscribe("foobar", handler) await gossip1.subscribe("foobar", handler)
await gossip2.subscribe("foobar", handler) await gossip2.subscribe("foobar", handler)
await sleepAsync(100.millis) await sleepAsync(1.seconds)
check: check:
"foobar" in gossip1.topics "foobar" in gossip1.topics
@ -236,6 +283,11 @@ suite "GossipSub":
gossip1.peerInfo.id in gossip1.gossipsub["foobar"] gossip1.peerInfo.id in gossip1.gossipsub["foobar"]
gossip2.peerInfo.id in gossip2.gossipsub["foobar"] gossip2.peerInfo.id in gossip2.gossipsub["foobar"]
await allFuturesThrowing(
buf1.close(),
buf2.close()
)
result = true result = true
check: check:
@ -254,13 +306,19 @@ suite "GossipSub":
for node in nodes: for node in nodes:
awaitters.add(await node.start()) awaitters.add(await node.start())
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await subscribeNodes(nodes)
await sleepAsync(100.millis)
let gossip1 = GossipSub(nodes[0].pubSub.get()) var subs: seq[Future[void]]
let gossip2 = GossipSub(nodes[1].pubSub.get()) subs &= waitSub(nodes[1], nodes[0], "foobar")
subs &= waitSub(nodes[0], nodes[1], "foobar")
await allFuturesThrowing(subs)
let
gossip1 = GossipSub(nodes[0].pubSub.get())
gossip2 = GossipSub(nodes[1].pubSub.get())
check: check:
"foobar" in gossip1.topics "foobar" in gossip1.topics
@ -269,11 +327,14 @@ suite "GossipSub":
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
"foobar" in gossip2.gossipsub "foobar" in gossip2.gossipsub
gossip1.peerInfo.id in gossip2.gossipsub["foobar"] gossip2.peerInfo.id in gossip1.gossipsub["foobar"] or
gossip2.peerInfo.id in gossip1.gossipsub["foobar"] gossip2.peerInfo.id in gossip1.mesh["foobar"]
await allFutures(nodes.mapIt(it.stop())) gossip1.peerInfo.id in gossip2.gossipsub["foobar"] or
await allFutures(awaitters) gossip1.peerInfo.id in gossip2.mesh["foobar"]
await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFuturesThrowing(awaitters)
result = true result = true
@ -322,10 +383,10 @@ suite "GossipSub":
test "e2e - GossipSub send over fanout A -> B": test "e2e - GossipSub send over fanout A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
var passed: bool var passed = newFuture[void]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
passed = true passed.complete()
var nodes = generateNodes(2, true) var nodes = generateNodes(2, true)
var wait = newSeq[Future[void]]() var wait = newSeq[Future[void]]()
@ -335,33 +396,36 @@ suite "GossipSub":
await subscribeNodes(nodes) await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(1000.millis) await waitSub(nodes[0], nodes[1], "foobar")
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
await sleepAsync(1000.millis)
var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get()) var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get())
check: check:
"foobar" in gossipSub1.gossipsub "foobar" in gossipSub1.gossipsub
await nodes[1].stop() await passed.wait(5.seconds)
await nodes[0].stop()
await allFutures(wait) trace "test done, stopping..."
result = passed
await nodes[0].stop()
await nodes[1].stop()
await allFuturesThrowing(wait)
result = true
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
# test "send over mesh A -> B": # test "send over mesh A -> B":
# proc runTests(): Future[bool] {.async.} = # proc runTests(): Future[bool] {.async.} =
# var passed: bool # var passed = newFuture[void]()
# proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = # proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
# check: # check:
# topic == "foobar" # topic == "foobar"
# cast[string](data) == "Hello!" # cast[string](data) == "Hello!"
# passed = true # passed.complete()
# let gossip1 = createGossipSub() # let gossip1 = createGossipSub()
# let gossip2 = createGossipSub() # let gossip2 = createGossipSub()
@ -387,7 +451,11 @@ suite "GossipSub":
# await gossip2.publish("foobar", cast[seq[byte]]("Hello!")) # await gossip2.publish("foobar", cast[seq[byte]]("Hello!"))
# await sleepAsync(1.seconds) # await sleepAsync(1.seconds)
# result = passed
# await passed.wait(5.seconds)
# result = true
# await allFuturesThrowing(buf1.close(), buf2.close())
# check: # check:
# waitFor(runTests()) == true # waitFor(runTests()) == true
@ -405,17 +473,17 @@ suite "GossipSub":
wait.add(await nodes[1].start()) wait.add(await nodes[1].start())
await subscribeNodes(nodes) await subscribeNodes(nodes)
await sleepAsync(100.millis)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(100.millis) await waitSub(nodes[0], nodes[1], "foobar")
await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) await nodes[0].publish("foobar", cast[seq[byte]]("Hello!"))
result = await passed result = await passed
await nodes[0].stop() await nodes[0].stop()
await nodes[1].stop() await nodes[1].stop()
await allFutures(wait) await allFuturesThrowing(wait)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -466,8 +534,8 @@ suite "GossipSub":
# nodes[1].peerInfo.peerId.get().pretty)) # nodes[1].peerInfo.peerId.get().pretty))
# await sleepAsync(1000.millis) # await sleepAsync(1000.millis)
# await allFutures(nodes.mapIt(it.stop())) # await allFuturesThrowing(nodes.mapIt(it.stop()))
# await allFutures(awaitters) # await allFuturesThrowing(awaitters)
# check: seen.len == 9 # check: seen.len == 9
# for k, v in seen.pairs: # for k, v in seen.pairs:
@ -487,6 +555,8 @@ suite "GossipSub":
nodes.add newStandardSwitch(triggerSelf = true, gossip = true) nodes.add newStandardSwitch(triggerSelf = true, gossip = true)
awaitters.add((await nodes[i].start())) awaitters.add((await nodes[i].start()))
await subscribeNodes(nodes)
var seen: Table[string, int] var seen: Table[string, int]
var subs: seq[Future[void]] var subs: seq[Future[void]]
var seenFut = newFuture[void]() var seenFut = newFuture[void]()
@ -502,10 +572,9 @@ suite "GossipSub":
if not seenFut.finished() and seen.len == 10: if not seenFut.finished() and seen.len == 10:
seenFut.complete() seenFut.complete()
subs.add(dialer.subscribe("foobar", handler)) subs.add(allFutures(dialer.subscribe("foobar", handler), waitSub(nodes[0], dialer, "foobar")))
await allFutures(subs)
await subscribeNodes(nodes) await allFuturesThrowing(subs)
await sleepAsync(1.seconds)
await wait(nodes[0].publish("foobar", await wait(nodes[0].publish("foobar",
cast[seq[byte]]("from node " & cast[seq[byte]]("from node " &
@ -517,8 +586,8 @@ suite "GossipSub":
for k, v in seen.pairs: for k, v in seen.pairs:
check: v == 1 check: v == 1
await allFutures(nodes.mapIt(it.stop())) await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFutures(awaitters) await allFuturesThrowing(awaitters)
result = true result = true
check: check:

View File

@ -12,5 +12,4 @@ proc subscribeNodes*(nodes: seq[Switch]) {.async.} =
for node in nodes: for node in nodes:
if dialer.peerInfo.peerId != node.peerInfo.peerId: if dialer.peerInfo.peerId != node.peerInfo.peerId:
dials.add(dialer.connect(node.peerInfo)) dials.add(dialer.connect(node.peerInfo))
await sleepAsync(100.millis)
await allFutures(dials) await allFutures(dials)

View File

@ -1,10 +1,15 @@
import unittest, strformat import unittest, strformat
import chronos import chronos
import ../libp2p/errors
import ../libp2p/stream/bufferstream import ../libp2p/stream/bufferstream
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
suite "BufferStream": suite "BufferStream":
teardown:
# echo getTracker("libp2p.bufferstream").dump()
check getTracker("libp2p.bufferstream").isLeaked() == false
test "push data to buffer": test "push data to buffer":
proc testPushTo(): Future[bool] {.async.} = proc testPushTo(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
@ -16,6 +21,8 @@ suite "BufferStream":
check buff.len == 5 check buff.len == 5
result = true result = true
await buff.close()
check: check:
waitFor(testPushTo()) == true waitFor(testPushTo()) == true
@ -33,6 +40,8 @@ suite "BufferStream":
result = true result = true
await buff.close()
check: check:
waitFor(testPushTo()) == true waitFor(testPushTo()) == true
@ -47,6 +56,8 @@ suite "BufferStream":
result = true result = true
await buff.close()
check: check:
waitFor(testRead()) == true waitFor(testRead()) == true
@ -62,6 +73,8 @@ suite "BufferStream":
result = true result = true
await buff.close()
check: check:
waitFor(testRead()) == true waitFor(testRead()) == true
@ -81,6 +94,8 @@ suite "BufferStream":
result = true result = true
await buff.close()
check: check:
waitFor(testRead()) == true waitFor(testRead()) == true
@ -102,8 +117,11 @@ suite "BufferStream":
var fut = reader() var fut = reader()
await buff.pushTo(cast[seq[byte]](@"12345")) await buff.pushTo(cast[seq[byte]](@"12345"))
await fut await fut
result = true result = true
await buff.close()
check: check:
waitFor(testRead()) == true waitFor(testRead()) == true
@ -118,8 +136,11 @@ suite "BufferStream":
var data: seq[byte] = newSeq[byte](2) var data: seq[byte] = newSeq[byte](2)
await buff.readExactly(addr data[0], 2) await buff.readExactly(addr data[0], 2)
check cast[string](data) == @['1', '2'] check cast[string](data) == @['1', '2']
result = true result = true
await buff.close()
check: check:
waitFor(testReadExactly()) == true waitFor(testReadExactly()) == true
@ -132,8 +153,11 @@ suite "BufferStream":
await buff.pushTo(cast[seq[byte]](@"12345\n67890")) await buff.pushTo(cast[seq[byte]](@"12345\n67890"))
check buff.len == 11 check buff.len == 11
check "12345" == await buff.readLine(0, "\n") check "12345" == await buff.readLine(0, "\n")
result = true result = true
await buff.close()
check: check:
waitFor(testReadLine()) == true waitFor(testReadLine()) == true
@ -150,8 +174,11 @@ suite "BufferStream":
check (await readFut) == 3 check (await readFut) == 3
check cast[string](data) == @['1', '2', '3'] check cast[string](data) == @['1', '2', '3']
result = true result = true
await buff.close()
check: check:
waitFor(testReadOnce()) == true waitFor(testReadOnce()) == true
@ -168,8 +195,11 @@ suite "BufferStream":
check (await readFut) == 4 check (await readFut) == 4
check cast[string](data) == @['1', '2', '3'] check cast[string](data) == @['1', '2', '3']
result = true result = true
await buff.close()
check: check:
waitFor(testReadUntil()) == true waitFor(testReadUntil()) == true
@ -183,8 +213,11 @@ suite "BufferStream":
var data = "Hello!" var data = "Hello!"
await buff.write(addr data[0], data.len) await buff.write(addr data[0], data.len)
result = true result = true
await buff.close()
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
@ -197,8 +230,11 @@ suite "BufferStream":
check buff.len == 0 check buff.len == 0
await buff.write("Hello!", 6) await buff.write("Hello!", 6)
result = true result = true
await buff.close()
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
@ -211,8 +247,11 @@ suite "BufferStream":
check buff.len == 0 check buff.len == 0
await buff.write(cast[seq[byte]]("Hello!"), 6) await buff.write(cast[seq[byte]]("Hello!"), 6)
result = true result = true
await buff.close()
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
@ -236,8 +275,11 @@ suite "BufferStream":
await buff.write("Msg 8") await buff.write("Msg 8")
await buff.write("Msg 9") await buff.write("Msg 9")
await buff.write("Msg 10") await buff.write("Msg 10")
result = true result = true
await buff.close()
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
@ -265,6 +307,8 @@ suite "BufferStream":
result = true result = true
await buff.close()
check: check:
waitFor(testWritePtr()) == true waitFor(testWritePtr()) == true
@ -295,7 +339,7 @@ suite "BufferStream":
await buf1.pushTo(cast[seq[byte]]("Hello2!")) await buf1.pushTo(cast[seq[byte]]("Hello2!"))
await buf2.pushTo(cast[seq[byte]]("Hello1!")) await buf2.pushTo(cast[seq[byte]]("Hello1!"))
await allFutures(readFut1, readFut2) await allFuturesThrowing(readFut1, readFut2)
check: check:
res1 == cast[seq[byte]]("Hello2!") res1 == cast[seq[byte]]("Hello2!")
@ -303,6 +347,9 @@ suite "BufferStream":
result = true result = true
await buf1.close()
await buf2.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -321,6 +368,9 @@ suite "BufferStream":
result = true result = true
await buf1.close()
await buf2.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -339,7 +389,7 @@ suite "BufferStream":
await buf1.write(cast[seq[byte]]("Hello1!")) await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!")) await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2) await allFuturesThrowing(readFut1, readFut2)
check: check:
res1 == cast[seq[byte]]("Hello2!") res1 == cast[seq[byte]]("Hello2!")
@ -347,6 +397,9 @@ suite "BufferStream":
result = true result = true
await buf1.close()
await buf2.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -368,6 +421,8 @@ suite "BufferStream":
result = true result = true
await buf1.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -386,6 +441,9 @@ suite "BufferStream":
result = true result = true
await buf1.close()
await buf2.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -404,7 +462,7 @@ suite "BufferStream":
await buf1.write(cast[seq[byte]]("Hello1!")) await buf1.write(cast[seq[byte]]("Hello1!"))
await buf2.write(cast[seq[byte]]("Hello2!")) await buf2.write(cast[seq[byte]]("Hello2!"))
await allFutures(readFut1, readFut2) await allFuturesThrowing(readFut1, readFut2)
check: check:
res1 == cast[seq[byte]]("Hello2!") res1 == cast[seq[byte]]("Hello2!")
@ -412,6 +470,9 @@ suite "BufferStream":
result = true result = true
await buf1.close()
await buf2.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -433,6 +494,8 @@ suite "BufferStream":
result = true result = true
await buf1.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -458,9 +521,11 @@ suite "BufferStream":
var writerFut = writer() var writerFut = writer()
var readerFut = reader() var readerFut = reader()
await allFutures(readerFut, writerFut) await allFuturesThrowing(readerFut, writerFut)
result = true result = true
await buf1.close()
check: check:
waitFor(pipeTest()) == true waitFor(pipeTest()) == true
@ -481,6 +546,8 @@ suite "BufferStream":
except AsyncTimeoutError: except AsyncTimeoutError:
result = false result = false
await stream.close()
check: check:
waitFor(closeTest()) == true waitFor(closeTest()) == true

View File

@ -13,7 +13,25 @@ import ../libp2p/[protocols/identify,
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "Identify": suite "Identify":
teardown:
let
trackers = [
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "handle identify message": test "handle identify message":
proc testHandle(): Future[bool] {.async.} = proc testHandle(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
@ -51,8 +69,11 @@ suite "Identify":
await conn.close() await conn.close()
await transport1.close() await transport1.close()
await serverFut await serverFut
result = true result = true
await transport2.close()
check: check:
waitFor(testHandle()) == true waitFor(testHandle()) == true
@ -63,9 +84,13 @@ suite "Identify":
let identifyProto1 = newIdentify(remotePeerInfo) let identifyProto1 = newIdentify(remotePeerInfo)
let msListen = newMultistream() let msListen = newMultistream()
let done = newFuture[void]()
msListen.addHandler(IdentifyCodec, identifyProto1) msListen.addHandler(IdentifyCodec, identifyProto1)
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =
await msListen.handle(conn) await msListen.handle(conn)
await conn.close()
done.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
asyncCheck transport1.listen(ma, connHandler) asyncCheck transport1.listen(ma, connHandler)
@ -76,9 +101,15 @@ suite "Identify":
var localPeerInfo = PeerInfo.init(PrivateKey.random(RSA), [ma]) var localPeerInfo = PeerInfo.init(PrivateKey.random(RSA), [ma])
let identifyProto2 = newIdentify(localPeerInfo) let identifyProto2 = newIdentify(localPeerInfo)
discard await msDial.select(conn, IdentifyCodec)
discard await identifyProto2.identify(conn, PeerInfo.init(PrivateKey.random(RSA))) try:
await conn.close() discard await msDial.select(conn, IdentifyCodec)
discard await identifyProto2.identify(conn, PeerInfo.init(PrivateKey.random(RSA)))
finally:
await done.wait(5000.millis) # when no issues will not wait that long!
await conn.close()
await transport2.close()
await transport1.close()
expect IdentityNoMatchError: expect IdentityNoMatchError:
waitFor(testHandleError()) waitFor(testHandleError())

View File

@ -1,6 +1,7 @@
import unittest, sequtils, sugar, strformat, options, strformat, random import unittest, sequtils, sugar, strformat, options, strformat, random
import chronos, nimcrypto/utils, chronicles import chronos, nimcrypto/utils, chronicles
import ../libp2p/[connection, import ../libp2p/[errors,
connection,
stream/lpstream, stream/lpstream,
stream/bufferstream, stream/bufferstream,
transports/tcptransport, transports/tcptransport,
@ -16,7 +17,26 @@ import ../libp2p/[connection,
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "Mplex": suite "Mplex":
teardown:
let
trackers = [
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "encode header with channel id 0": test "encode header with channel id 0":
proc testEncodeHeader(): Future[bool] {.async.} = proc testEncodeHeader(): Future[bool] {.async.} =
proc encHandler(msg: seq[byte]) {.async.} = proc encHandler(msg: seq[byte]) {.async.} =
@ -25,8 +45,11 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1")) await conn.writeMsg(0, MessageType.New, cast[seq[byte]]("stream 1"))
result = true result = true
await stream.close()
check: check:
waitFor(testEncodeHeader()) == true waitFor(testEncodeHeader()) == true
@ -38,8 +61,11 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1")) await conn.writeMsg(17, MessageType.New, cast[seq[byte]]("stream 1"))
result = true result = true
await stream.close()
check: check:
waitFor(testEncodeHeader()) == true waitFor(testEncodeHeader()) == true
@ -52,8 +78,11 @@ suite "Mplex":
let stream = newBufferStream(encHandler) let stream = newBufferStream(encHandler)
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1")) await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
result = true result = true
await stream.close()
check: check:
waitFor(testEncodeHeaderBody()) == true waitFor(testEncodeHeaderBody()) == true
@ -67,8 +96,11 @@ suite "Mplex":
let conn = newConnection(stream) let conn = newConnection(stream)
await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1")) await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]]("stream 1"))
await conn.close() await conn.close()
result = true result = true
await stream.close()
check: check:
waitFor(testEncodeHeaderBody()) == true waitFor(testEncodeHeaderBody()) == true
@ -81,8 +113,11 @@ suite "Mplex":
check msg.id == 0 check msg.id == 0
check msg.msgType == MessageType.New check msg.msgType == MessageType.New
result = true result = true
await stream.close()
check: check:
waitFor(testDecodeHeader()) == true waitFor(testDecodeHeader()) == true
@ -96,8 +131,11 @@ suite "Mplex":
check msg.id == 0 check msg.id == 0
check msg.msgType == MessageType.MsgOut check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!" check cast[string](msg.data) == "hello from channel 0!!"
result = true result = true
await stream.close()
check: check:
waitFor(testDecodeHeader()) == true waitFor(testDecodeHeader()) == true
@ -111,8 +149,11 @@ suite "Mplex":
check msg.id == 17 check msg.id == 17
check msg.msgType == MessageType.MsgOut check msg.msgType == MessageType.MsgOut
check cast[string](msg.data) == "hello from channel 0!!" check cast[string](msg.data) == "hello from channel 0!!"
result = true result = true
await stream.close()
check: check:
waitFor(testDecodeHeader()) == true waitFor(testDecodeHeader()) == true
@ -120,21 +161,25 @@ suite "Mplex":
proc testNewStream(): Future[bool] {.async.} = proc testNewStream(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
var
done = newFuture[void]()
done2 = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
proc handleMplexListen(stream: Connection) {.async, gcsafe.} = proc handleMplexListen(stream: Connection) {.async, gcsafe.} =
let msg = await stream.readLp() let msg = await stream.readLp()
check cast[string](msg) == "Hello from stream!" check cast[string](msg) == "Hello from stream!"
await stream.close() await stream.close()
done.complete()
let mplexListen = newMplex(conn) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen mplexListen.streamHandler = handleMplexListen
discard mplexListen.handle() await mplexListen.handle()
await conn.close()
done2.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
discard await transport1.listen(ma, connHandler) let lfut = await transport1.listen(ma, connHandler)
defer:
await transport1.close()
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(transport1.ma) let conn = await transport2.dial(transport1.ma)
@ -145,8 +190,17 @@ suite "Mplex":
await stream.writeLp("Hello from stream!") await stream.writeLp("Hello from stream!")
await conn.close() await conn.close()
check openState # not lazy check openState # not lazy
result = true result = true
await done.wait(5000.millis)
await done2.wait(5000.millis)
await stream.close()
await conn.close()
await transport2.close()
await transport1.close()
await lfut
check: check:
waitFor(testNewStream()) == true waitFor(testNewStream()) == true
@ -154,15 +208,21 @@ suite "Mplex":
proc testNewStream(): Future[bool] {.async.} = proc testNewStream(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
var
done = newFuture[void]()
done2 = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
proc handleMplexListen(stream: Connection) {.async, gcsafe.} = proc handleMplexListen(stream: Connection) {.async, gcsafe.} =
let msg = await stream.readLp() let msg = await stream.readLp()
check cast[string](msg) == "Hello from stream!" check cast[string](msg) == "Hello from stream!"
await stream.close() await stream.close()
done.complete()
let mplexListen = newMplex(conn) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen mplexListen.streamHandler = handleMplexListen
discard mplexListen.handle() await mplexListen.handle()
done2.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
let listenFut = await transport1.listen(ma, connHandler) let listenFut = await transport1.listen(ma, connHandler)
@ -179,7 +239,12 @@ suite "Mplex":
check not openState # assert lazy check not openState # assert lazy
result = true result = true
await done.wait(5000.millis)
await done2.wait(5000.millis)
await conn.close()
await stream.close()
await mplexDial.close() await mplexDial.close()
await transport2.close()
await transport1.close() await transport1.close()
await listenFut await listenFut
@ -214,8 +279,6 @@ suite "Mplex":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(transport1.ma) let conn = await transport2.dial(transport1.ma)
defer:
await conn.close()
let mplexDial = newMplex(conn) let mplexDial = newMplex(conn)
let stream = await mplexDial.newStream() let stream = await mplexDial.newStream()
@ -228,7 +291,10 @@ suite "Mplex":
result = true result = true
await stream.close()
await mplexDial.close() await mplexDial.close()
await conn.close()
await transport2.close()
await transport1.close() await transport1.close()
await listenFut await listenFut
@ -239,10 +305,13 @@ suite "Mplex":
proc testNewStream(): Future[bool] {.async.} = proc testNewStream(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let done = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
proc handleMplexListen(stream: Connection) {.async, gcsafe.} = proc handleMplexListen(stream: Connection) {.async, gcsafe.} =
await stream.writeLp("Hello from stream!") await stream.writeLp("Hello from stream!")
await stream.close() await stream.close()
done.complete()
let mplexListen = newMplex(conn) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen mplexListen.streamHandler = handleMplexListen
@ -259,11 +328,15 @@ suite "Mplex":
let stream = await mplexDial.newStream("DIALER") let stream = await mplexDial.newStream("DIALER")
let msg = cast[string](await stream.readLp()) let msg = cast[string](await stream.readLp())
check msg == "Hello from stream!" check msg == "Hello from stream!"
await conn.close()
# await dialFut # await dialFut
result = true result = true
await done.wait(5000.millis)
await stream.close()
await conn.close()
await mplexDial.close() await mplexDial.close()
await transport2.close()
await transport1.close() await transport1.close()
await listenFut await listenFut
@ -274,6 +347,8 @@ suite "Mplex":
proc testNewStream(): Future[bool] {.async.} = proc testNewStream(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let done = newFuture[void]()
var count = 1 var count = 1
var listenConn: Connection var listenConn: Connection
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
@ -282,6 +357,8 @@ suite "Mplex":
check cast[string](msg) == &"stream {count}!" check cast[string](msg) == &"stream {count}!"
count.inc count.inc
await stream.close() await stream.close()
if count == 10:
done.complete()
listenConn = conn listenConn = conn
let mplexListen = newMplex(conn) let mplexListen = newMplex(conn)
@ -300,9 +377,9 @@ suite "Mplex":
await stream.writeLp(&"stream {i}!") await stream.writeLp(&"stream {i}!")
await stream.close() await stream.close()
await sleepAsync(1.seconds) # allow messages to get to the handler await done.wait(5000.millis)
await conn.close() # TODO: chronos sockets don't seem to have half-closed functionality await conn.close()
await transport2.close()
await mplexDial.close() await mplexDial.close()
await listenConn.close() await listenConn.close()
await transport1.close() await transport1.close()
@ -318,8 +395,8 @@ suite "Mplex":
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
var count = 1 var count = 1
var listenFut: Future[void]
var listenConn: Connection var listenConn: Connection
let done = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
listenConn = conn listenConn = conn
proc handleMplexListen(stream: Connection) {.async, gcsafe.} = proc handleMplexListen(stream: Connection) {.async, gcsafe.} =
@ -328,12 +405,12 @@ suite "Mplex":
await stream.writeLp(&"stream {count} from listener!") await stream.writeLp(&"stream {count} from listener!")
count.inc count.inc
await stream.close() await stream.close()
if count == 10:
done.complete()
let mplexListen = newMplex(conn) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen mplexListen.streamHandler = handleMplexListen
listenFut = mplexListen.handle() await mplexListen.handle()
listenFut.addCallback(proc(udata: pointer) {.gcsafe.}
= trace "completed listener")
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
let transportFut = await transport1.listen(ma, connHandler) let transportFut = await transport1.listen(ma, connHandler)
@ -352,9 +429,12 @@ suite "Mplex":
check cast[string](msg) == &"stream {i} from listener!" check cast[string](msg) == &"stream {i} from listener!"
await stream.close() await stream.close()
await done.wait(5.seconds)
await conn.close() await conn.close()
await listenConn.close() await listenConn.close()
await allFutures(dialFut, listenFut) await allFuturesThrowing(dialFut)
await mplexDial.close()
await transport2.close()
await transport1.close() await transport1.close()
await transportFut await transportFut
result = true result = true
@ -365,9 +445,16 @@ suite "Mplex":
test "half closed - channel should close for write": test "half closed - channel should close for write":
proc testClosedForWrite(): Future[void] {.async.} = proc testClosedForWrite(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let
await chann.close() buff = newBufferStream(writeHandler)
await chann.write("Hello") conn = newConnection(buff)
chann = newChannel(1, conn, true)
try:
await chann.close()
await chann.write("Hello")
finally:
await chann.cleanUp()
await conn.close()
expect LPStreamEOFError: expect LPStreamEOFError:
waitFor(testClosedForWrite()) waitFor(testClosedForWrite())
@ -375,12 +462,19 @@ suite "Mplex":
test "half closed - channel should close for read by remote": test "half closed - channel should close for read by remote":
proc testClosedForRead(): Future[void] {.async.} = proc testClosedForRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let
buff = newBufferStream(writeHandler)
conn = newConnection(buff)
chann = newChannel(1, conn, true)
await chann.pushTo(cast[seq[byte]]("Hello!")) try:
await chann.closedByRemote() await chann.pushTo(cast[seq[byte]]("Hello!"))
discard await chann.read() # this should work, since there is data in the buffer await chann.closedByRemote()
discard await chann.read() # this should throw discard await chann.read() # this should work, since there is data in the buffer
discard await chann.read() # this should throw
finally:
await chann.cleanUp()
await conn.close()
expect LPStreamEOFError: expect LPStreamEOFError:
waitFor(testClosedForRead()) waitFor(testClosedForRead())
@ -445,6 +539,7 @@ suite "Mplex":
await conn.close() await conn.close()
await complete await complete
await transport2.close()
await transport1.close() await transport1.close()
await listenFut await listenFut
@ -502,7 +597,7 @@ suite "Mplex":
await stream.close() await stream.close()
await conn.close() await conn.close()
await complete await complete
await transport2.close()
await transport1.close() await transport1.close()
await listenFut await listenFut
@ -514,10 +609,18 @@ suite "Mplex":
test "reset - channel should fail reading": test "reset - channel should fail reading":
proc testResetRead(): Future[void] {.async.} = proc testResetRead(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let
await chann.reset() buff = newBufferStream(writeHandler)
var data = await chann.read() conn = newConnection(buff)
doAssert(len(data) == 1) chann = newChannel(1, conn, true)
try:
await chann.reset()
var data = await chann.read()
doAssert(len(data) == 1)
finally:
await chann.cleanUp()
await conn.close()
expect LPStreamEOFError: expect LPStreamEOFError:
waitFor(testResetRead()) waitFor(testResetRead())
@ -525,9 +628,16 @@ suite "Mplex":
test "reset - channel should fail writing": test "reset - channel should fail writing":
proc testResetWrite(): Future[void] {.async.} = proc testResetWrite(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let
await chann.reset() buff = newBufferStream(writeHandler)
await chann.write(cast[seq[byte]]("Hello!")) conn = newConnection(buff)
chann = newChannel(1, conn, true)
try:
await chann.reset()
await chann.write(cast[seq[byte]]("Hello!"))
finally:
await chann.cleanUp()
await conn.close()
expect LPStreamEOFError: expect LPStreamEOFError:
waitFor(testResetWrite()) waitFor(testResetWrite())
@ -535,9 +645,16 @@ suite "Mplex":
test "should not allow pushing data to channel when remote end closed": test "should not allow pushing data to channel when remote end closed":
proc testResetWrite(): Future[void] {.async.} = proc testResetWrite(): Future[void] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) let
await chann.closedByRemote() buff = newBufferStream(writeHandler)
await chann.pushTo(@[byte(1)]) conn = newConnection(buff)
chann = newChannel(1, conn, true)
try:
await chann.closedByRemote()
await chann.pushTo(@[byte(1)])
finally:
await chann.cleanUp()
await conn.close()
expect LPStreamEOFError: expect LPStreamEOFError:
waitFor(testResetWrite()) waitFor(testResetWrite())

View File

@ -1,6 +1,7 @@
import unittest, strutils, sequtils, strformat, options import unittest, strutils, sequtils, strformat, options
import chronos import chronos
import ../libp2p/connection, import ../libp2p/errors,
../libp2p/connection,
../libp2p/multistream, ../libp2p/multistream,
../libp2p/stream/lpstream, ../libp2p/stream/lpstream,
../libp2p/stream/bufferstream, ../libp2p/stream/bufferstream,
@ -20,6 +21,10 @@ type
TestSelectStream = ref object of LPStream TestSelectStream = ref object of LPStream
step*: int step*: int
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
method readExactly*(s: TestSelectStream, method readExactly*(s: TestSelectStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async, gcsafe.} = nbytes: int): Future[void] {.async, gcsafe.} =
@ -155,11 +160,27 @@ proc newTestNaStream(na: NaHandler): TestNaStream =
result.step = 1 result.step = 1
suite "Multistream select": suite "Multistream select":
teardown:
let
trackers = [
# getTracker(ConnectionTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "test select custom proto": test "test select custom proto":
proc testSelect(): Future[bool] {.async.} = proc testSelect(): Future[bool] {.async.} =
let ms = newMultistream() let ms = newMultistream()
let conn = newConnection(newTestSelectStream()) let conn = newConnection(newTestSelectStream())
result = (await ms.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0" result = (await ms.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0"
await conn.close()
check: check:
waitFor(testSelect()) == true waitFor(testSelect()) == true
@ -190,10 +211,12 @@ suite "Multistream select":
proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} # forward declaration proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} # forward declaration
let conn = newConnection(newTestLsStream(testLsHandler)) let conn = newConnection(newTestLsStream(testLsHandler))
let done = newFuture[void]()
proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} = proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} =
var strProto: string = cast[string](proto) var strProto: string = cast[string](proto)
check strProto == "\x26/test/proto1/1.0.0\n/test/proto2/1.0.0\n" check strProto == "\x26/test/proto1/1.0.0\n/test/proto2/1.0.0\n"
await conn.close() await conn.close()
done.complete()
proc testHandler(conn: Connection, proto: string): Future[void] proc testHandler(conn: Connection, proto: string): Future[void]
{.async, gcsafe.} = discard {.async, gcsafe.} = discard
@ -204,6 +227,8 @@ suite "Multistream select":
await ms.handle(conn) await ms.handle(conn)
result = true result = true
await done.wait(5.seconds)
check: check:
waitFor(testLs()) == true waitFor(testLs()) == true
@ -235,6 +260,10 @@ suite "Multistream select":
proc endToEnd(): Future[bool] {.async.} = proc endToEnd(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let
handlerWait1 = newFuture[void]()
handlerWait2 = newFuture[void]()
var protocol: LPProtocol = new LPProtocol var protocol: LPProtocol = new LPProtocol
proc testHandler(conn: Connection, proc testHandler(conn: Connection,
proto: string): proto: string):
@ -242,6 +271,7 @@ suite "Multistream select":
check proto == "/test/proto/1.0.0" check proto == "/test/proto/1.0.0"
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
await conn.close() await conn.close()
handlerWait1.complete()
protocol.handler = testHandler protocol.handler = testHandler
let msListen = newMultistream() let msListen = newMultistream()
@ -249,6 +279,8 @@ suite "Multistream select":
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =
await msListen.handle(conn) await msListen.handle(conn)
await conn.close()
handlerWait2.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
asyncCheck transport1.listen(ma, connHandler) asyncCheck transport1.listen(ma, connHandler)
@ -263,6 +295,11 @@ suite "Multistream select":
result = hello == "Hello!" result = hello == "Hello!"
await conn.close() await conn.close()
await transport2.close()
await transport1.close()
await allFuturesThrowing(handlerWait1.wait(5000.millis) #[if OK won't happen!!]#, handlerWait2.wait(5000.millis) #[if OK won't happen!!]#)
check: check:
waitFor(endToEnd()) == true waitFor(endToEnd()) == true
@ -270,13 +307,21 @@ suite "Multistream select":
proc endToEnd(): Future[bool] {.async.} = proc endToEnd(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let
handlerWait = newFuture[void]()
let msListen = newMultistream() let msListen = newMultistream()
var protocol: LPProtocol = new LPProtocol var protocol: LPProtocol = new LPProtocol
protocol.handler = proc(conn: Connection, proto: string) {.async, gcsafe.} = protocol.handler = proc(conn: Connection, proto: string) {.async, gcsafe.} =
await conn.close() # never reached
discard
proc testHandler(conn: Connection, proc testHandler(conn: Connection,
proto: string): proto: string):
Future[void] {.async.} = discard Future[void] {.async.} =
# never reached
discard
protocol.handler = testHandler protocol.handler = testHandler
msListen.addHandler("/test/proto1/1.0.0", protocol) msListen.addHandler("/test/proto1/1.0.0", protocol)
msListen.addHandler("/test/proto2/1.0.0", protocol) msListen.addHandler("/test/proto2/1.0.0", protocol)
@ -284,6 +329,8 @@ suite "Multistream select":
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =
await msListen.handle(conn) await msListen.handle(conn)
handlerWait.complete()
asyncCheck transport1.listen(ma, connHandler) asyncCheck transport1.listen(ma, connHandler)
let msDial = newMultistream() let msDial = newMultistream()
@ -292,9 +339,15 @@ suite "Multistream select":
let ls = await msDial.list(conn) let ls = await msDial.list(conn)
let protos: seq[string] = @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] let protos: seq[string] = @["/test/proto1/1.0.0", "/test/proto2/1.0.0"]
await conn.close()
result = ls == protos result = ls == protos
await conn.close()
await transport2.close()
await transport1.close()
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
check: check:
waitFor(endToEnd()) == true waitFor(endToEnd()) == true
@ -329,7 +382,11 @@ suite "Multistream select":
let hello = cast[string](await conn.readLp()) let hello = cast[string](await conn.readLp())
result = hello == "Hello!" result = hello == "Hello!"
await conn.close() await conn.close()
await transport2.close()
await transport1.close()
check: check:
waitFor(endToEnd()) == true waitFor(endToEnd()) == true
@ -363,7 +420,10 @@ suite "Multistream select":
check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0" check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0"
result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!" result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!"
await conn.close() await conn.close()
await transport2.close()
await transport1.close()
check: check:
waitFor(endToEnd()) == true waitFor(endToEnd()) == true

View File

@ -1,6 +1,16 @@
import testvarint import testvarint
import testrsa, testecnist, tested25519, testsecp256k1, testcrypto
import testmultibase, testmultihash, testmultiaddress, testcid, testpeer import testrsa,
testecnist,
tested25519,
testsecp256k1,
testcrypto
import testmultibase,
testmultihash,
testmultiaddress,
testcid,
testpeer
import testtransport, import testtransport,
testmultistream, testmultistream,
@ -9,7 +19,5 @@ import testtransport,
testswitch, testswitch,
testnoise, testnoise,
testpeerinfo, testpeerinfo,
pubsub/testpubsub, testmplex,
# TODO: placing this before pubsub tests, pubsub/testpubsub
# breaks some flood and gossip tests - no idea why
testmplex

View File

@ -13,7 +13,9 @@ import chronicles
import nimcrypto/sysrand import nimcrypto/sysrand
import ../libp2p/crypto/crypto import ../libp2p/crypto/crypto
import ../libp2p/[switch, import ../libp2p/[switch,
errors,
multistream, multistream,
stream/bufferstream,
protocols/identify, protocols/identify,
connection, connection,
transports/transport, transports/transport,
@ -29,7 +31,10 @@ import ../libp2p/[switch,
protocols/secure/noise, protocols/secure/noise,
protocols/secure/secure] protocols/secure/secure]
const TestCodec = "/test/proto/1.0.0" const
TestCodec = "/test/proto/1.0.0"
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
type type
TestProto = ref object of LPProtocol TestProto = ref object of LPProtocol
@ -64,6 +69,21 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
result = (switch, peerInfo) result = (switch, peerInfo)
suite "Noise": suite "Noise":
teardown:
let
trackers = [
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e: handle write + noise": test "e2e: handle write + noise":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =
let let
@ -75,6 +95,7 @@ suite "Noise":
let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn)
defer: defer:
await sconn.close() await sconn.close()
await conn.close()
await sconn.write(cstring("Hello!"), 6) await sconn.write(cstring("Hello!"), 6)
let let
@ -91,7 +112,9 @@ suite "Noise":
msg = await sconn.read(6) msg = await sconn.read(6)
await sconn.close() await sconn.close()
await conn.close()
await transport1.close() await transport1.close()
await transport2.close()
result = cast[string](msg) == "Hello!" result = cast[string](msg) == "Hello!"
@ -110,6 +133,7 @@ suite "Noise":
let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn)
defer: defer:
await sconn.close() await sconn.close()
await conn.close()
let msg = await sconn.read(6) let msg = await sconn.read(6)
check cast[string](msg) == "Hello!" check cast[string](msg) == "Hello!"
readTask.complete() readTask.complete()
@ -128,53 +152,58 @@ suite "Noise":
await sconn.write("Hello!".cstring, 6) await sconn.write("Hello!".cstring, 6)
await readTask await readTask
await sconn.close() await sconn.close()
await conn.close()
await transport1.close() await transport1.close()
await transport2.close()
result = true result = true
check: check:
waitFor(testListenerDialer()) == true waitFor(testListenerDialer()) == true
# test "e2e: handle read + noise fragmented": test "e2e: handle read + noise fragmented":
# proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =
# let let
# server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
# serverInfo = PeerInfo.init(PrivateKey.random(RSA), [server]) serverInfo = PeerInfo.init(PrivateKey.random(RSA), [server])
# serverNoise = newNoise(serverInfo.privateKey, outgoing = false) serverNoise = newNoise(serverInfo.privateKey, outgoing = false)
# readTask = newFuture[void]() readTask = newFuture[void]()
# var hugePayload = newSeq[byte](0xFFFFF) var hugePayload = newSeq[byte](0xFFFFF)
# check randomBytes(hugePayload) == hugePayload.len check randomBytes(hugePayload) == hugePayload.len
# trace "Sending huge payload", size = hugePayload.len trace "Sending huge payload", size = hugePayload.len
# proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
# let sconn = await serverNoise.secure(conn) let sconn = await serverNoise.secure(conn)
# defer: defer:
# await sconn.close() await sconn.close()
# let msg = await sconn.readLp() let msg = await sconn.readLp()
# check msg == hugePayload check msg == hugePayload
# readTask.complete() readTask.complete()
# let let
# transport1: TcpTransport = newTransport(TcpTransport) transport1: TcpTransport = newTransport(TcpTransport)
# asyncCheck await transport1.listen(server, connHandler) asyncCheck await transport1.listen(server, connHandler)
# let let
# transport2: TcpTransport = newTransport(TcpTransport) transport2: TcpTransport = newTransport(TcpTransport)
# clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma]) clientInfo = PeerInfo.init(PrivateKey.random(RSA), [transport1.ma])
# clientNoise = newNoise(clientInfo.privateKey, outgoing = true) clientNoise = newNoise(clientInfo.privateKey, outgoing = true)
# conn = await transport2.dial(transport1.ma) conn = await transport2.dial(transport1.ma)
# sconn = await clientNoise.secure(conn) sconn = await clientNoise.secure(conn)
# await sconn.writeLp(hugePayload) await sconn.writeLp(hugePayload)
# await readTask await readTask
# await sconn.close()
# await transport1.close()
# result = true await sconn.close()
await conn.close()
await transport2.close()
await transport1.close()
# check: result = true
# waitFor(testListenerDialer()) == true
check:
waitFor(testListenerDialer()) == true
test "e2e use switch dial proto string": test "e2e use switch dial proto string":
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch(): Future[bool] {.async, gcsafe.} =
@ -199,8 +228,8 @@ suite "Noise":
let msg = cast[string](await conn.readLp()) let msg = cast[string](await conn.readLp())
check "Hello!" == msg check "Hello!" == msg
await allFutures(switch1.stop(), switch2.stop()) await allFuturesThrowing(switch1.stop(), switch2.stop())
await allFutures(awaiters) await allFuturesThrowing(awaiters)
result = true result = true
check: check:

View File

@ -5,7 +5,24 @@ import ../libp2p/crypto/crypto,
../libp2p/peerinfo, ../libp2p/peerinfo,
../libp2p/peer ../libp2p/peer
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
suite "PeerInfo": suite "PeerInfo":
teardown:
let
trackers = [
getTracker(AsyncStreamWriterTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "Should init with private key": test "Should init with private key":
let seckey = PrivateKey.random(RSA) let seckey = PrivateKey.random(RSA)
var peerInfo = PeerInfo.init(seckey) var peerInfo = PeerInfo.init(seckey)

View File

@ -2,8 +2,10 @@ import unittest, tables
import chronos import chronos
import chronicles import chronicles
import nimcrypto/sysrand import nimcrypto/sysrand
import ../libp2p/[switch, import ../libp2p/[errors,
switch,
multistream, multistream,
stream/bufferstream,
protocols/identify, protocols/identify,
connection, connection,
transports/transport, transports/transport,
@ -22,7 +24,10 @@ import ../libp2p/[switch,
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
const TestCodec = "/test/proto/1.0.0" const
TestCodec = "/test/proto/1.0.0"
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
type type
TestProto = ref object of LPProtocol TestProto = ref object of LPProtocol
@ -47,6 +52,22 @@ proc createSwitch(ma: MultiAddress): (Switch, PeerInfo) =
result = (switch, peerInfo) result = (switch, peerInfo)
suite "Switch": suite "Switch":
teardown:
let
trackers = [
# getTracker(ConnectionTrackerName),
getTracker(BufferStreamTrackerName),
getTracker(AsyncStreamWriterTrackerName),
getTracker(TcpTransportTrackerName),
getTracker(AsyncStreamReaderTrackerName),
getTracker(StreamTransportTrackerName),
getTracker(StreamServerTrackerName)
]
for tracker in trackers:
if not isNil(tracker):
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e use switch dial proto string": test "e2e use switch dial proto string":
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch(): Future[bool] {.async, gcsafe.} =
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
@ -58,11 +79,14 @@ suite "Switch":
(switch1, peerInfo1) = createSwitch(ma1) (switch1, peerInfo1) = createSwitch(ma1)
let done = newFuture[void]()
proc handle(conn: Connection, proto: string) {.async, gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
let msg = cast[string](await conn.readLp()) let msg = cast[string](await conn.readLp())
check "Hello!" == msg check "Hello!" == msg
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
await conn.close() await conn.close()
done.complete()
let testProto = new TestProto let testProto = new TestProto
testProto.codec = TestCodec testProto.codec = TestCodec
@ -83,8 +107,15 @@ suite "Switch":
except LPStreamError: except LPStreamError:
result = false result = false
await allFutures(switch1.stop(), switch2.stop()) await allFuturesThrowing(
await allFutures(awaiters) done.wait(5000.millis) #[if OK won't happen!!]#,
conn.close(),
switch1.stop(),
switch2.stop(),
)
# this needs to go at end
await allFuturesThrowing(awaiters)
check: check:
waitFor(testSwitch()) == true waitFor(testSwitch()) == true
@ -125,8 +156,12 @@ suite "Switch":
except LPStreamError: except LPStreamError:
result = false result = false
await allFutures(switch1.stop(), switch2.stop()) await allFuturesThrowing(
await allFutures(awaiters) conn.close(),
switch1.stop(),
switch2.stop()
)
await allFuturesThrowing(awaiters)
check: check:
waitFor(testSwitch()) == true waitFor(testSwitch()) == true
@ -164,7 +199,10 @@ suite "Switch":
# await sconn.write(hugePayload) # await sconn.write(hugePayload)
# await readTask # await readTask
# await sconn.close() # await sconn.close()
# await conn.close()
# await transport2.close()
# await transport1.close() # await transport1.close()
# result = true # result = true

View File

@ -1,6 +1,7 @@
import unittest import unittest
import chronos import chronos
import ../libp2p/[connection, import ../libp2p/[errors,
connection,
transports/transport, transports/transport,
transports/tcptransport, transports/tcptransport,
multiaddress, multiaddress,
@ -8,19 +9,45 @@ import ../libp2p/[connection,
when defined(nimHasUsed): {.used.} when defined(nimHasUsed): {.used.}
const
StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server"
template ignoreErrors(body: untyped): untyped =
try:
body
except:
echo getCurrentExceptionMsg()
suite "TCP transport": suite "TCP transport":
teardown:
check:
# getTracker(ConnectionTrackerName).isLeaked() == false
getTracker(AsyncStreamReaderTrackerName).isLeaked() == false
getTracker(AsyncStreamWriterTrackerName).isLeaked() == false
getTracker(StreamTransportTrackerName).isLeaked() == false
getTracker(StreamServerTrackerName).isLeaked() == false
test "test listener: handle write": test "test listener: handle write":
proc testListener(): Future[bool] {.async, gcsafe.} = proc testListener(): Future[bool] {.async, gcsafe.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = let handlerWait = newFuture[void]()
result = conn.write(cstring("Hello!"), 6) proc connHandler(conn: Connection) {.async, gcsafe.} =
await conn.write(cstring("Hello!"), 6)
await conn.close()
handlerWait.complete()
let transport: TcpTransport = newTransport(TcpTransport) let transport: TcpTransport = newTransport(TcpTransport)
asyncCheck await transport.listen(ma, connHandler)
let streamTransport: StreamTransport = await connect(transport.ma) asyncCheck transport.listen(ma, connHandler)
let streamTransport = await connect(transport.ma)
let msg = await streamTransport.read(6) let msg = await streamTransport.read(6)
await transport.close()
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await streamTransport.closeWait() await streamTransport.closeWait()
await transport.close()
result = cast[string](msg) == "Hello!" result = cast[string](msg) == "Hello!"
@ -30,14 +57,22 @@ suite "TCP transport":
test "test listener: handle read": test "test listener: handle read":
proc testListener(): Future[bool] {.async.} = proc testListener(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = let handlerWait = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} =
let msg = await conn.read(6) let msg = await conn.read(6)
check cast[string](msg) == "Hello!" check cast[string](msg) == "Hello!"
await conn.close()
handlerWait.complete()
let transport: TcpTransport = newTransport(TcpTransport) let transport: TcpTransport = newTransport(TcpTransport)
asyncCheck await transport.listen(ma, connHandler) asyncCheck await transport.listen(ma, connHandler)
let streamTransport: StreamTransport = await connect(transport.ma) let streamTransport: StreamTransport = await connect(transport.ma)
let sent = await streamTransport.write("Hello!", 6) let sent = await streamTransport.write("Hello!", 6)
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await streamTransport.closeWait()
await transport.close()
result = sent == 6 result = sent == 6
check: check:
@ -45,6 +80,7 @@ suite "TCP transport":
test "test dialer: handle write": test "test dialer: handle write":
proc testDialer(address: TransportAddress): Future[bool] {.async.} = proc testDialer(address: TransportAddress): Future[bool] {.async.} =
let handlerWait = newFuture[void]()
proc serveClient(server: StreamServer, proc serveClient(server: StreamServer,
transp: StreamTransport) {.async, gcsafe.} = transp: StreamTransport) {.async, gcsafe.} =
var wstream = newAsyncStreamWriter(transp) var wstream = newAsyncStreamWriter(transp)
@ -54,6 +90,7 @@ suite "TCP transport":
await transp.closeWait() await transp.closeWait()
server.stop() server.stop()
server.close() server.close()
handlerWait.complete()
var server = createStreamServer(address, serveClient, {ReuseAddr}) var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start() server.start()
@ -64,13 +101,21 @@ suite "TCP transport":
let msg = await conn.read(6) let msg = await conn.read(6)
result = cast[string](msg) == "Hello!" result = cast[string](msg) == "Hello!"
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await conn.close()
await transport.close()
server.stop() server.stop()
server.close() server.close()
await server.join() await server.join()
check waitFor(testDialer(initTAddress("0.0.0.0:0"))) == true
check:
waitFor(testDialer(initTAddress("0.0.0.0:0"))) == true
test "test dialer: handle write": test "test dialer: handle write":
proc testDialer(address: TransportAddress): Future[bool] {.async, gcsafe.} = proc testDialer(address: TransportAddress): Future[bool] {.async, gcsafe.} =
let handlerWait = newFuture[void]()
proc serveClient(server: StreamServer, proc serveClient(server: StreamServer,
transp: StreamTransport) {.async, gcsafe.} = transp: StreamTransport) {.async, gcsafe.} =
var rstream = newAsyncStreamReader(transp) var rstream = newAsyncStreamReader(transp)
@ -81,6 +126,7 @@ suite "TCP transport":
await transp.closeWait() await transp.closeWait()
server.stop() server.stop()
server.close() server.close()
handlerWait.complete()
var server = createStreamServer(address, serveClient, {ReuseAddr}) var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start() server.start()
@ -91,23 +137,37 @@ suite "TCP transport":
await conn.write(cstring("Hello!"), 6) await conn.write(cstring("Hello!"), 6)
result = true result = true
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await conn.close()
await transport.close()
server.stop() server.stop()
server.close() server.close()
await server.join() await server.join()
check waitFor(testDialer(initTAddress("0.0.0.0:0"))) == true check:
waitFor(testDialer(initTAddress("0.0.0.0:0"))) == true
test "e2e: handle write": test "e2e: handle write":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = let handlerWait = newFuture[void]()
result = conn.write(cstring("Hello!"), 6) proc connHandler(conn: Connection) {.async, gcsafe.} =
await conn.write(cstring("Hello!"), 6)
await conn.close()
handlerWait.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
asyncCheck await transport1.listen(ma, connHandler) asyncCheck transport1.listen(ma, connHandler)
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(transport1.ma) let conn = await transport2.dial(transport1.ma)
let msg = await conn.read(6) let msg = await conn.read(6)
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await conn.close()
await transport2.close()
await transport1.close() await transport1.close()
result = cast[string](msg) == "Hello!" result = cast[string](msg) == "Hello!"
@ -118,16 +178,24 @@ suite "TCP transport":
test "e2e: handle read": test "e2e: handle read":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = let handlerWait = newFuture[void]()
proc connHandler(conn: Connection) {.async, gcsafe.} =
let msg = await conn.read(6) let msg = await conn.read(6)
check cast[string](msg) == "Hello!" check cast[string](msg) == "Hello!"
await conn.close()
handlerWait.complete()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
asyncCheck await transport1.listen(ma, connHandler) asyncCheck transport1.listen(ma, connHandler)
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(transport1.ma) let conn = await transport2.dial(transport1.ma)
await conn.write(cstring("Hello!"), 6) await conn.write(cstring("Hello!"), 6)
await handlerWait.wait(5000.millis) # when no issues will not wait that long!
await conn.close()
await transport2.close()
await transport1.close() await transport1.close()
result = true result = true