Half closed (#174)

* call write until all is written out

* add comments to lpchannel fields

* add an eof flag to signal which end closed

* wip: rework with proper half-closed

* add eof and closed handling

* propagate closes to piped

* call parent close

* moving bufferstream trackers out

* move writeLock to bufferstream

* move writeLock out

* remove unused call

* wip

* rebasing master

* fix mplex tests

* wip

* fix bufferstream after backport

* wip

* rename to differentiate from chronos tracker

* close connection on chronos close

* make reset request asyncCheck

* fix channel cleanup

* misc

* don't use read

* fix backports

* make noise work again

* proper exception handling

* don't reraise just yet

* add convenience templates

* dont double wrap

* use async pragma

* fixes after backporting

* muxer owns connection

* remove on transport close cleanup

* revert back allread

* adding some todos

* read from stream

* inc count before closing

* rebasing master

* rebase master

* use correct exception type

* use try/finally insted of defer

* fix compile in trace mode

* reset channels on mplex close
This commit is contained in:
Dmitriy Ryajov 2020-05-19 18:14:15 -06:00 committed by GitHub
parent 681991ae48
commit 7900fd9f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 327 additions and 235 deletions

View File

@ -7,6 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import oids, deques
import chronos, chronicles import chronos, chronicles
import types, import types,
coder, coder,
@ -22,21 +23,45 @@ export lpstream
logScope: logScope:
topic = "MplexChannel" topic = "MplexChannel"
## Channel half-closed states
##
## | State | Closed local | Closed remote
## |=============================================
## | Read | Yes (until EOF) | No
## | Write | No | Yes
##
type type
LPChannel* = ref object of BufferStream LPChannel* = ref object of BufferStream
id*: uint64 id*: uint64 # channel id
name*: string name*: string # name of the channel (for debugging)
conn*: Connection conn*: Connection # wrapped connection used to for writing
initiator*: bool initiator*: bool # initiated remotely or locally flag
isLazy*: bool isLazy*: bool # is channel lazy
isOpen*: bool isOpen*: bool # has channel been oppened (only used with isLazy)
isReset*: bool closedLocal*: bool # has channel been closed locally
closedLocal*: bool msgCode*: MessageType # cached in/out message code
closedRemote*: bool closeCode*: MessageType # cached in/out close code
handlerFuture*: Future[void] resetCode*: MessageType # cached in/out reset code
msgCode*: MessageType
closeCode*: MessageType proc open*(s: LPChannel) {.async, gcsafe.}
resetCode*: MessageType
template withWriteLock(lock: AsyncLock, body: untyped): untyped =
try:
await lock.acquire()
body
finally:
lock.release()
template withEOFExceptions(body: untyped): untyped =
try:
body
except LPStreamEOFError as exc:
trace "muxed connection EOF", exc = exc.msg
except LPStreamClosedError as exc:
trace "muxed connection closed", exc = exc.msg
except LPStreamIncompleteError as exc:
trace "incomplete message", exc = exc.msg
proc newChannel*(id: uint64, proc newChannel*(id: uint64,
conn: Connection, conn: Connection,
@ -55,110 +80,114 @@ proc newChannel*(id: uint64,
result.isLazy = lazy result.isLazy = lazy
let chan = result let chan = result
proc writeHandler(data: seq[byte]): Future[void] {.async.} = proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
if chan.isLazy and not(chan.isOpen):
await chan.open()
# writes should happen in sequence # writes should happen in sequence
trace "sending data", data = data.shortLog, trace "sending data", data = data.shortLog,
id = chan.id, id = chan.id,
initiator = chan.initiator initiator = chan.initiator,
name = chan.name,
oid = chan.oid
await conn.writeMsg(chan.id, chan.msgCode, data) # write header await conn.writeMsg(chan.id, chan.msgCode, data) # write header
result.initBufferStream(writeHandler, size) result.initBufferStream(writeHandler, size)
when chronicles.enabledLogLevel == LogLevel.TRACE:
result.name = if result.name.len > 0: result.name else: $result.oid
trace "created new lpchannel", id = result.id,
oid = result.oid,
initiator = result.initiator,
name = result.name
proc closeMessage(s: LPChannel) {.async.} = proc closeMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.closeCode) # write header withEOFExceptions:
withWriteLock(s.writeLock):
trace "sending close message", id = s.id,
initiator = s.initiator,
name = s.name,
oid = s.oid
proc cleanUp*(s: LPChannel): Future[void] = await s.conn.writeMsg(s.id, s.closeCode) # write close
# method which calls the underlying buffer's `close`
# method used instead of `close` since it's overloaded to
# simulate half-closed streams
result = procCall close(BufferStream(s))
proc tryCleanup(s: LPChannel) {.async, inline.} =
# if stream is EOF, then cleanup immediatelly
if s.closedRemote and s.len == 0:
await s.cleanUp()
proc closedByRemote*(s: LPChannel) {.async.} =
s.closedRemote = true
if s.len == 0:
await s.cleanUp()
proc open*(s: LPChannel): Future[void] =
s.isOpen = true
s.conn.writeMsg(s.id, MessageType.New, s.name)
method close*(s: LPChannel) {.async, gcsafe.} =
s.closedLocal = true
await s.closeMessage()
proc resetMessage(s: LPChannel) {.async.} = proc resetMessage(s: LPChannel) {.async.} =
await s.conn.writeMsg(s.id, s.resetCode) withEOFExceptions:
withWriteLock(s.writeLock):
trace "sending reset message", id = s.id,
initiator = s.initiator,
name = s.name,
oid = s.oid
proc resetByRemote*(s: LPChannel) {.async.} = await s.conn.writeMsg(s.id, s.resetCode) # write reset
# Immediately block futher calls
s.isReset = true
# start and await async teardown proc open*(s: LPChannel) {.async, gcsafe.} =
let ## NOTE: Don't call withExcAndLock or withWriteLock,
futs = await allFinished( ## because this already gets called from writeHandler
s.close(), ## which is locked
s.closedByRemote(), withEOFExceptions:
s.cleanUp() await s.conn.writeMsg(s.id, MessageType.New, s.name)
) trace "oppened channel", oid = s.oid,
name = s.name,
initiator = s.initiator
s.isOpen = true
checkFutures(futs, [LPStreamEOFError]) proc closeRemote*(s: LPChannel) {.async.} =
trace "got EOF, closing channel", id = s.id,
initiator = s.initiator,
name = s.name,
oid = s.oid
proc reset*(s: LPChannel) {.async.} = # wait for all data in the buffer to be consumed
let while s.len > 0:
futs = await allFinished( await s.dataReadEvent.wait()
s.resetMessage(), s.dataReadEvent.clear()
s.resetByRemote()
)
checkFutures(futs, [LPStreamEOFError]) # TODO: Not sure if this needs to be set here or bfore consuming
# the buffer
s.isEof = true # set EOF immediately to prevent further reads
await procCall BufferStream(s).close() # close parent bufferstream
trace "channel closed on EOF", id = s.id,
initiator = s.initiator,
oid = s.oid,
name = s.name
method closed*(s: LPChannel): bool = method closed*(s: LPChannel): bool =
trace "closing lpchannel", id = s.id, initiator = s.initiator ## this emulates half-closed behavior
result = s.closedRemote and s.len == 0 ## when closed locally writing is
## dissabled - see the table in the
## header of the file
s.closedLocal
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] = method close*(s: LPChannel) {.async, gcsafe.} =
if s.closedRemote or s.isReset: if s.closedLocal:
var retFuture = newFuture[void]("LPChannel.pushTo") return
retFuture.fail(newLPStreamEOFError())
return retFuture
trace "pushing data to channel", data = data.shortLog, trace "closing local lpchannel", id = s.id,
id = s.id, initiator = s.initiator,
initiator = s.initiator name = s.name,
oid = s.oid
# TODO: we should install a timer that on expire
# will make sure the channel did close by the remote
# so the hald-closed flow completed, if it didn't
# we should send a `reset` and move on.
await s.closeMessage()
s.closedLocal = true
if s.atEof: # already closed by remote close parent buffer imediately
await procCall BufferStream(s).close()
result = procCall pushTo(BufferStream(s), data) trace "lpchannel closed local", id = s.id,
initiator = s.initiator,
name = s.name,
oid = s.oid
template raiseEOF(): untyped = method reset*(s: LPChannel) {.base, async.} =
if s.closed or s.isReset: # we asyncCheck here because the other end
raise newLPStreamEOFError() # might be dead already - reset is always
# optimistic
method readExactly*(s: LPChannel, asyncCheck s.resetMessage()
pbytes: pointer, await procCall BufferStream(s).close()
nbytes: int): s.isEof = true
Future[void] {.async.} = s.closedLocal = true
raiseEOF()
await procCall readExactly(BufferStream(s), pbytes, nbytes)
await s.tryCleanup()
method readOnce*(s: LPChannel,
pbytes: pointer,
nbytes: int):
Future[int] {.async.} =
raiseEOF()
result = await procCall readOnce(BufferStream(s), pbytes, nbytes)
await s.tryCleanup()
method write*(s: LPChannel, msg: seq[byte]) {.async.} =
if s.closedLocal or s.isReset:
raise newLPStreamEOFError()
if s.isLazy and not s.isOpen:
await s.open()
await procCall write(BufferStream(s), msg)

View File

@ -7,15 +7,12 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
## TODO: import tables, sequtils, oids
## Timeouts and message limits are still missing
## they need to be added ASAP
import tables, sequtils
import chronos, chronicles import chronos, chronicles
import ../muxer, import ../muxer,
../../connection, ../../connection,
../../stream/lpstream, ../../stream/lpstream,
../../stream/bufferstream,
../../utility, ../../utility,
../../errors, ../../errors,
coder, coder,
@ -27,18 +24,21 @@ logScope:
type type
Mplex* = ref object of Muxer Mplex* = ref object of Muxer
remote*: Table[uint64, LPChannel] remote: Table[uint64, LPChannel]
local*: Table[uint64, LPChannel] local: Table[uint64, LPChannel]
handlers*: array[2, Table[uint64, Future[void]]] handlerFuts: seq[Future[void]]
currentId*: uint64 currentId*: uint64
maxChannels*: uint64 maxChannels*: uint64
isClosed: bool
when chronicles.enabledLogLevel == LogLevel.TRACE:
oid*: Oid
proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] =
if initiator: if initiator:
trace "picking local channels", initiator = initiator trace "picking local channels", initiator = initiator, oid = m.oid
result = m.local result = m.local
else: else:
trace "picking remote channels", initiator = initiator trace "picking remote channels", initiator = initiator, oid = m.oid
result = m.remote result = m.remote
proc newStreamInternal*(m: Mplex, proc newStreamInternal*(m: Mplex,
@ -49,36 +49,27 @@ proc newStreamInternal*(m: Mplex,
Future[LPChannel] {.async, gcsafe.} = Future[LPChannel] {.async, gcsafe.} =
## create new channel/stream ## create new channel/stream
let id = if initiator: m.currentId.inc(); m.currentId else: chanId let id = if initiator: m.currentId.inc(); m.currentId else: chanId
trace "creating new channel", channelId = id, initiator = initiator trace "creating new channel", channelId = id,
result = newChannel(id, m.connection, initiator, name, lazy = lazy) initiator = initiator,
name = name,
oid = m.oid
result = newChannel(id,
m.connection,
initiator,
name,
lazy = lazy)
m.getChannelList(initiator)[id] = result m.getChannelList(initiator)[id] = result
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async.} =
## call the channel's `close` to signal the
## remote that the channel is closing
if not isNil(chann) and not chann.closed:
trace "cleaning up channel", id = chann.id
await chann.close()
await chann.cleanUp()
m.getChannelList(initiator).del(chann.id)
trace "cleaned up channel", id = chann.id
proc cleanupChann(chann: LPChannel) {.async.} =
trace "cleaning up channel", id = chann.id
await chann.reset()
await chann.close()
await chann.cleanUp()
trace "cleaned up channel", id = chann.id
method handle*(m: Mplex) {.async, gcsafe.} = method handle*(m: Mplex) {.async, gcsafe.} =
trace "starting mplex main loop" trace "starting mplex main loop", oid = m.oid
try: try:
while not m.connection.closed: while not m.connection.closed:
trace "waiting for data" trace "waiting for data", oid = m.oid
let (id, msgType, data) = await m.connection.readMsg() let (id, msgType, data) = await m.connection.readMsg()
trace "read message from connection", id = id, trace "read message from connection", id = id,
msgType = msgType, msgType = msgType,
data = data.shortLog data = data.shortLog,
oid = m.oid
let initiator = bool(ord(msgType) and 1) let initiator = bool(ord(msgType) and 1)
var channel: LPChannel var channel: LPChannel
if MessageType(msgType) != MessageType.New: if MessageType(msgType) != MessageType.New:
@ -86,7 +77,8 @@ method handle*(m: Mplex) {.async, gcsafe.} =
if id notin channels: if id notin channels:
trace "Channel not found, skipping", id = id, trace "Channel not found, skipping", id = id,
initiator = initiator, initiator = initiator,
msg = msgType msg = msgType,
oid = m.oid
continue continue
channel = channels[id] channel = channels[id]
@ -94,26 +86,33 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.New: of MessageType.New:
let name = cast[string](data) let name = cast[string](data)
channel = await m.newStreamInternal(false, id, name) channel = await m.newStreamInternal(false, id, name)
trace "created channel", id = id, name = name, inititator = initiator trace "created channel", id = id,
name = name,
inititator = channel.initiator,
channoid = channel.oid,
oid = m.oid
if not isNil(m.streamHandler): if not isNil(m.streamHandler):
let stream = newConnection(channel) let stream = newConnection(channel)
stream.peerInfo = m.connection.peerInfo stream.peerInfo = m.connection.peerInfo
var fut = newFuture[void]()
proc handler() {.async.} = proc handler() {.async.} =
tryAndWarn "mplex channel handler": tryAndWarn "mplex channel handler":
await m.streamHandler(stream) await m.streamHandler(stream)
if not initiator:
await m.cleanupChann(channel, false)
if not initiator: fut = handler()
m.handlers[0][id] = handler() m.handlerFuts.add(fut)
else: fut.addCallback do(udata: pointer):
m.handlers[1][id] = handler() m.handlerFuts.keepItIf(it != fut)
of MessageType.MsgIn, MessageType.MsgOut: of MessageType.MsgIn, MessageType.MsgOut:
trace "pushing data to channel", id = id, trace "pushing data to channel", id = id,
initiator = initiator, initiator = initiator,
msgType = msgType, msgType = msgType,
size = data.len size = data.len,
name = channel.name,
channoid = channel.oid,
oid = m.oid
if data.len > MaxMsgSize: if data.len > MaxMsgSize:
raise newLPStreamLimitError() raise newLPStreamLimitError()
@ -121,27 +120,40 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.CloseIn, MessageType.CloseOut: of MessageType.CloseIn, MessageType.CloseOut:
trace "closing channel", id = id, trace "closing channel", id = id,
initiator = initiator, initiator = initiator,
msgType = msgType msgType = msgType,
name = channel.name,
channoid = channel.oid,
oid = m.oid
await channel.closedByRemote() await channel.closeRemote()
m.getChannelList(initiator).del(id) m.getChannelList(initiator).del(id)
trace "deleted channel", id = id,
initiator = initiator,
msgType = msgType,
name = channel.name,
channoid = channel.oid,
oid = m.oid
of MessageType.ResetIn, MessageType.ResetOut: of MessageType.ResetIn, MessageType.ResetOut:
trace "resetting channel", id = id, trace "resetting channel", id = id,
initiator = initiator, initiator = initiator,
msgType = msgType msgType = msgType,
name = channel.name,
channoid = channel.oid,
oid = m.oid
await channel.resetByRemote() await channel.reset()
m.getChannelList(initiator).del(id) m.getChannelList(initiator).del(id)
trace "deleted channel", id = id,
initiator = initiator,
msgType = msgType,
name = channel.name,
channoid = channel.oid,
oid = m.oid
break break
except CatchableError as exc: except CatchableError as exc:
trace "Exception occurred", exception = exc.msg trace "Exception occurred", exception = exc.msg, oid = m.oid
finally: finally:
trace "stopping mplex main loop" trace "stopping mplex main loop", oid = m.oid
await m.close()
proc internalCleanup(m: Mplex, conn: Connection) {.async.} =
await conn.closeEvent.wait()
trace "connection closed, cleaning up mplex"
await m.close() await m.close()
proc newMplex*(conn: Connection, proc newMplex*(conn: Connection,
@ -152,7 +164,16 @@ proc newMplex*(conn: Connection,
result.remote = initTable[uint64, LPChannel]() result.remote = initTable[uint64, LPChannel]()
result.local = initTable[uint64, LPChannel]() result.local = initTable[uint64, LPChannel]()
asyncCheck result.internalCleanup(conn) when chronicles.enabledLogLevel == LogLevel.TRACE:
result.oid = genOid()
proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} =
## remove the local channel from the internal tables
##
await chann.closeEvent.wait()
if not isNil(chann):
m.getChannelList(true).del(chann.id)
trace "cleaned up channel", id = chann.id
method newStream*(m: Mplex, method newStream*(m: Mplex,
name: string = "", name: string = "",
@ -163,24 +184,23 @@ method newStream*(m: Mplex,
result = newConnection(channel) result = newConnection(channel)
result.peerInfo = m.connection.peerInfo result.peerInfo = m.connection.peerInfo
asyncCheck m.cleanupChann(channel)
method close*(m: Mplex) {.async, gcsafe.} = method close*(m: Mplex) {.async, gcsafe.} =
trace "closing mplex muxer" if m.isClosed:
return
trace "closing mplex muxer", oid = m.oid
checkFutures(
await allFinished(
toSeq(m.remote.values).mapIt(it.reset()) &
toSeq(m.local.values).mapIt(it.reset())))
checkFutures(await allFinished(m.handlerFuts))
if not m.connection.closed():
await m.connection.close() await m.connection.close()
let
futs = await allFinished(
toSeq(m.remote.values).mapIt(it.cleanupChann()) &
toSeq(m.local.values).mapIt(it.cleanupChann()) &
toSeq(m.handlers[0].values).mapIt(it) &
toSeq(m.handlers[1].values).mapIt(it))
checkFutures(futs)
m.handlers[0].clear()
m.handlers[1].clear()
m.remote.clear() m.remote.clear()
m.local.clear() m.local.clear()
m.handlerFuts = @[]
trace "mplex muxer closed" m.isClosed = true

View File

@ -20,7 +20,8 @@ import pubsub,
../../peerinfo, ../../peerinfo,
../../connection, ../../connection,
../../peer, ../../peer,
../../errors ../../errors,
../../utility
logScope: logScope:
topic = "GossipSub" topic = "GossipSub"

View File

@ -113,7 +113,7 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
p.onConnect.wait().addCallback do (udata: pointer): p.onConnect.wait().addCallback do (udata: pointer):
asyncCheck sendToRemote() asyncCheck sendToRemote()
trace "enqueued message to send at a later time", peer = p.id, trace "enqueued message to send at a later time", peer = p.id,
encoded = encodedHex encoded = digest
except CatchableError as exc: except CatchableError as exc:
trace "Exception occurred in PubSubPeer.send", exc = exc.msg trace "Exception occurred in PubSubPeer.send", exc = exc.msg

View File

@ -55,7 +55,6 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection]
result = await s.handleConn(conn, initiator) result = await s.handleConn(conn, initiator)
except CatchableError as exc: except CatchableError as exc:
warn "securing connection failed", msg = exc.msg warn "securing connection failed", msg = exc.msg
if not conn.closed():
await conn.close() await conn.close()
method readExactly*(s: SecureConn, method readExactly*(s: SecureConn,

View File

@ -36,26 +36,18 @@ import ../stream/lpstream
export lpstream export lpstream
logScope:
topic = "BufferStream"
declareGauge libp2p_open_bufferstream, "open BufferStream instances"
const const
BufferStreamTrackerName* = "libp2p.bufferstream"
DefaultBufferSize* = 1024 DefaultBufferSize* = 1024
const
BufferStreamTrackerName* = "libp2p.bufferstream"
type type
# TODO: figure out how to make this generic to avoid casts
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
BufferStream* = ref object of LPStream
maxSize*: int # buffer's max size in bytes
readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
dataReadEvent: AsyncEvent
writeHandler*: WriteHandler
lock: AsyncLock
isPiped: bool
AlreadyPipedError* = object of CatchableError
NotWritableError* = object of CatchableError
BufferStreamTracker* = ref object of TrackerBase BufferStreamTracker* = ref object of TrackerBase
opened*: uint64 opened*: uint64
closed*: uint64 closed*: uint64
@ -83,7 +75,23 @@ proc setupBufferStreamTracker(): BufferStreamTracker =
result.dump = dumpTracking result.dump = dumpTracking
result.isLeaked = leakTransport result.isLeaked = leakTransport
addTracker(BufferStreamTrackerName, result) addTracker(BufferStreamTrackerName, result)
declareGauge libp2p_open_bufferstream, "open BufferStream instances"
type
# TODO: figure out how to make this generic to avoid casts
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
BufferStream* = ref object of LPStream
maxSize*: int # buffer's max size in bytes
readBuf: Deque[byte] # this is a ring buffer based dequeue
readReqs*: Deque[Future[void]] # use dequeue to fire reads in order
dataReadEvent*: AsyncEvent # event triggered when data has been consumed from the internal buffer
writeHandler*: WriteHandler # user provided write callback
writeLock*: AsyncLock # write lock to guarantee ordered writes
lock: AsyncLock # pushTo lock to guarantee ordered reads
piped: BufferStream # a piped bufferstream instance
AlreadyPipedError* = object of CatchableError
NotWritableError* = object of CatchableError
proc newAlreadyPipedError*(): ref Exception {.inline.} = proc newAlreadyPipedError*(): ref Exception {.inline.} =
result = newException(AlreadyPipedError, "stream already piped") result = newException(AlreadyPipedError, "stream already piped")
@ -96,7 +104,7 @@ proc requestReadBytes(s: BufferStream): Future[void] =
## data becomes available in the read buffer ## data becomes available in the read buffer
result = newFuture[void]() result = newFuture[void]()
s.readReqs.addLast(result) s.readReqs.addLast(result)
trace "requestReadBytes(): added a future to readReqs" trace "requestReadBytes(): added a future to readReqs", oid = s.oid
proc initBufferStream*(s: BufferStream, proc initBufferStream*(s: BufferStream,
handler: WriteHandler = nil, handler: WriteHandler = nil,
@ -106,12 +114,29 @@ proc initBufferStream*(s: BufferStream,
s.readReqs = initDeque[Future[void]]() s.readReqs = initDeque[Future[void]]()
s.dataReadEvent = newAsyncEvent() s.dataReadEvent = newAsyncEvent()
s.lock = newAsyncLock() s.lock = newAsyncLock()
s.writeHandler = handler s.writeLock = newAsyncLock()
s.closeEvent = newAsyncEvent() s.closeEvent = newAsyncEvent()
inc getBufferStreamTracker().opened s.isClosed = false
if not(isNil(handler)):
s.writeHandler = proc (data: seq[byte]) {.async, gcsafe.} =
try:
# Using a lock here to guarantee
# proper write ordering. This is
# specially important when
# implementing half-closed in mplex
# or other functionality that requires
# strict message ordering
await s.writeLock.acquire()
await handler(data)
finally:
s.writeLock.release()
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
s.oid = genOid() s.oid = genOid()
s.isClosed = false
trace "created bufferstream", oid = s.oid
inc getBufferStreamTracker().opened
libp2p_open_bufferstream.inc() libp2p_open_bufferstream.inc()
proc newBufferStream*(handler: WriteHandler = nil, proc newBufferStream*(handler: WriteHandler = nil,
@ -133,7 +158,7 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
proc len*(s: BufferStream): int = s.readBuf.len proc len*(s: BufferStream): int = s.readBuf.len
proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} = method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
## Write bytes to internal read buffer, use this to fill up the ## Write bytes to internal read buffer, use this to fill up the
## buffer with data. ## buffer with data.
## ##
@ -142,9 +167,8 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
## is preserved. ## is preserved.
## ##
when chronicles.enabledLogLevel == LogLevel.TRACE: if s.atEof:
logScope: raise newLPStreamEOFError()
stream_oid = $s.oid
try: try:
await s.lock.acquire() await s.lock.acquire()
@ -153,12 +177,12 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
while index < data.len and s.readBuf.len < s.maxSize: while index < data.len and s.readBuf.len < s.maxSize:
s.readBuf.addLast(data[index]) s.readBuf.addLast(data[index])
inc(index) inc(index)
trace "pushTo()", msg = "added " & $index & " bytes to readBuf" trace "pushTo()", msg = "added " & $index & " bytes to readBuf", oid = s.oid
# resolve the next queued read request # resolve the next queued read request
if s.readReqs.len > 0: if s.readReqs.len > 0:
s.readReqs.popFirst().complete() s.readReqs.popFirst().complete()
trace "pushTo(): completed a readReqs future" trace "pushTo(): completed a readReqs future", oid = s.oid
if index >= data.len: if index >= data.len:
return return
@ -180,11 +204,11 @@ method readExactly*(s: BufferStream,
## If EOF is received and ``nbytes`` is not yet read, the procedure ## If EOF is received and ``nbytes`` is not yet read, the procedure
## will raise ``LPStreamIncompleteError``. ## will raise ``LPStreamIncompleteError``.
## ##
when chronicles.enabledLogLevel == LogLevel.TRACE:
logScope:
stream_oid = $s.oid
trace "read()", requested_bytes = nbytes if s.atEof:
raise newLPStreamEOFError()
trace "readExactly()", requested_bytes = nbytes, oid = s.oid
var index = 0 var index = 0
if s.readBuf.len() == 0: if s.readBuf.len() == 0:
@ -195,7 +219,7 @@ method readExactly*(s: BufferStream,
while s.readBuf.len() > 0 and index < nbytes: while s.readBuf.len() > 0 and index < nbytes:
output[index] = s.popFirst() output[index] = s.popFirst()
inc(index) inc(index)
trace "readExactly()", read_bytes = index trace "readExactly()", read_bytes = index, oid = s.oid
if index < nbytes: if index < nbytes:
await s.requestReadBytes() await s.requestReadBytes()
@ -209,6 +233,10 @@ method readOnce*(s: BufferStream,
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
## internal buffer, otherwise it will wait until some bytes will be received. ## internal buffer, otherwise it will wait until some bytes will be received.
## ##
if s.atEof:
raise newLPStreamEOFError()
if s.readBuf.len == 0: if s.readBuf.len == 0:
await s.requestReadBytes() await s.requestReadBytes()
@ -216,7 +244,7 @@ method readOnce*(s: BufferStream,
await s.readExactly(pbytes, len) await s.readExactly(pbytes, len)
result = len result = len
method write*(s: BufferStream, msg: seq[byte]): Future[void] = method write*(s: BufferStream, msg: seq[byte]) {.async.} =
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
## stream ``wstream``. ## stream ``wstream``.
## ##
@ -226,12 +254,14 @@ method write*(s: BufferStream, msg: seq[byte]): Future[void] =
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
## stream. ## stream.
## ##
if isNil(s.writeHandler):
var retFuture = newFuture[void]("BufferStream.write(seq)")
retFuture.fail(newNotWritableError())
return retFuture
result = s.writeHandler(msg) if s.closed:
raise newLPStreamClosedError()
if isNil(s.writeHandler):
raise newNotWritableError()
await s.writeHandler(msg)
proc pipe*(s: BufferStream, proc pipe*(s: BufferStream,
target: BufferStream): BufferStream = target: BufferStream): BufferStream =
@ -242,10 +272,10 @@ proc pipe*(s: BufferStream,
## interface methods `read*` and `write` are ## interface methods `read*` and `write` are
## piped. ## piped.
## ##
if s.isPiped: if not(isNil(s.piped)):
raise newAlreadyPipedError() raise newAlreadyPipedError()
s.isPiped = true s.piped = target
let oldHandler = target.writeHandler let oldHandler = target.writeHandler
proc handler(data: seq[byte]) {.async, closure, gcsafe.} = proc handler(data: seq[byte]) {.async, closure, gcsafe.} =
if not isNil(oldHandler): if not isNil(oldHandler):
@ -272,10 +302,10 @@ proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
## pipe operator to make piping less verbose ## pipe operator to make piping less verbose
pipe(s, target) pipe(s, target)
method close*(s: BufferStream) {.async.} = method close*(s: BufferStream) {.async, gcsafe.} =
## close the stream and clear the buffer ## close the stream and clear the buffer
if not s.isClosed: if not s.isClosed:
trace "closing bufferstream" trace "closing bufferstream", oid = s.oid
for r in s.readReqs: for r in s.readReqs:
if not(isNil(r)) and not(r.finished()): if not(isNil(r)) and not(r.finished()):
r.fail(newLPStreamEOFError()) r.fail(newLPStreamEOFError())
@ -283,7 +313,10 @@ method close*(s: BufferStream) {.async.} =
s.readBuf.clear() s.readBuf.clear()
s.closeEvent.fire() s.closeEvent.fire()
s.isClosed = true s.isClosed = true
inc getBufferStreamTracker().closed inc getBufferStreamTracker().closed
libp2p_open_bufferstream.dec() libp2p_open_bufferstream.dec()
trace "bufferstream closed", oid = s.oid
else: else:
trace "attempt to close an already closed bufferstream", trace = getStackTrace() trace "attempt to close an already closed bufferstream", trace = getStackTrace()

View File

@ -21,6 +21,7 @@ proc newChronosStream*(client: StreamTransport): ChronosStream =
result.client = client result.client = client
result.closeEvent = newAsyncEvent() result.closeEvent = newAsyncEvent()
template withExceptions(body: untyped) = template withExceptions(body: untyped) =
try: try:
body body
@ -38,20 +39,23 @@ template withExceptions(body: untyped) =
method readExactly*(s: ChronosStream, method readExactly*(s: ChronosStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): Future[void] {.async.} = nbytes: int): Future[void] {.async.} =
if s.client.atEof: if s.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
withExceptions: withExceptions:
await s.client.readExactly(pbytes, nbytes) await s.client.readExactly(pbytes, nbytes)
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.client.atEof: if s.atEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
withExceptions: withExceptions:
result = await s.client.readOnce(pbytes, nbytes) result = await s.client.readOnce(pbytes, nbytes)
method write*(s: ChronosStream, msg: seq[byte]) {.async.} = method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
if s.closed:
raise newLPStreamClosedError()
if msg.len == 0: if msg.len == 0:
return return
@ -63,6 +67,9 @@ method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
method closed*(s: ChronosStream): bool {.inline.} = method closed*(s: ChronosStream): bool {.inline.} =
result = s.client.closed result = s.client.closed
method atEof*(s: ChronosStream): bool {.inline.} =
s.client.atEof()
method close*(s: ChronosStream) {.async.} = method close*(s: ChronosStream) {.async.} =
if not s.closed: if not s.closed:
trace "shutting chronos stream", address = $s.client.remoteAddress() trace "shutting chronos stream", address = $s.client.remoteAddress()

View File

@ -15,6 +15,7 @@ import ../varint,
type type
LPStream* = ref object of RootObj LPStream* = ref object of RootObj
isClosed*: bool isClosed*: bool
isEof*: bool
closeEvent*: AsyncEvent closeEvent*: AsyncEvent
when chronicles.enabledLogLevel == LogLevel.TRACE: when chronicles.enabledLogLevel == LogLevel.TRACE:
oid*: Oid oid*: Oid
@ -28,6 +29,7 @@ type
LPStreamWriteError* = object of LPStreamError LPStreamWriteError* = object of LPStreamError
par*: ref Exception par*: ref Exception
LPStreamEOFError* = object of LPStreamError LPStreamEOFError* = object of LPStreamError
LPStreamClosedError* = object of LPStreamError
InvalidVarintError* = object of LPStreamError InvalidVarintError* = object of LPStreamError
MaxSizeError* = object of LPStreamError MaxSizeError* = object of LPStreamError
@ -59,9 +61,15 @@ proc newLPStreamIncorrectDefect*(m: string): ref Exception =
proc newLPStreamEOFError*(): ref Exception = proc newLPStreamEOFError*(): ref Exception =
result = newException(LPStreamEOFError, "Stream EOF!") result = newException(LPStreamEOFError, "Stream EOF!")
proc newLPStreamClosedError*(): ref Exception =
result = newException(LPStreamClosedError, "Stream Closed!")
method closed*(s: LPStream): bool {.base, inline.} = method closed*(s: LPStream): bool {.base, inline.} =
s.isClosed s.isClosed
method atEof*(s: LPStream): bool {.base, inline.} =
s.isEof
method readExactly*(s: LPStream, method readExactly*(s: LPStream,
pbytes: pointer, pbytes: pointer,
nbytes: int): nbytes: int):

View File

@ -43,8 +43,8 @@ proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} =
proc dumpTracking(): string {.gcsafe.} = proc dumpTracking(): string {.gcsafe.} =
var tracker = getTcpTransportTracker() var tracker = getTcpTransportTracker()
result = "Opened transports: " & $tracker.opened & "\n" & result = "Opened tcp transports: " & $tracker.opened & "\n" &
"Closed transports: " & $tracker.closed "Closed tcp transports: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} = proc leakTransport(): bool {.gcsafe.} =
var tracker = getTcpTransportTracker() var tracker = getTcpTransportTracker()
@ -98,7 +98,6 @@ proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T =
method initTransport*(t: TcpTransport) = method initTransport*(t: TcpTransport) =
t.multicodec = multiCodec("tcp") t.multicodec = multiCodec("tcp")
inc getTcpTransportTracker().opened inc getTcpTransportTracker().opened
method close*(t: TcpTransport) {.async, gcsafe.} = method close*(t: TcpTransport) {.async, gcsafe.} =

View File

@ -122,7 +122,7 @@ suite "Mplex":
await chann.close() await chann.close()
try: try:
await chann.write("Hello") await chann.write("Hello")
except LPStreamEOFError: except LPStreamClosedError:
result = true result = true
finally: finally:
await chann.reset() await chann.reset()
@ -141,7 +141,7 @@ suite "Mplex":
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
await chann.pushTo(("Hello!").toBytes) await chann.pushTo(("Hello!").toBytes)
let closeFut = chann.closedByRemote() let closeFut = chann.closeRemote()
var data = newSeq[byte](6) var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
@ -163,7 +163,7 @@ suite "Mplex":
let let
conn = newConnection(newBufferStream(writeHandler)) conn = newConnection(newBufferStream(writeHandler))
chann = newChannel(1, conn, true) chann = newChannel(1, conn, true)
await chann.closedByRemote() await chann.closeRemote()
try: try:
await chann.pushTo(@[byte(1)]) await chann.pushTo(@[byte(1)])
except LPStreamEOFError: except LPStreamEOFError:
@ -204,7 +204,7 @@ suite "Mplex":
await chann.reset() await chann.reset()
try: try:
await chann.write(("Hello!").toBytes) await chann.write(("Hello!").toBytes)
except LPStreamEOFError: except LPStreamClosedError:
result = true result = true
finally: finally:
await conn.close() await conn.close()
@ -401,7 +401,7 @@ suite "Mplex":
let mplexDial = newMplex(conn) let mplexDial = newMplex(conn)
# TODO: Reenable once half-closed is working properly # TODO: Reenable once half-closed is working properly
# let mplexDialFut = mplexDial.handle() let mplexDialFut = mplexDial.handle()
for i in 1..10: for i in 1..10:
let stream = await mplexDial.newStream() let stream = await mplexDial.newStream()
await stream.writeLp(&"stream {i}!") await stream.writeLp(&"stream {i}!")
@ -409,7 +409,7 @@ suite "Mplex":
await done.wait(10.seconds) await done.wait(10.seconds)
await conn.close() await conn.close()
# await mplexDialFut await mplexDialFut
await allFuturesThrowing(transport1.close(), transport2.close()) await allFuturesThrowing(transport1.close(), transport2.close())
await listenFut await listenFut

View File

@ -71,6 +71,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
suite "Noise": suite "Noise":
teardown: teardown:
for tracker in testTrackers(): for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "e2e: handle write + noise": test "e2e: handle write + noise":

View File

@ -56,7 +56,7 @@ suite "Switch":
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "e2e use switch dial proto string": test "e2e use switch dial proto string":
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch() {.async, gcsafe.} =
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
@ -86,16 +86,12 @@ suite "Switch":
let conn = await switch2.dial(switch1.peerInfo, TestCodec) let conn = await switch2.dial(switch1.peerInfo, TestCodec)
try:
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
let msg = cast[string](await conn.readLp(1024)) let msg = cast[string](await conn.readLp(1024))
check "Hello!" == msg check "Hello!" == msg
result = true
except LPStreamError:
result = false
await allFuturesThrowing( await allFuturesThrowing(
done.wait(5000.millis) #[if OK won't happen!!]#, done.wait(5.seconds) #[if OK won't happen!!]#,
conn.close(), conn.close(),
switch1.stop(), switch1.stop(),
switch2.stop(), switch2.stop(),
@ -104,8 +100,7 @@ suite "Switch":
# this needs to go at end # this needs to go at end
await allFuturesThrowing(awaiters) await allFuturesThrowing(awaiters)
check: waitFor(testSwitch())
waitFor(testSwitch()) == true
test "e2e use switch no proto string": test "e2e use switch no proto string":
proc testSwitch(): Future[bool] {.async, gcsafe.} = proc testSwitch(): Future[bool] {.async, gcsafe.} =