channel close race and deadlock fixes (#368)

* channel close race and deadlock fixes

* remove send lock, write chunks in one go
* push some of half-closed implementation to BufferStream
* fix some hangs where LPChannel readers and writers would not always
wake up
* simplify lazy channels
* fix close happening more than once in some orderings
* reenable connection tracking tests
* close channels first on mplex close such that consumers can read bytes

A notable difference is that BufferedStream is no longer considered EOF
until someone has actually read the EOF marker.

* docs, simplification
This commit is contained in:
Jacek Sieka 2020-09-21 19:48:19 +02:00 committed by GitHub
parent b99d2039a8
commit 49a12e619d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 287 additions and 271 deletions

View File

@ -51,33 +51,47 @@ proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
if msgType.int > ord(MessageType.ResetOut): if msgType.int > ord(MessageType.ResetOut):
raise newInvalidMplexMsgType() raise newInvalidMplexMsgType()
result = (header shr 3, MessageType(msgType), data) return (header shr 3, MessageType(msgType), data)
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint64, id: uint64,
msgType: MessageType, msgType: MessageType,
data: seq[byte] = @[]) {.async, gcsafe.} = data: seq[byte] = @[]) {.async, gcsafe.} =
trace "sending data over mplex", conn, if conn.closed:
id, return # No point in trying to write to an already-closed connection
msgType,
data = data.len
var var
left = data.len left = data.len
offset = 0 offset = 0
buf = initVBuffer()
# Split message into length-prefixed chunks
while left > 0 or data.len == 0: while left > 0 or data.len == 0:
let let
chunkSize = if left > MaxMsgSize: MaxMsgSize - 64 else: left chunkSize = if left > MaxMsgSize: MaxMsgSize - 64 else: left
## write length prefixed
var buf = initVBuffer()
buf.writePBVarint(id shl 3 or ord(msgType).uint64) buf.writePBVarint(id shl 3 or ord(msgType).uint64)
buf.writeSeq(data.toOpenArray(offset, offset + chunkSize - 1)) buf.writeSeq(data.toOpenArray(offset, offset + chunkSize - 1))
buf.finish()
left = left - chunkSize left = left - chunkSize
offset = offset + chunkSize offset = offset + chunkSize
await conn.write(buf.buffer)
if data.len == 0: if data.len == 0:
return break
trace "writing mplex message",
conn, id, msgType, data = data.len, encoded = buf.buffer.len
try:
# Write all chunks in a single write to avoid async races where a close
# message gets written before some of the chunks
await conn.write(buf.buffer)
trace "wrote mplex", conn, id, msgType
except CatchableError as exc:
# If the write to the underlying connection failed it should be closed so
# that the other channels are notified as well
trace "failed write", conn, id, msg = exc.msg
await conn.close()
raise exc
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint64, id: uint64,

View File

@ -27,19 +27,9 @@ logScope:
## | Read | Yes (until EOF) | No ## | Read | Yes (until EOF) | No
## | Write | No | Yes ## | Write | No | Yes
## ##
## Channels are considered fully closed when both outgoing and incoming
# TODO: this is one place where we need to use ## directions are closed and when the reader of the channel has read the
# a proper state machine, but I've opted out of ## EOF marker
# it for now for two reasons:
#
# 1) we don't have that many states to manage
# 2) I'm not sure if adding the state machine
# would have simplified or complicated the code
#
# But now that this is in place, we should perhaps
# reconsider reworking it again, this time with a
# more formal approach.
#
type type
LPChannel* = ref object of BufferStream LPChannel* = ref object of BufferStream
@ -47,154 +37,103 @@ type
name*: string # name of the channel (for debugging) name*: string # name of the channel (for debugging)
conn*: Connection # wrapped connection used to for writing conn*: Connection # wrapped connection used to for writing
initiator*: bool # initiated remotely or locally flag initiator*: bool # initiated remotely or locally flag
isLazy*: bool # is channel lazy isOpen*: bool # has channel been opened
isOpen*: bool # has channel been opened (only used with isLazy)
isReset*: bool # channel was reset, pushTo should drop data
pushing*: bool
closedLocal*: bool # has channel been closed locally closedLocal*: bool # has channel been closed locally
msgCode*: MessageType # cached in/out message code msgCode*: MessageType # cached in/out message code
closeCode*: MessageType # cached in/out close code closeCode*: MessageType # cached in/out close code
resetCode*: MessageType # cached in/out reset code resetCode*: MessageType # cached in/out reset code
writeLock: AsyncLock
proc open*(s: LPChannel) {.async, gcsafe.} proc open*(s: LPChannel) {.async, gcsafe.}
template withWriteLock(lock: AsyncLock, body: untyped): untyped =
try:
await lock.acquire()
body
finally:
if not(isNil(lock)) and lock.locked:
lock.release()
func shortLog*(s: LPChannel): auto = func shortLog*(s: LPChannel): auto =
if s.isNil: "LPChannel(nil)" if s.isNil: "LPChannel(nil)"
elif s.conn.peerInfo.isNil: $s.oid elif s.conn.peerInfo.isNil: $s.oid
elif s.name != $s.oid: &"{shortLog(s.conn.peerInfo.peerId)}:{s.oid}:{s.name}" elif s.name != $s.oid and s.name.len > 0:
&"{shortLog(s.conn.peerInfo.peerId)}:{s.oid}:{s.name}"
else: &"{shortLog(s.conn.peerInfo.peerId)}:{s.oid}" else: &"{shortLog(s.conn.peerInfo.peerId)}:{s.oid}"
chronicles.formatIt(LPChannel): shortLog(it) chronicles.formatIt(LPChannel): shortLog(it)
proc closeMessage(s: LPChannel) {.async.} =
## send close message
withWriteLock(s.writeLock):
trace "sending close message", s
await s.conn.writeMsg(s.id, s.closeCode) # write close
proc resetMessage(s: LPChannel) {.async.} =
## send reset message - this will not raise
try:
withWriteLock(s.writeLock):
trace "sending reset message", s
await s.conn.writeMsg(s.id, s.resetCode) # write reset
except CancelledError:
# This procedure is called from one place and never awaited, so there no
# need to re-raise CancelledError.
debug "Unexpected cancellation while resetting channel", s
except LPStreamEOFError as exc:
trace "muxed connection EOF", s, msg = exc.msg
except LPStreamClosedError as exc:
trace "muxed connection closed", s, msg = exc.msg
except LPStreamIncompleteError as exc:
trace "incomplete message", s, msg = exc.msg
except CatchableError as exc:
debug "Unhandled exception leak", s, msg = exc.msg
proc open*(s: LPChannel) {.async, gcsafe.} = proc open*(s: LPChannel) {.async, gcsafe.} =
trace "Opening channel", s, conn = s.conn
await s.conn.writeMsg(s.id, MessageType.New, s.name) await s.conn.writeMsg(s.id, MessageType.New, s.name)
trace "Opened channel", s
s.isOpen = true s.isOpen = true
proc closeRemote*(s: LPChannel) {.async.} =
trace "Closing remote", s
try:
# close parent bufferstream to prevent further reads
await procCall BufferStream(s).close()
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception closing remote channel", s, msg = exc.msg
trace "Closed remote", s
method closed*(s: LPChannel): bool = method closed*(s: LPChannel): bool =
## this emulates half-closed behavior
## when closed locally writing is
## disabled - see the table in the
## header of the file
s.closedLocal s.closedLocal
method pushTo*(s: LPChannel, data: seq[byte]) {.async.} = proc closeUnderlying(s: LPChannel): Future[void] {.async.} =
if s.isReset: ## Channels may be closed for reading and writing in any order - we'll close
raise newLPStreamClosedError() # Terminate mplex loop ## the underlying bufferstream when both directions are closed
try:
s.pushing = true
await procCall BufferStream(s).pushTo(data)
finally:
s.pushing = false
method reset*(s: LPChannel) {.base, async, gcsafe.} =
if s.closedLocal and s.isEof: if s.closedLocal and s.isEof:
trace "channel already closed or reset", s await procCall BufferStream(s).close()
proc reset*(s: LPChannel) {.async, gcsafe.} =
if s.isClosed:
trace "Already closed", s
return return
trace "Resetting channel", s, len = s.len trace "Resetting channel", s, len = s.len
# First, make sure any new calls to `readOnce` and `pushTo` will fail - there # First, make sure any new calls to `readOnce` and `pushData` etc will fail -
# may already be such calls in the event queue # there may already be such calls in the event queue however
s.isEof = true
s.isReset = true
s.readBuf = StreamSeq()
s.closedLocal = true s.closedLocal = true
s.isEof = true
s.readBuf = StreamSeq()
s.pushedEof = true
asyncSpawn s.resetMessage() for i in 0..<s.pushing:
# Make sure to drain any ongoing pushes - there's already at least one item
# more in the queue already so any ongoing reads shouldn't interfere
# Notably, popFirst is not fair - which reader/writer gets woken up depends
discard await s.readQueue.popFirst()
if s.readQueue.len == 0 and s.pushing == 0:
# There is no push ongoing and nothing on the queue - let's place an
# EOF marker there so that any reader is woken up - we don't need to
# synchronize here
await s.readQueue.addLast(@[])
if not s.conn.isClosed:
# If the connection is still active, notify the other end
proc resetMessage() {.async.} =
try:
trace "sending reset message", s, conn = s.conn
await s.conn.writeMsg(s.id, s.resetCode) # write reset
except CatchableError as exc:
# No cancellations, errors handled in writeMsg
trace "Can't send reset message", s, conn = s.conn, msg = exc.msg
asyncSpawn resetMessage()
# This should wake up any readers by pushing an EOF marker at least # This should wake up any readers by pushing an EOF marker at least
await procCall BufferStream(s).close() # noraises, nocancels await procCall BufferStream(s).close() # noraises, nocancels
if s.pushing:
# When data is being pushed, there will be two items competing for the
# readQueue slot - the BufferStream.close EOF marker and the pushTo data.
# If the EOF wins, the pushTo call will get stuck because there will be no
# new readers to clear the data. It's worth noting that if there's a reader
# already waiting for data, this reader will be unblocked by the pushTo -
# this is necessary or it will get stuck
if s.readQueue.len > 0:
discard s.readQueue.popFirstNoWait()
trace "Channel reset", s trace "Channel reset", s
method close*(s: LPChannel) {.async, gcsafe.} = method close*(s: LPChannel) {.async, gcsafe.} =
## Close channel for writing - a message will be sent to the other peer
## informing them that the channel is closed and that we're waiting for
## their acknowledgement.
if s.closedLocal: if s.closedLocal:
trace "Already closed", s trace "Already closed", s
return return
trace "Closing channel", s, len = s.len
proc closeInternal() {.async.} =
try:
await s.closeMessage().wait(2.minutes)
if s.atEof: # already closed by remote close parent buffer immediately
await procCall BufferStream(s).close()
except CancelledError:
debug "Unexpected cancellation while closing channel", s
await s.reset()
# This is top-level procedure which will work as separate task, so it
# do not need to propogate CancelledError.
except LPStreamClosedError, LPStreamEOFError:
trace "Connection already closed", s
except CatchableError as exc: # Shouldn't happen?
warn "Exception closing channel", s, msg = exc.msg
await s.reset()
trace "Closed channel", s
s.closedLocal = true s.closedLocal = true
# All the errors are handled inside `closeInternal()` procedure.
asyncSpawn closeInternal() trace "Closing channel", s, conn = s.conn, len = s.len
if s.isOpen:
try:
await s.conn.writeMsg(s.id, s.closeCode) # write close
except CancelledError as exc:
raise exc
except CatchableError as exc:
# It's harmless that close message cannot be sent - the connection is
# likely down already
trace "Cannot send close message", s, id = s.id
await s.closeUnderlying() # maybe already eofed
trace "Closed channel", s, len = s.len
method initStream*(s: LPChannel) = method initStream*(s: LPChannel) =
if s.objName.len == 0: if s.objName.len == 0:
@ -206,21 +145,33 @@ method initStream*(s: LPChannel) =
procCall BufferStream(s).initStream() procCall BufferStream(s).initStream()
s.writeLock = newAsyncLock() method readOnce*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[int] {.async.} =
try:
let bytes = await procCall BufferStream(s).readOnce(pbytes, nbytes)
trace "readOnce", s, bytes
if bytes == 0:
await s.closeUnderlying()
return bytes
except CatchableError as exc:
await s.closeUnderlying()
raise exc
method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} =
if s.closedLocal: if s.closedLocal:
raise newLPStreamClosedError() raise newLPStreamClosedError()
doAssert msg.len > 0
try: try:
if s.isLazy and not(s.isOpen): if not s.isOpen:
await s.open() await s.open()
# writes should happen in sequence # writes should happen in sequence
trace "write msg", len = msg.len trace "write msg", s, conn = s.conn, len = msg.len
withWriteLock(s.writeLock): await s.conn.writeMsg(s.id, s.msgCode, msg)
await s.conn.writeMsg(s.id, s.msgCode, msg)
s.activity = true s.activity = true
except CatchableError as exc: except CatchableError as exc:
trace "exception in lpchannel write handler", s, msg = exc.msg trace "exception in lpchannel write handler", s, msg = exc.msg
@ -233,7 +184,6 @@ proc init*(
conn: Connection, conn: Connection,
initiator: bool, initiator: bool,
name: string = "", name: string = "",
lazy: bool = false,
timeout: Duration = DefaultChanTimeout): LPChannel = timeout: Duration = DefaultChanTimeout): LPChannel =
let chann = L( let chann = L(
@ -241,8 +191,8 @@ proc init*(
name: name, name: name,
conn: conn, conn: conn,
initiator: initiator, initiator: initiator,
isLazy: lazy,
timeout: timeout, timeout: timeout,
isOpen: if initiator: false else: true,
msgCode: if initiator: MessageType.MsgOut else: MessageType.MsgIn, msgCode: if initiator: MessageType.MsgOut else: MessageType.MsgIn,
closeCode: if initiator: MessageType.CloseOut else: MessageType.CloseIn, closeCode: if initiator: MessageType.CloseOut else: MessageType.CloseIn,
resetCode: if initiator: MessageType.ResetOut else: MessageType.ResetIn, resetCode: if initiator: MessageType.ResetOut else: MessageType.ResetIn,
@ -253,6 +203,6 @@ proc init*(
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
chann.name = if chann.name.len > 0: chann.name else: $chann.oid chann.name = if chann.name.len > 0: chann.name else: $chann.oid
trace "Created new lpchannel", chann trace "Created new lpchannel", chann, id, initiator
return chann return chann

View File

@ -76,7 +76,6 @@ proc newStreamInternal*(m: Mplex,
initiator: bool = true, initiator: bool = true,
chanId: uint64 = 0, chanId: uint64 = 0,
name: string = "", name: string = "",
lazy: bool = false,
timeout: Duration): timeout: Duration):
LPChannel {.gcsafe.} = LPChannel {.gcsafe.} =
## create new channel/stream ## create new channel/stream
@ -93,7 +92,6 @@ proc newStreamInternal*(m: Mplex,
m.connection, m.connection,
initiator, initiator,
name, name,
lazy = lazy,
timeout = timeout) timeout = timeout)
result.peerInfo = m.connection.peerInfo result.peerInfo = m.connection.peerInfo
@ -176,11 +174,11 @@ method handle*(m: Mplex) {.async, gcsafe.} =
raise newLPStreamLimitError() raise newLPStreamLimitError()
trace "pushing data to channel", m, channel, len = data.len trace "pushing data to channel", m, channel, len = data.len
await channel.pushTo(data) await channel.pushData(data)
trace "pushed data to channel", m, channel, len = data.len trace "pushed data to channel", m, channel, len = data.len
of MessageType.CloseIn, MessageType.CloseOut: of MessageType.CloseIn, MessageType.CloseOut:
await channel.closeRemote() await channel.pushEof()
of MessageType.ResetIn, MessageType.ResetOut: of MessageType.ResetIn, MessageType.ResetOut:
await channel.reset() await channel.reset()
except CancelledError: except CancelledError:
@ -208,8 +206,7 @@ proc init*(M: type Mplex,
method newStream*(m: Mplex, method newStream*(m: Mplex,
name: string = "", name: string = "",
lazy: bool = false): Future[Connection] {.async, gcsafe.} = lazy: bool = false): Future[Connection] {.async, gcsafe.} =
let channel = m.newStreamInternal( let channel = m.newStreamInternal(timeout = m.inChannTimeout)
lazy = lazy, timeout = m.inChannTimeout)
if not lazy: if not lazy:
await channel.open() await channel.open()
@ -224,15 +221,21 @@ method close*(m: Mplex) {.async, gcsafe.} =
trace "Closing mplex", m trace "Closing mplex", m
let channs = toSeq(m.channels[false].values) & toSeq(m.channels[true].values) var channs = toSeq(m.channels[false].values) & toSeq(m.channels[true].values)
for chann in channs: for chann in channs:
await chann.reset() await chann.close()
await m.connection.close() await m.connection.close()
# TODO while we're resetting, new channels may be created that will not be # TODO while we're resetting, new channels may be created that will not be
# closed properly # closed properly
channs = toSeq(m.channels[false].values) & toSeq(m.channels[true].values)
for chann in channs:
await chann.reset()
m.channels[false].clear() m.channels[false].clear()
m.channels[true].clear() m.channels[true].clear()

View File

@ -79,7 +79,10 @@ proc setupBufferStreamTracker(): BufferStreamTracker =
type type
BufferStream* = ref object of Connection BufferStream* = ref object of Connection
readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure
readBuf*: StreamSeq # overflow buffer for readOnce readBuf*: StreamSeq # overflow buffer for readOnce
pushing*: int # number of ongoing push operations
pushedEof*: bool
func shortLog*(s: BufferStream): auto = func shortLog*(s: BufferStream): auto =
if s.isNil: "BufferStream(nil)" if s.isNil: "BufferStream(nil)"
@ -106,14 +109,13 @@ proc newBufferStream*(timeout: Duration = DefaultConnectionTimeout): BufferStrea
result.timeout = timeout result.timeout = timeout
result.initStream() result.initStream()
method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} = method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
## Write bytes to internal read buffer, use this to fill up the ## Write bytes to internal read buffer, use this to fill up the
## buffer with data. ## buffer with data.
## ##
## `pushTo` will block if the queue is full, thus maintaining backpressure. ## `pushTo` will block if the queue is full, thus maintaining backpressure.
## ##
if s.isClosed or s.pushedEof:
if s.isClosed:
raise newLPStreamEOFError() raise newLPStreamEOFError()
if data.len == 0: if data.len == 0:
@ -121,15 +123,32 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
# We will block here if there is already data queued, until it has been # We will block here if there is already data queued, until it has been
# processed # processed
trace "Pushing readQueue", s, len = data.len inc s.pushing
await s.readQueue.addLast(data) try:
trace "Pushing data", s, data = data.len
await s.readQueue.addLast(data)
finally:
dec s.pushing
method pushEof*(s: BufferStream) {.base, async.} =
if s.pushedEof:
return
s.pushedEof = true
# We will block here if there is already data queued, until it has been
# processed
inc s.pushing
try:
trace "Pushing EOF", s
await s.readQueue.addLast(@[])
finally:
dec s.pushing
method readOnce*(s: BufferStream, method readOnce*(s: BufferStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):
Future[int] {.async.} = Future[int] {.async.} =
doAssert(nbytes > 0, "nbytes must be positive integer") doAssert(nbytes > 0, "nbytes must be positive integer")
if s.isEof and s.readBuf.len() == 0: if s.isEof and s.readBuf.len() == 0:
raise newLPStreamEOFError() raise newLPStreamEOFError()
@ -144,7 +163,7 @@ method readOnce*(s: BufferStream,
trace "popping readQueue", s, rbytes, nbytes trace "popping readQueue", s, rbytes, nbytes
let buf = await s.readQueue.popFirst() let buf = await s.readQueue.popFirst()
if buf.len == 0: if buf.len == 0 or s.isEof: # Another task might have set EOF!
# No more data will arrive on read queue # No more data will arrive on read queue
s.isEof = true s.isEof = true
else: else:
@ -165,18 +184,14 @@ method readOnce*(s: BufferStream,
return rbytes return rbytes
method close*(s: BufferStream) {.async, gcsafe.} = method closeImpl*(s: BufferStream): Future[void] =
## close the stream and clear the buffer ## close the stream and clear the buffer
if s.isClosed: trace "Closing BufferStream", s, len = s.len
trace "Already closed", s
return
trace "Closing BufferStream", s if not s.pushedEof: # Potentially wake up reader
asyncSpawn s.pushEof()
# Push empty block to signal close, but don't block
asyncSpawn s.readQueue.addLast(@[])
await procCall Connection(s).close() # noraises, nocancels
inc getBufferStreamTracker().closed inc getBufferStreamTracker().closed
trace "Closed BufferStream", s trace "Closed BufferStream", s
procCall Connection(s).closeImpl() # noraises, nocancels

View File

@ -90,16 +90,15 @@ method closed*(s: ChronosStream): bool {.inline.} =
method atEof*(s: ChronosStream): bool {.inline.} = method atEof*(s: ChronosStream): bool {.inline.} =
s.client.atEof() s.client.atEof()
method close*(s: ChronosStream) {.async.} = method closeImpl*(s: ChronosStream) {.async.} =
try: try:
if not s.isClosed: trace "shutting down chronos stream", address = $s.client.remoteAddress(),
trace "shutting down chronos stream", address = $s.client.remoteAddress(), s
s if not s.client.closed():
if not s.client.closed(): await s.client.closeWait()
await s.client.closeWait()
await procCall Connection(s).close()
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
trace "error closing chronosstream", exc = exc.msg, s trace "error closing chronosstream", s, msg = exc.msg
await procCall Connection(s).closeImpl()

View File

@ -89,14 +89,16 @@ method initStream*(s: Connection) =
inc getConnectionTracker().opened inc getConnectionTracker().opened
method close*(s: Connection) {.async.} = method closeImpl*(s: Connection): Future[void] =
## cleanup timers # Cleanup timeout timer
trace "Closing connection", s
if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished: if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished:
s.timerTaskFut.cancel() s.timerTaskFut.cancel()
if not s.isClosed: inc getConnectionTracker().closed
await procCall LPStream(s).close() trace "Closed connection"
inc getConnectionTracker().closed
procCall LPStream(s).closeImpl()
func hash*(p: Connection): Hash = func hash*(p: Connection): Hash =
cast[pointer](p).hash cast[pointer](p).hash

View File

@ -210,15 +210,22 @@ proc write*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] {.deprecate
proc write*(s: LPStream, msg: string): Future[void] = proc write*(s: LPStream, msg: string): Future[void] =
s.write(msg.toBytes()) s.write(msg.toBytes())
# TODO: split `close` into `close` and `dispose/destroy` method closeImpl*(s: LPStream): Future[void] {.async, base.} =
method close*(s: LPStream) {.base, async.} = # {.raises [Defect].} ## Implementation of close - called only once
trace "Closing stream", s, objName = s.objName
s.closeEvent.fire()
libp2p_open_streams.dec(labelValues = [s.objName])
trace "Closed stream", s, objName = s.objName
method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].}
## close the stream - this may block, but will not raise exceptions ## close the stream - this may block, but will not raise exceptions
## ##
if s.isClosed: if s.isClosed:
trace "Already closed", s trace "Already closed", s
return return
s.isClosed = true # Set flag before performing virtual close
s.isClosed = true # An separate implementation method is used so that even when derived types
s.closeEvent.fire() # override `closeImpl`, it is called only once - anyone overriding `close`
libp2p_open_streams.dec(labelValues = [s.objName]) # itself must implement this - once-only check as well, with their own field
trace "Closed stream", s, objName = s.objName await closeImpl(s)

View File

@ -1,3 +1,5 @@
import std/unittest
import chronos, bearssl import chronos, bearssl
import ../libp2p/transports/tcptransport import ../libp2p/transports/tcptransport
@ -10,7 +12,7 @@ const
StreamServerTrackerName = "stream.server" StreamServerTrackerName = "stream.server"
trackerNames = [ trackerNames = [
# ConnectionTrackerName, ConnectionTrackerName,
BufferStreamTrackerName, BufferStreamTrackerName,
TcpTransportTrackerName, TcpTransportTrackerName,
StreamTransportTrackerName, StreamTransportTrackerName,
@ -25,6 +27,12 @@ iterator testTrackers*(extras: openArray[string] = []): TrackerBase =
let t = getTracker(name) let t = getTracker(name)
if not isNil(t): yield t if not isNil(t): yield t
template checkTrackers*() =
for tracker in testTrackers():
if tracker.isLeaked():
checkpoint tracker.dump()
fail()
type RngWrap = object type RngWrap = object
rng: ref BrHmacDrbgContext rng: ref BrHmacDrbgContext

View File

@ -36,9 +36,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
suite "FloodSub": suite "FloodSub":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "FloodSub basic publish/subscribe A -> B": test "FloodSub basic publish/subscribe A -> B":
proc runTests() {.async.} = proc runTests() {.async.} =

View File

@ -28,9 +28,7 @@ proc randomPeerInfo(): PeerInfo =
suite "GossipSub internal": suite "GossipSub internal":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "topic params": test "topic params":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =

View File

@ -73,9 +73,7 @@ template tryPublish(call: untyped, require: int, wait: Duration = 1.seconds, tim
suite "GossipSub": suite "GossipSub":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "GossipSub validation should succeed": test "GossipSub validation should succeed":
proc runTests() {.async.} = proc runTests() {.async.} =

View File

@ -11,26 +11,26 @@ suite "BufferStream":
check getTracker("libp2p.bufferstream").isLeaked() == false check getTracker("libp2p.bufferstream").isLeaked() == false
test "push data to buffer": test "push data to buffer":
proc testPushTo(): Future[bool] {.async.} = proc testpushData(): Future[bool] {.async.} =
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
var data = "12345" var data = "12345"
await buff.pushTo(data.toBytes()) await buff.pushData(data.toBytes())
check buff.len == 5 check buff.len == 5
result = true result = true
await buff.close() await buff.close()
check: check:
waitFor(testPushTo()) == true waitFor(testpushData()) == true
test "push and wait": test "push and wait":
proc testPushTo(): Future[bool] {.async.} = proc testpushData(): Future[bool] {.async.} =
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
let fut0 = buff.pushTo("1234".toBytes()) let fut0 = buff.pushData("1234".toBytes())
let fut1 = buff.pushTo("5".toBytes()) let fut1 = buff.pushData("5".toBytes())
check buff.len == 4 # the second write should not be visible yet check buff.len == 4 # the second write should not be visible yet
var data: array[1, byte] var data: array[1, byte]
@ -46,14 +46,14 @@ suite "BufferStream":
await buff.close() await buff.close()
check: check:
waitFor(testPushTo()) == true waitFor(testpushData()) == true
test "read with size": test "read with size":
proc testRead(): Future[bool] {.async.} = proc testRead(): Future[bool] {.async.} =
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
await buff.pushTo("12345".toBytes()) await buff.pushData("12345".toBytes())
var data: array[3, byte] var data: array[3, byte]
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
check ['1', '2', '3'] == string.fromBytes(data) check ['1', '2', '3'] == string.fromBytes(data)
@ -70,7 +70,7 @@ suite "BufferStream":
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
await buff.pushTo("12345".toBytes()) await buff.pushData("12345".toBytes())
check buff.len == 5 check buff.len == 5
var data: array[2, byte] var data: array[2, byte]
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
@ -88,7 +88,7 @@ suite "BufferStream":
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
await buff.pushTo("123".toBytes()) await buff.pushData("123".toBytes())
var data: array[5, byte] var data: array[5, byte]
var readFut = buff.readExactly(addr data[0], data.len) var readFut = buff.readExactly(addr data[0], data.len)
await buff.close() await buff.close()
@ -108,7 +108,7 @@ suite "BufferStream":
var data: array[3, byte] var data: array[3, byte]
let readFut = buff.readOnce(addr data[0], data.len) let readFut = buff.readOnce(addr data[0], data.len)
await buff.pushTo("123".toBytes()) await buff.pushData("123".toBytes())
check buff.len == 3 check buff.len == 3
check (await readFut) == 3 check (await readFut) == 3
@ -126,9 +126,9 @@ suite "BufferStream":
let buff = newBufferStream() let buff = newBufferStream()
check buff.len == 0 check buff.len == 0
let w1 = buff.pushTo("Msg 1".toBytes()) let w1 = buff.pushData("Msg 1".toBytes())
let w2 = buff.pushTo("Msg 2".toBytes()) let w2 = buff.pushData("Msg 2".toBytes())
let w3 = buff.pushTo("Msg 3".toBytes()) let w3 = buff.pushData("Msg 3".toBytes())
var data: array[5, byte] var data: array[5, byte]
await buff.readExactly(addr data[0], data.len) await buff.readExactly(addr data[0], data.len)
@ -143,9 +143,9 @@ suite "BufferStream":
for f in [w1, w2, w3]: await f for f in [w1, w2, w3]: await f
let w4 = buff.pushTo("Msg 4".toBytes()) let w4 = buff.pushData("Msg 4".toBytes())
let w5 = buff.pushTo("Msg 5".toBytes()) let w5 = buff.pushData("Msg 5".toBytes())
let w6 = buff.pushTo("Msg 6".toBytes()) let w6 = buff.pushData("Msg 6".toBytes())
await buff.close() await buff.close()
@ -173,7 +173,7 @@ suite "BufferStream":
var writes: seq[Future[void]] var writes: seq[Future[void]]
var str: string var str: string
for i in 0..<10: for i in 0..<10:
writes.add buff.pushTo("123".toBytes()) writes.add buff.pushData("123".toBytes())
str &= "123" str &= "123"
await buff.close() # all data should still be read after close await buff.close() # all data should still be read after close
@ -201,8 +201,8 @@ suite "BufferStream":
proc closeTest(): Future[bool] {.async.} = proc closeTest(): Future[bool] {.async.} =
var stream = newBufferStream() var stream = newBufferStream()
var var
fut = stream.pushTo(toBytes("hello")) fut = stream.pushData(toBytes("hello"))
fut2 = stream.pushTo(toBytes("again")) fut2 = stream.pushData(toBytes("again"))
await stream.close() await stream.close()
try: try:
await wait(fut, 100.milliseconds) await wait(fut, 100.milliseconds)
@ -219,13 +219,13 @@ suite "BufferStream":
test "no push after close": test "no push after close":
proc closeTest(): Future[bool] {.async.} = proc closeTest(): Future[bool] {.async.} =
var stream = newBufferStream() var stream = newBufferStream()
await stream.pushTo("123".toBytes()) await stream.pushData("123".toBytes())
var data: array[3, byte] var data: array[3, byte]
await stream.readExactly(addr data[0], data.len) await stream.readExactly(addr data[0], data.len)
await stream.close() await stream.close()
try: try:
await stream.pushTo("123".toBytes()) await stream.pushData("123".toBytes())
except LPStreamClosedError: except LPStreamClosedError:
result = true result = true

View File

@ -21,9 +21,7 @@ method newStream*(
suite "Connection Manager": suite "Connection Manager":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "add and retrive a connection": test "add and retrive a connection":
let connMngr = ConnManager.init() let connMngr = ConnManager.init()

View File

@ -15,9 +15,7 @@ when defined(nimHasUsed): {.used.}
suite "Identify": suite "Identify":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "handle identify message": test "handle identify message":
proc testHandle(): Future[bool] {.async.} = proc testHandle(): Future[bool] {.async.} =

View File

@ -18,9 +18,7 @@ import ./helpers
suite "Mplex": suite "Mplex":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "encode header with channel id 0": test "encode header with channel id 0":
proc testEncodeHeader() {.async.} = proc testEncodeHeader() {.async.} =
@ -70,7 +68,7 @@ suite "Mplex":
proc testDecodeHeader() {.async.} = proc testDecodeHeader() {.async.} =
let stream = newBufferStream() let stream = newBufferStream()
let conn = stream let conn = stream
await stream.pushTo(fromHex("000873747265616d2031")) await stream.pushData(fromHex("000873747265616d2031"))
let msg = await conn.readMsg() let msg = await conn.readMsg()
check msg.id == 0 check msg.id == 0
@ -83,7 +81,7 @@ suite "Mplex":
proc testDecodeHeader() {.async.} = proc testDecodeHeader() {.async.} =
let stream = newBufferStream() let stream = newBufferStream()
let conn = stream let conn = stream
await stream.pushTo(fromHex("021668656C6C6F2066726F6D206368616E6E656C20302121")) await stream.pushData(fromHex("021668656C6C6F2066726F6D206368616E6E656C20302121"))
let msg = await conn.readMsg() let msg = await conn.readMsg()
check msg.id == 0 check msg.id == 0
@ -97,7 +95,7 @@ suite "Mplex":
proc testDecodeHeader() {.async.} = proc testDecodeHeader() {.async.} =
let stream = newBufferStream() let stream = newBufferStream()
let conn = stream let conn = stream
await stream.pushTo(fromHex("8a011668656C6C6F2066726F6D206368616E6E656C20302121")) await stream.pushData(fromHex("8a011668656C6C6F2066726F6D206368616E6E656C20302121"))
let msg = await conn.readMsg() let msg = await conn.readMsg()
check msg.id == 17 check msg.id == 17
@ -134,14 +132,14 @@ suite "Mplex":
) )
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
await chann.pushTo(("Hello!").toBytes) await chann.pushData(("Hello!").toBytes)
var data = newSeq[byte](6) var data = newSeq[byte](6)
await chann.close() # closing channel await chann.close() # closing channel
# should be able to read on local clsoe # should be able to read on local clsoe
await chann.readExactly(addr data[0], 3) await chann.readExactly(addr data[0], 3)
# closing remote end # closing remote end
let closeFut = chann.closeRemote() let closeFut = chann.pushEof()
# should still allow reading until buffer EOF # should still allow reading until buffer EOF
await chann.readExactly(addr data[3], 3) await chann.readExactly(addr data[3], 3)
try: try:
@ -166,11 +164,11 @@ suite "Mplex":
) )
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
await chann.pushTo(("Hello!").toBytes) await chann.pushData(("Hello!").toBytes)
var data = newSeq[byte](6) var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 3) await chann.readExactly(addr data[0], 3)
let closeFut = chann.closeRemote() # closing channel let closeFut = chann.pushEof() # closing channel
let readFut = chann.readExactly(addr data[3], 3) let readFut = chann.readExactly(addr data[3], 3)
await all(closeFut, readFut) await all(closeFut, readFut)
try: try:
@ -184,7 +182,7 @@ suite "Mplex":
check: check:
waitFor(testClosedForRead()) == true waitFor(testClosedForRead()) == true
test "half closed (remote close) - channel should allow writting on remote close": test "half closed (remote close) - channel should allow writing on remote close":
proc testClosedForRead(): Future[bool] {.async.} = proc testClosedForRead(): Future[bool] {.async.} =
let let
testData = "Hello!".toBytes testData = "Hello!".toBytes
@ -194,12 +192,12 @@ suite "Mplex":
) )
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
await chann.closeRemote() # closing channel await chann.pushEof() # closing channel
try: try:
await chann.writeLp(testData) await chann.writeLp(testData)
return true return true
finally: finally:
await chann.close() await chann.reset() # there's nobody reading the EOF!
await conn.close() await conn.close()
check: check:
@ -211,9 +209,11 @@ suite "Mplex":
let let
conn = newBufferStream(writeHandler) conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
await chann.closeRemote() await chann.pushEof()
var buf: array[1, byte]
check: (await chann.readOnce(addr buf[0], 1)) == 0 # EOF marker read
try: try:
await chann.pushTo(@[byte(1)]) await chann.pushData(@[byte(1)])
except LPStreamEOFError: except LPStreamEOFError:
result = true result = true
finally: finally:
@ -234,7 +234,6 @@ suite "Mplex":
var data = newSeq[byte](1) var data = newSeq[byte](1)
try: try:
await chann.readExactly(addr data[0], 1) await chann.readExactly(addr data[0], 1)
check data.len == 1
except LPStreamEOFError: except LPStreamEOFError:
result = true result = true
finally: finally:
@ -243,6 +242,60 @@ suite "Mplex":
check: check:
waitFor(testResetRead()) == true waitFor(testResetRead()) == true
test "reset - should complete read":
proc testResetRead(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](1)
let fut = chann.readExactly(addr data[0], 1)
await chann.reset()
try:
await fut
except LPStreamEOFError:
result = true
finally:
await conn.close()
check:
waitFor(testResetRead()) == true
test "reset - should complete pushData":
proc testResetRead(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
await chann.pushData(@[0'u8])
let fut = chann.pushData(@[0'u8])
await chann.reset()
result = await fut.withTimeout(100.millis)
await conn.close()
check:
waitFor(testResetRead()) == true
test "reset - should complete both read and push":
proc testResetRead(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let
conn = newBufferStream(writeHandler)
chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](1)
let rfut = chann.readExactly(addr data[0], 1)
let wfut = chann.pushData(@[0'u8])
let wfut2 = chann.pushData(@[0'u8])
await chann.reset()
result = await allFutures(rfut, wfut, wfut2).withTimeout(100.millis)
await conn.close()
check:
waitFor(testResetRead()) == true
test "reset - channel should fail writing": test "reset - channel should fail writing":
proc testResetWrite(): Future[bool] {.async.} = proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
@ -533,6 +586,7 @@ suite "Mplex":
await done.wait(5.seconds) await done.wait(5.seconds)
await conn.close() await conn.close()
await mplexDialFut await mplexDialFut
await mplexDial.close()
await allFuturesThrowing( await allFuturesThrowing(
transport1.close(), transport1.close(),
transport2.close()) transport2.close())

View File

@ -168,9 +168,7 @@ proc newTestNaStream(na: NaHandler): TestNaStream =
suite "Multistream select": suite "Multistream select":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "test select custom proto": test "test select custom proto":
proc testSelect(): Future[bool] {.async.} = proc testSelect(): Future[bool] {.async.} =

View File

@ -68,9 +68,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
suite "Noise": suite "Noise":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e: handle write + noise": test "e2e: handle write + noise":
proc testListenerDialer(): Future[bool] {.async.} = proc testListenerDialer(): Future[bool] {.async.} =

View File

@ -15,7 +15,6 @@ import ../libp2p/[errors,
crypto/crypto, crypto/crypto,
protocols/protocol, protocols/protocol,
muxers/muxer, muxers/muxer,
muxers/mplex/mplex,
stream/lpstream] stream/lpstream]
import ./helpers import ./helpers
@ -27,9 +26,7 @@ type
suite "Switch": suite "Switch":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "e2e use switch dial proto string": test "e2e use switch dial proto string":
proc testSwitch() {.async, gcsafe.} = proc testSwitch() {.async, gcsafe.} =
@ -112,23 +109,6 @@ suite "Switch":
check "Hello!" == msg check "Hello!" == msg
await conn.close() await conn.close()
await sleepAsync(2.seconds) # wait a little for cleanup to happen
var bufferTracker = getTracker(BufferStreamTrackerName)
# echo bufferTracker.dump()
# plus 4 for the pubsub streams
check (BufferStreamTracker(bufferTracker).opened ==
(BufferStreamTracker(bufferTracker).closed))
var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump()
# plus 8 is for the secured connection and the socket
# and the pubsub streams that won't clean up until
# `disconnect()` or `stop()`
check (ConnectionTracker(connTracker).opened ==
(ConnectionTracker(connTracker).closed + 4.uint64))
await allFuturesThrowing( await allFuturesThrowing(
done.wait(5.seconds), done.wait(5.seconds),
switch1.stop(), switch1.stop(),

View File

@ -11,9 +11,7 @@ import ./helpers
suite "TCP transport": suite "TCP transport":
teardown: teardown:
for tracker in testTrackers(): checkTrackers()
# echo tracker.dump()
check tracker.isLeaked() == false
test "test listener: handle write": test "test listener: handle write":
proc testListener(): Future[bool] {.async, gcsafe.} = proc testListener(): Future[bool] {.async, gcsafe.} =