mirror of
https://github.com/vacp2p/nim-libp2p-experimental.git
synced 2025-01-11 19:04:26 +00:00
Half closed (#174)
* call write until all is written out * add comments to lpchannel fields * add an eof flag to signal which end closed * wip: rework with proper half-closed * add eof and closed handling * propagate closes to piped * call parent close * moving bufferstream trackers out * move writeLock to bufferstream * move writeLock out * remove unused call * wip * rebasing master * fix mplex tests * wip * fix bufferstream after backport * wip * rename to differentiate from chronos tracker * close connection on chronos close * make reset request asyncCheck * fix channel cleanup * misc * don't use read * fix backports * make noise work again * proper exception handling * don't reraise just yet * add convenience templates * dont double wrap * use async pragma * fixes after backporting * muxer owns connection * remove on transport close cleanup * revert back allread * adding some todos * read from stream * inc count before closing * rebasing master * rebase master * use correct exception type * use try/finally insted of defer * fix compile in trace mode * reset channels on mplex close
This commit is contained in:
parent
681991ae48
commit
7900fd9f61
@ -7,6 +7,7 @@
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import oids, deques
|
||||
import chronos, chronicles
|
||||
import types,
|
||||
coder,
|
||||
@ -22,21 +23,45 @@ export lpstream
|
||||
logScope:
|
||||
topic = "MplexChannel"
|
||||
|
||||
## Channel half-closed states
|
||||
##
|
||||
## | State | Closed local | Closed remote
|
||||
## |=============================================
|
||||
## | Read | Yes (until EOF) | No
|
||||
## | Write | No | Yes
|
||||
##
|
||||
|
||||
type
|
||||
LPChannel* = ref object of BufferStream
|
||||
id*: uint64
|
||||
name*: string
|
||||
conn*: Connection
|
||||
initiator*: bool
|
||||
isLazy*: bool
|
||||
isOpen*: bool
|
||||
isReset*: bool
|
||||
closedLocal*: bool
|
||||
closedRemote*: bool
|
||||
handlerFuture*: Future[void]
|
||||
msgCode*: MessageType
|
||||
closeCode*: MessageType
|
||||
resetCode*: MessageType
|
||||
id*: uint64 # channel id
|
||||
name*: string # name of the channel (for debugging)
|
||||
conn*: Connection # wrapped connection used to for writing
|
||||
initiator*: bool # initiated remotely or locally flag
|
||||
isLazy*: bool # is channel lazy
|
||||
isOpen*: bool # has channel been oppened (only used with isLazy)
|
||||
closedLocal*: bool # has channel been closed locally
|
||||
msgCode*: MessageType # cached in/out message code
|
||||
closeCode*: MessageType # cached in/out close code
|
||||
resetCode*: MessageType # cached in/out reset code
|
||||
|
||||
proc open*(s: LPChannel) {.async, gcsafe.}
|
||||
|
||||
template withWriteLock(lock: AsyncLock, body: untyped): untyped =
|
||||
try:
|
||||
await lock.acquire()
|
||||
body
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
template withEOFExceptions(body: untyped): untyped =
|
||||
try:
|
||||
body
|
||||
except LPStreamEOFError as exc:
|
||||
trace "muxed connection EOF", exc = exc.msg
|
||||
except LPStreamClosedError as exc:
|
||||
trace "muxed connection closed", exc = exc.msg
|
||||
except LPStreamIncompleteError as exc:
|
||||
trace "incomplete message", exc = exc.msg
|
||||
|
||||
proc newChannel*(id: uint64,
|
||||
conn: Connection,
|
||||
@ -55,110 +80,114 @@ proc newChannel*(id: uint64,
|
||||
result.isLazy = lazy
|
||||
|
||||
let chan = result
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async.} =
|
||||
proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} =
|
||||
if chan.isLazy and not(chan.isOpen):
|
||||
await chan.open()
|
||||
|
||||
# writes should happen in sequence
|
||||
trace "sending data ", data = data.shortLog,
|
||||
trace "sending data", data = data.shortLog,
|
||||
id = chan.id,
|
||||
initiator = chan.initiator
|
||||
initiator = chan.initiator,
|
||||
name = chan.name,
|
||||
oid = chan.oid
|
||||
|
||||
await conn.writeMsg(chan.id, chan.msgCode, data) # write header
|
||||
|
||||
result.initBufferStream(writeHandler, size)
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
result.name = if result.name.len > 0: result.name else: $result.oid
|
||||
|
||||
trace "created new lpchannel", id = result.id,
|
||||
oid = result.oid,
|
||||
initiator = result.initiator,
|
||||
name = result.name
|
||||
|
||||
proc closeMessage(s: LPChannel) {.async.} =
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write header
|
||||
withEOFExceptions:
|
||||
withWriteLock(s.writeLock):
|
||||
trace "sending close message", id = s.id,
|
||||
initiator = s.initiator,
|
||||
name = s.name,
|
||||
oid = s.oid
|
||||
|
||||
proc cleanUp*(s: LPChannel): Future[void] =
|
||||
# method which calls the underlying buffer's `close`
|
||||
# method used instead of `close` since it's overloaded to
|
||||
# simulate half-closed streams
|
||||
result = procCall close(BufferStream(s))
|
||||
|
||||
proc tryCleanup(s: LPChannel) {.async, inline.} =
|
||||
# if stream is EOF, then cleanup immediatelly
|
||||
if s.closedRemote and s.len == 0:
|
||||
await s.cleanUp()
|
||||
|
||||
proc closedByRemote*(s: LPChannel) {.async.} =
|
||||
s.closedRemote = true
|
||||
if s.len == 0:
|
||||
await s.cleanUp()
|
||||
|
||||
proc open*(s: LPChannel): Future[void] =
|
||||
s.isOpen = true
|
||||
s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||
|
||||
method close*(s: LPChannel) {.async, gcsafe.} =
|
||||
s.closedLocal = true
|
||||
await s.closeMessage()
|
||||
await s.conn.writeMsg(s.id, s.closeCode) # write close
|
||||
|
||||
proc resetMessage(s: LPChannel) {.async.} =
|
||||
await s.conn.writeMsg(s.id, s.resetCode)
|
||||
withEOFExceptions:
|
||||
withWriteLock(s.writeLock):
|
||||
trace "sending reset message", id = s.id,
|
||||
initiator = s.initiator,
|
||||
name = s.name,
|
||||
oid = s.oid
|
||||
|
||||
proc resetByRemote*(s: LPChannel) {.async.} =
|
||||
# Immediately block futher calls
|
||||
s.isReset = true
|
||||
await s.conn.writeMsg(s.id, s.resetCode) # write reset
|
||||
|
||||
# start and await async teardown
|
||||
let
|
||||
futs = await allFinished(
|
||||
s.close(),
|
||||
s.closedByRemote(),
|
||||
s.cleanUp()
|
||||
)
|
||||
proc open*(s: LPChannel) {.async, gcsafe.} =
|
||||
## NOTE: Don't call withExcAndLock or withWriteLock,
|
||||
## because this already gets called from writeHandler
|
||||
## which is locked
|
||||
withEOFExceptions:
|
||||
await s.conn.writeMsg(s.id, MessageType.New, s.name)
|
||||
trace "oppened channel", oid = s.oid,
|
||||
name = s.name,
|
||||
initiator = s.initiator
|
||||
s.isOpen = true
|
||||
|
||||
checkFutures(futs, [LPStreamEOFError])
|
||||
proc closeRemote*(s: LPChannel) {.async.} =
|
||||
trace "got EOF, closing channel", id = s.id,
|
||||
initiator = s.initiator,
|
||||
name = s.name,
|
||||
oid = s.oid
|
||||
|
||||
proc reset*(s: LPChannel) {.async.} =
|
||||
let
|
||||
futs = await allFinished(
|
||||
s.resetMessage(),
|
||||
s.resetByRemote()
|
||||
)
|
||||
# wait for all data in the buffer to be consumed
|
||||
while s.len > 0:
|
||||
await s.dataReadEvent.wait()
|
||||
s.dataReadEvent.clear()
|
||||
|
||||
checkFutures(futs, [LPStreamEOFError])
|
||||
# TODO: Not sure if this needs to be set here or bfore consuming
|
||||
# the buffer
|
||||
s.isEof = true # set EOF immediately to prevent further reads
|
||||
await procCall BufferStream(s).close() # close parent bufferstream
|
||||
|
||||
trace "channel closed on EOF", id = s.id,
|
||||
initiator = s.initiator,
|
||||
oid = s.oid,
|
||||
name = s.name
|
||||
|
||||
method closed*(s: LPChannel): bool =
|
||||
trace "closing lpchannel", id = s.id, initiator = s.initiator
|
||||
result = s.closedRemote and s.len == 0
|
||||
## this emulates half-closed behavior
|
||||
## when closed locally writing is
|
||||
## dissabled - see the table in the
|
||||
## header of the file
|
||||
s.closedLocal
|
||||
|
||||
proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] =
|
||||
if s.closedRemote or s.isReset:
|
||||
var retFuture = newFuture[void]("LPChannel.pushTo")
|
||||
retFuture.fail(newLPStreamEOFError())
|
||||
return retFuture
|
||||
method close*(s: LPChannel) {.async, gcsafe.} =
|
||||
if s.closedLocal:
|
||||
return
|
||||
|
||||
trace "pushing data to channel", data = data.shortLog,
|
||||
id = s.id,
|
||||
initiator = s.initiator
|
||||
trace "closing local lpchannel", id = s.id,
|
||||
initiator = s.initiator,
|
||||
name = s.name,
|
||||
oid = s.oid
|
||||
# TODO: we should install a timer that on expire
|
||||
# will make sure the channel did close by the remote
|
||||
# so the hald-closed flow completed, if it didn't
|
||||
# we should send a `reset` and move on.
|
||||
await s.closeMessage()
|
||||
s.closedLocal = true
|
||||
if s.atEof: # already closed by remote close parent buffer imediately
|
||||
await procCall BufferStream(s).close()
|
||||
|
||||
result = procCall pushTo(BufferStream(s), data)
|
||||
trace "lpchannel closed local", id = s.id,
|
||||
initiator = s.initiator,
|
||||
name = s.name,
|
||||
oid = s.oid
|
||||
|
||||
template raiseEOF(): untyped =
|
||||
if s.closed or s.isReset:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
method readExactly*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async.} =
|
||||
raiseEOF()
|
||||
await procCall readExactly(BufferStream(s), pbytes, nbytes)
|
||||
await s.tryCleanup()
|
||||
|
||||
method readOnce*(s: LPChannel,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[int] {.async.} =
|
||||
raiseEOF()
|
||||
result = await procCall readOnce(BufferStream(s), pbytes, nbytes)
|
||||
await s.tryCleanup()
|
||||
|
||||
method write*(s: LPChannel, msg: seq[byte]) {.async.} =
|
||||
if s.closedLocal or s.isReset:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
if s.isLazy and not s.isOpen:
|
||||
await s.open()
|
||||
|
||||
await procCall write(BufferStream(s), msg)
|
||||
method reset*(s: LPChannel) {.base, async.} =
|
||||
# we asyncCheck here because the other end
|
||||
# might be dead already - reset is always
|
||||
# optimistic
|
||||
asyncCheck s.resetMessage()
|
||||
await procCall BufferStream(s).close()
|
||||
s.isEof = true
|
||||
s.closedLocal = true
|
||||
|
@ -7,15 +7,12 @@
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
## TODO:
|
||||
## Timeouts and message limits are still missing
|
||||
## they need to be added ASAP
|
||||
|
||||
import tables, sequtils
|
||||
import tables, sequtils, oids
|
||||
import chronos, chronicles
|
||||
import ../muxer,
|
||||
../../connection,
|
||||
../../stream/lpstream,
|
||||
../../stream/bufferstream,
|
||||
../../utility,
|
||||
../../errors,
|
||||
coder,
|
||||
@ -27,18 +24,21 @@ logScope:
|
||||
|
||||
type
|
||||
Mplex* = ref object of Muxer
|
||||
remote*: Table[uint64, LPChannel]
|
||||
local*: Table[uint64, LPChannel]
|
||||
handlers*: array[2, Table[uint64, Future[void]]]
|
||||
remote: Table[uint64, LPChannel]
|
||||
local: Table[uint64, LPChannel]
|
||||
handlerFuts: seq[Future[void]]
|
||||
currentId*: uint64
|
||||
maxChannels*: uint64
|
||||
isClosed: bool
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
oid*: Oid
|
||||
|
||||
proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] =
|
||||
if initiator:
|
||||
trace "picking local channels", initiator = initiator
|
||||
trace "picking local channels", initiator = initiator, oid = m.oid
|
||||
result = m.local
|
||||
else:
|
||||
trace "picking remote channels", initiator = initiator
|
||||
trace "picking remote channels", initiator = initiator, oid = m.oid
|
||||
result = m.remote
|
||||
|
||||
proc newStreamInternal*(m: Mplex,
|
||||
@ -49,36 +49,27 @@ proc newStreamInternal*(m: Mplex,
|
||||
Future[LPChannel] {.async, gcsafe.} =
|
||||
## create new channel/stream
|
||||
let id = if initiator: m.currentId.inc(); m.currentId else: chanId
|
||||
trace "creating new channel", channelId = id, initiator = initiator
|
||||
result = newChannel(id, m.connection, initiator, name, lazy = lazy)
|
||||
trace "creating new channel", channelId = id,
|
||||
initiator = initiator,
|
||||
name = name,
|
||||
oid = m.oid
|
||||
result = newChannel(id,
|
||||
m.connection,
|
||||
initiator,
|
||||
name,
|
||||
lazy = lazy)
|
||||
m.getChannelList(initiator)[id] = result
|
||||
|
||||
proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async.} =
|
||||
## call the channel's `close` to signal the
|
||||
## remote that the channel is closing
|
||||
if not isNil(chann) and not chann.closed:
|
||||
trace "cleaning up channel", id = chann.id
|
||||
await chann.close()
|
||||
await chann.cleanUp()
|
||||
m.getChannelList(initiator).del(chann.id)
|
||||
trace "cleaned up channel", id = chann.id
|
||||
|
||||
proc cleanupChann(chann: LPChannel) {.async.} =
|
||||
trace "cleaning up channel", id = chann.id
|
||||
await chann.reset()
|
||||
await chann.close()
|
||||
await chann.cleanUp()
|
||||
trace "cleaned up channel", id = chann.id
|
||||
|
||||
method handle*(m: Mplex) {.async, gcsafe.} =
|
||||
trace "starting mplex main loop"
|
||||
trace "starting mplex main loop", oid = m.oid
|
||||
try:
|
||||
while not m.connection.closed:
|
||||
trace "waiting for data"
|
||||
trace "waiting for data", oid = m.oid
|
||||
let (id, msgType, data) = await m.connection.readMsg()
|
||||
trace "read message from connection", id = id,
|
||||
msgType = msgType,
|
||||
data = data.shortLog
|
||||
data = data.shortLog,
|
||||
oid = m.oid
|
||||
let initiator = bool(ord(msgType) and 1)
|
||||
var channel: LPChannel
|
||||
if MessageType(msgType) != MessageType.New:
|
||||
@ -86,7 +77,8 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
||||
if id notin channels:
|
||||
trace "Channel not found, skipping", id = id,
|
||||
initiator = initiator,
|
||||
msg = msgType
|
||||
msg = msgType,
|
||||
oid = m.oid
|
||||
continue
|
||||
channel = channels[id]
|
||||
|
||||
@ -94,26 +86,33 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
||||
of MessageType.New:
|
||||
let name = cast[string](data)
|
||||
channel = await m.newStreamInternal(false, id, name)
|
||||
trace "created channel", id = id, name = name, inititator = initiator
|
||||
trace "created channel", id = id,
|
||||
name = name,
|
||||
inititator = channel.initiator,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
if not isNil(m.streamHandler):
|
||||
let stream = newConnection(channel)
|
||||
stream.peerInfo = m.connection.peerInfo
|
||||
|
||||
var fut = newFuture[void]()
|
||||
proc handler() {.async.} =
|
||||
tryAndWarn "mplex channel handler":
|
||||
await m.streamHandler(stream)
|
||||
if not initiator:
|
||||
await m.cleanupChann(channel, false)
|
||||
|
||||
if not initiator:
|
||||
m.handlers[0][id] = handler()
|
||||
else:
|
||||
m.handlers[1][id] = handler()
|
||||
fut = handler()
|
||||
m.handlerFuts.add(fut)
|
||||
fut.addCallback do(udata: pointer):
|
||||
m.handlerFuts.keepItIf(it != fut)
|
||||
|
||||
of MessageType.MsgIn, MessageType.MsgOut:
|
||||
trace "pushing data to channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType,
|
||||
size = data.len
|
||||
size = data.len,
|
||||
name = channel.name,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
|
||||
if data.len > MaxMsgSize:
|
||||
raise newLPStreamLimitError()
|
||||
@ -121,27 +120,40 @@ method handle*(m: Mplex) {.async, gcsafe.} =
|
||||
of MessageType.CloseIn, MessageType.CloseOut:
|
||||
trace "closing channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType
|
||||
msgType = msgType,
|
||||
name = channel.name,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
|
||||
await channel.closedByRemote()
|
||||
await channel.closeRemote()
|
||||
m.getChannelList(initiator).del(id)
|
||||
trace "deleted channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType,
|
||||
name = channel.name,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
of MessageType.ResetIn, MessageType.ResetOut:
|
||||
trace "resetting channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType
|
||||
msgType = msgType,
|
||||
name = channel.name,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
|
||||
await channel.resetByRemote()
|
||||
await channel.reset()
|
||||
m.getChannelList(initiator).del(id)
|
||||
trace "deleted channel", id = id,
|
||||
initiator = initiator,
|
||||
msgType = msgType,
|
||||
name = channel.name,
|
||||
channoid = channel.oid,
|
||||
oid = m.oid
|
||||
break
|
||||
except CatchableError as exc:
|
||||
trace "Exception occurred", exception = exc.msg
|
||||
trace "Exception occurred", exception = exc.msg, oid = m.oid
|
||||
finally:
|
||||
trace "stopping mplex main loop"
|
||||
await m.close()
|
||||
|
||||
proc internalCleanup(m: Mplex, conn: Connection) {.async.} =
|
||||
await conn.closeEvent.wait()
|
||||
trace "connection closed, cleaning up mplex"
|
||||
trace "stopping mplex main loop", oid = m.oid
|
||||
await m.close()
|
||||
|
||||
proc newMplex*(conn: Connection,
|
||||
@ -152,7 +164,16 @@ proc newMplex*(conn: Connection,
|
||||
result.remote = initTable[uint64, LPChannel]()
|
||||
result.local = initTable[uint64, LPChannel]()
|
||||
|
||||
asyncCheck result.internalCleanup(conn)
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
result.oid = genOid()
|
||||
|
||||
proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} =
|
||||
## remove the local channel from the internal tables
|
||||
##
|
||||
await chann.closeEvent.wait()
|
||||
if not isNil(chann):
|
||||
m.getChannelList(true).del(chann.id)
|
||||
trace "cleaned up channel", id = chann.id
|
||||
|
||||
method newStream*(m: Mplex,
|
||||
name: string = "",
|
||||
@ -163,24 +184,23 @@ method newStream*(m: Mplex,
|
||||
result = newConnection(channel)
|
||||
result.peerInfo = m.connection.peerInfo
|
||||
|
||||
asyncCheck m.cleanupChann(channel)
|
||||
|
||||
method close*(m: Mplex) {.async, gcsafe.} =
|
||||
trace "closing mplex muxer"
|
||||
if m.isClosed:
|
||||
return
|
||||
|
||||
trace "closing mplex muxer", oid = m.oid
|
||||
|
||||
checkFutures(
|
||||
await allFinished(
|
||||
toSeq(m.remote.values).mapIt(it.reset()) &
|
||||
toSeq(m.local.values).mapIt(it.reset())))
|
||||
|
||||
checkFutures(await allFinished(m.handlerFuts))
|
||||
|
||||
if not m.connection.closed():
|
||||
await m.connection.close()
|
||||
|
||||
let
|
||||
futs = await allFinished(
|
||||
toSeq(m.remote.values).mapIt(it.cleanupChann()) &
|
||||
toSeq(m.local.values).mapIt(it.cleanupChann()) &
|
||||
toSeq(m.handlers[0].values).mapIt(it) &
|
||||
toSeq(m.handlers[1].values).mapIt(it))
|
||||
|
||||
checkFutures(futs)
|
||||
|
||||
m.handlers[0].clear()
|
||||
m.handlers[1].clear()
|
||||
m.remote.clear()
|
||||
m.local.clear()
|
||||
|
||||
trace "mplex muxer closed"
|
||||
m.handlerFuts = @[]
|
||||
m.isClosed = true
|
||||
|
@ -20,7 +20,8 @@ import pubsub,
|
||||
../../peerinfo,
|
||||
../../connection,
|
||||
../../peer,
|
||||
../../errors
|
||||
../../errors,
|
||||
../../utility
|
||||
|
||||
logScope:
|
||||
topic = "GossipSub"
|
||||
|
@ -113,7 +113,7 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
|
||||
p.onConnect.wait().addCallback do (udata: pointer):
|
||||
asyncCheck sendToRemote()
|
||||
trace "enqueued message to send at a later time", peer = p.id,
|
||||
encoded = encodedHex
|
||||
encoded = digest
|
||||
|
||||
except CatchableError as exc:
|
||||
trace "Exception occurred in PubSubPeer.send", exc = exc.msg
|
||||
|
@ -55,7 +55,6 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection]
|
||||
result = await s.handleConn(conn, initiator)
|
||||
except CatchableError as exc:
|
||||
warn "securing connection failed", msg = exc.msg
|
||||
if not conn.closed():
|
||||
await conn.close()
|
||||
|
||||
method readExactly*(s: SecureConn,
|
||||
|
@ -36,26 +36,18 @@ import ../stream/lpstream
|
||||
|
||||
export lpstream
|
||||
|
||||
logScope:
|
||||
topic = "BufferStream"
|
||||
|
||||
declareGauge libp2p_open_bufferstream, "open BufferStream instances"
|
||||
|
||||
const
|
||||
BufferStreamTrackerName* = "libp2p.bufferstream"
|
||||
DefaultBufferSize* = 1024
|
||||
|
||||
const
|
||||
BufferStreamTrackerName* = "libp2p.bufferstream"
|
||||
|
||||
type
|
||||
# TODO: figure out how to make this generic to avoid casts
|
||||
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
|
||||
|
||||
BufferStream* = ref object of LPStream
|
||||
maxSize*: int # buffer's max size in bytes
|
||||
readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here
|
||||
readReqs: Deque[Future[void]] # use dequeue to fire reads in order
|
||||
dataReadEvent: AsyncEvent
|
||||
writeHandler*: WriteHandler
|
||||
lock: AsyncLock
|
||||
isPiped: bool
|
||||
|
||||
AlreadyPipedError* = object of CatchableError
|
||||
NotWritableError* = object of CatchableError
|
||||
|
||||
BufferStreamTracker* = ref object of TrackerBase
|
||||
opened*: uint64
|
||||
closed*: uint64
|
||||
@ -83,7 +75,23 @@ proc setupBufferStreamTracker(): BufferStreamTracker =
|
||||
result.dump = dumpTracking
|
||||
result.isLeaked = leakTransport
|
||||
addTracker(BufferStreamTrackerName, result)
|
||||
declareGauge libp2p_open_bufferstream, "open BufferStream instances"
|
||||
|
||||
type
|
||||
# TODO: figure out how to make this generic to avoid casts
|
||||
WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.}
|
||||
|
||||
BufferStream* = ref object of LPStream
|
||||
maxSize*: int # buffer's max size in bytes
|
||||
readBuf: Deque[byte] # this is a ring buffer based dequeue
|
||||
readReqs*: Deque[Future[void]] # use dequeue to fire reads in order
|
||||
dataReadEvent*: AsyncEvent # event triggered when data has been consumed from the internal buffer
|
||||
writeHandler*: WriteHandler # user provided write callback
|
||||
writeLock*: AsyncLock # write lock to guarantee ordered writes
|
||||
lock: AsyncLock # pushTo lock to guarantee ordered reads
|
||||
piped: BufferStream # a piped bufferstream instance
|
||||
|
||||
AlreadyPipedError* = object of CatchableError
|
||||
NotWritableError* = object of CatchableError
|
||||
|
||||
proc newAlreadyPipedError*(): ref Exception {.inline.} =
|
||||
result = newException(AlreadyPipedError, "stream already piped")
|
||||
@ -96,7 +104,7 @@ proc requestReadBytes(s: BufferStream): Future[void] =
|
||||
## data becomes available in the read buffer
|
||||
result = newFuture[void]()
|
||||
s.readReqs.addLast(result)
|
||||
trace "requestReadBytes(): added a future to readReqs"
|
||||
trace "requestReadBytes(): added a future to readReqs", oid = s.oid
|
||||
|
||||
proc initBufferStream*(s: BufferStream,
|
||||
handler: WriteHandler = nil,
|
||||
@ -106,12 +114,29 @@ proc initBufferStream*(s: BufferStream,
|
||||
s.readReqs = initDeque[Future[void]]()
|
||||
s.dataReadEvent = newAsyncEvent()
|
||||
s.lock = newAsyncLock()
|
||||
s.writeHandler = handler
|
||||
s.writeLock = newAsyncLock()
|
||||
s.closeEvent = newAsyncEvent()
|
||||
inc getBufferStreamTracker().opened
|
||||
s.isClosed = false
|
||||
|
||||
if not(isNil(handler)):
|
||||
s.writeHandler = proc (data: seq[byte]) {.async, gcsafe.} =
|
||||
try:
|
||||
# Using a lock here to guarantee
|
||||
# proper write ordering. This is
|
||||
# specially important when
|
||||
# implementing half-closed in mplex
|
||||
# or other functionality that requires
|
||||
# strict message ordering
|
||||
await s.writeLock.acquire()
|
||||
await handler(data)
|
||||
finally:
|
||||
s.writeLock.release()
|
||||
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
s.oid = genOid()
|
||||
s.isClosed = false
|
||||
|
||||
trace "created bufferstream", oid = s.oid
|
||||
inc getBufferStreamTracker().opened
|
||||
libp2p_open_bufferstream.inc()
|
||||
|
||||
proc newBufferStream*(handler: WriteHandler = nil,
|
||||
@ -133,7 +158,7 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) =
|
||||
|
||||
proc len*(s: BufferStream): int = s.readBuf.len
|
||||
|
||||
proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
|
||||
method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
||||
## Write bytes to internal read buffer, use this to fill up the
|
||||
## buffer with data.
|
||||
##
|
||||
@ -142,9 +167,8 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
|
||||
## is preserved.
|
||||
##
|
||||
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
logScope:
|
||||
stream_oid = $s.oid
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
try:
|
||||
await s.lock.acquire()
|
||||
@ -153,12 +177,12 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} =
|
||||
while index < data.len and s.readBuf.len < s.maxSize:
|
||||
s.readBuf.addLast(data[index])
|
||||
inc(index)
|
||||
trace "pushTo()", msg = "added " & $index & " bytes to readBuf"
|
||||
trace "pushTo()", msg = "added " & $index & " bytes to readBuf", oid = s.oid
|
||||
|
||||
# resolve the next queued read request
|
||||
if s.readReqs.len > 0:
|
||||
s.readReqs.popFirst().complete()
|
||||
trace "pushTo(): completed a readReqs future"
|
||||
trace "pushTo(): completed a readReqs future", oid = s.oid
|
||||
|
||||
if index >= data.len:
|
||||
return
|
||||
@ -180,11 +204,11 @@ method readExactly*(s: BufferStream,
|
||||
## If EOF is received and ``nbytes`` is not yet read, the procedure
|
||||
## will raise ``LPStreamIncompleteError``.
|
||||
##
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
logScope:
|
||||
stream_oid = $s.oid
|
||||
|
||||
trace "read()", requested_bytes = nbytes
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
trace "readExactly()", requested_bytes = nbytes, oid = s.oid
|
||||
var index = 0
|
||||
|
||||
if s.readBuf.len() == 0:
|
||||
@ -195,7 +219,7 @@ method readExactly*(s: BufferStream,
|
||||
while s.readBuf.len() > 0 and index < nbytes:
|
||||
output[index] = s.popFirst()
|
||||
inc(index)
|
||||
trace "readExactly()", read_bytes = index
|
||||
trace "readExactly()", read_bytes = index, oid = s.oid
|
||||
|
||||
if index < nbytes:
|
||||
await s.requestReadBytes()
|
||||
@ -209,6 +233,10 @@ method readOnce*(s: BufferStream,
|
||||
## If internal buffer is not empty, ``nbytes`` bytes will be transferred from
|
||||
## internal buffer, otherwise it will wait until some bytes will be received.
|
||||
##
|
||||
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
if s.readBuf.len == 0:
|
||||
await s.requestReadBytes()
|
||||
|
||||
@ -216,7 +244,7 @@ method readOnce*(s: BufferStream,
|
||||
await s.readExactly(pbytes, len)
|
||||
result = len
|
||||
|
||||
method write*(s: BufferStream, msg: seq[byte]): Future[void] =
|
||||
method write*(s: BufferStream, msg: seq[byte]) {.async.} =
|
||||
## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer
|
||||
## stream ``wstream``.
|
||||
##
|
||||
@ -226,12 +254,14 @@ method write*(s: BufferStream, msg: seq[byte]): Future[void] =
|
||||
## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to
|
||||
## stream.
|
||||
##
|
||||
if isNil(s.writeHandler):
|
||||
var retFuture = newFuture[void]("BufferStream.write(seq)")
|
||||
retFuture.fail(newNotWritableError())
|
||||
return retFuture
|
||||
|
||||
result = s.writeHandler(msg)
|
||||
if s.closed:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
if isNil(s.writeHandler):
|
||||
raise newNotWritableError()
|
||||
|
||||
await s.writeHandler(msg)
|
||||
|
||||
proc pipe*(s: BufferStream,
|
||||
target: BufferStream): BufferStream =
|
||||
@ -242,10 +272,10 @@ proc pipe*(s: BufferStream,
|
||||
## interface methods `read*` and `write` are
|
||||
## piped.
|
||||
##
|
||||
if s.isPiped:
|
||||
if not(isNil(s.piped)):
|
||||
raise newAlreadyPipedError()
|
||||
|
||||
s.isPiped = true
|
||||
s.piped = target
|
||||
let oldHandler = target.writeHandler
|
||||
proc handler(data: seq[byte]) {.async, closure, gcsafe.} =
|
||||
if not isNil(oldHandler):
|
||||
@ -272,10 +302,10 @@ proc `|`*(s: BufferStream, target: BufferStream): BufferStream =
|
||||
## pipe operator to make piping less verbose
|
||||
pipe(s, target)
|
||||
|
||||
method close*(s: BufferStream) {.async.} =
|
||||
method close*(s: BufferStream) {.async, gcsafe.} =
|
||||
## close the stream and clear the buffer
|
||||
if not s.isClosed:
|
||||
trace "closing bufferstream"
|
||||
trace "closing bufferstream", oid = s.oid
|
||||
for r in s.readReqs:
|
||||
if not(isNil(r)) and not(r.finished()):
|
||||
r.fail(newLPStreamEOFError())
|
||||
@ -283,7 +313,10 @@ method close*(s: BufferStream) {.async.} =
|
||||
s.readBuf.clear()
|
||||
s.closeEvent.fire()
|
||||
s.isClosed = true
|
||||
|
||||
inc getBufferStreamTracker().closed
|
||||
libp2p_open_bufferstream.dec()
|
||||
|
||||
trace "bufferstream closed", oid = s.oid
|
||||
else:
|
||||
trace "attempt to close an already closed bufferstream", trace=getStackTrace()
|
||||
trace "attempt to close an already closed bufferstream", trace = getStackTrace()
|
||||
|
@ -21,6 +21,7 @@ proc newChronosStream*(client: StreamTransport): ChronosStream =
|
||||
result.client = client
|
||||
result.closeEvent = newAsyncEvent()
|
||||
|
||||
|
||||
template withExceptions(body: untyped) =
|
||||
try:
|
||||
body
|
||||
@ -38,20 +39,23 @@ template withExceptions(body: untyped) =
|
||||
method readExactly*(s: ChronosStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int): Future[void] {.async.} =
|
||||
if s.client.atEof:
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
withExceptions:
|
||||
await s.client.readExactly(pbytes, nbytes)
|
||||
|
||||
method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
|
||||
if s.client.atEof:
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
withExceptions:
|
||||
result = await s.client.readOnce(pbytes, nbytes)
|
||||
|
||||
method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
|
||||
if s.closed:
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
if msg.len == 0:
|
||||
return
|
||||
|
||||
@ -63,6 +67,9 @@ method write*(s: ChronosStream, msg: seq[byte]) {.async.} =
|
||||
method closed*(s: ChronosStream): bool {.inline.} =
|
||||
result = s.client.closed
|
||||
|
||||
method atEof*(s: ChronosStream): bool {.inline.} =
|
||||
s.client.atEof()
|
||||
|
||||
method close*(s: ChronosStream) {.async.} =
|
||||
if not s.closed:
|
||||
trace "shutting chronos stream", address = $s.client.remoteAddress()
|
||||
|
@ -15,6 +15,7 @@ import ../varint,
|
||||
type
|
||||
LPStream* = ref object of RootObj
|
||||
isClosed*: bool
|
||||
isEof*: bool
|
||||
closeEvent*: AsyncEvent
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
oid*: Oid
|
||||
@ -28,6 +29,7 @@ type
|
||||
LPStreamWriteError* = object of LPStreamError
|
||||
par*: ref Exception
|
||||
LPStreamEOFError* = object of LPStreamError
|
||||
LPStreamClosedError* = object of LPStreamError
|
||||
|
||||
InvalidVarintError* = object of LPStreamError
|
||||
MaxSizeError* = object of LPStreamError
|
||||
@ -59,9 +61,15 @@ proc newLPStreamIncorrectDefect*(m: string): ref Exception =
|
||||
proc newLPStreamEOFError*(): ref Exception =
|
||||
result = newException(LPStreamEOFError, "Stream EOF!")
|
||||
|
||||
proc newLPStreamClosedError*(): ref Exception =
|
||||
result = newException(LPStreamClosedError, "Stream Closed!")
|
||||
|
||||
method closed*(s: LPStream): bool {.base, inline.} =
|
||||
s.isClosed
|
||||
|
||||
method atEof*(s: LPStream): bool {.base, inline.} =
|
||||
s.isEof
|
||||
|
||||
method readExactly*(s: LPStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
|
@ -43,8 +43,8 @@ proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} =
|
||||
|
||||
proc dumpTracking(): string {.gcsafe.} =
|
||||
var tracker = getTcpTransportTracker()
|
||||
result = "Opened transports: " & $tracker.opened & "\n" &
|
||||
"Closed transports: " & $tracker.closed
|
||||
result = "Opened tcp transports: " & $tracker.opened & "\n" &
|
||||
"Closed tcp transports: " & $tracker.closed
|
||||
|
||||
proc leakTransport(): bool {.gcsafe.} =
|
||||
var tracker = getTcpTransportTracker()
|
||||
@ -98,7 +98,6 @@ proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T =
|
||||
|
||||
method initTransport*(t: TcpTransport) =
|
||||
t.multicodec = multiCodec("tcp")
|
||||
|
||||
inc getTcpTransportTracker().opened
|
||||
|
||||
method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||
|
@ -122,7 +122,7 @@ suite "Mplex":
|
||||
await chann.close()
|
||||
try:
|
||||
await chann.write("Hello")
|
||||
except LPStreamEOFError:
|
||||
except LPStreamClosedError:
|
||||
result = true
|
||||
finally:
|
||||
await chann.reset()
|
||||
@ -141,7 +141,7 @@ suite "Mplex":
|
||||
chann = newChannel(1, conn, true)
|
||||
|
||||
await chann.pushTo(("Hello!").toBytes)
|
||||
let closeFut = chann.closedByRemote()
|
||||
let closeFut = chann.closeRemote()
|
||||
|
||||
var data = newSeq[byte](6)
|
||||
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer
|
||||
@ -163,7 +163,7 @@ suite "Mplex":
|
||||
let
|
||||
conn = newConnection(newBufferStream(writeHandler))
|
||||
chann = newChannel(1, conn, true)
|
||||
await chann.closedByRemote()
|
||||
await chann.closeRemote()
|
||||
try:
|
||||
await chann.pushTo(@[byte(1)])
|
||||
except LPStreamEOFError:
|
||||
@ -204,7 +204,7 @@ suite "Mplex":
|
||||
await chann.reset()
|
||||
try:
|
||||
await chann.write(("Hello!").toBytes)
|
||||
except LPStreamEOFError:
|
||||
except LPStreamClosedError:
|
||||
result = true
|
||||
finally:
|
||||
await conn.close()
|
||||
@ -401,7 +401,7 @@ suite "Mplex":
|
||||
|
||||
let mplexDial = newMplex(conn)
|
||||
# TODO: Reenable once half-closed is working properly
|
||||
# let mplexDialFut = mplexDial.handle()
|
||||
let mplexDialFut = mplexDial.handle()
|
||||
for i in 1..10:
|
||||
let stream = await mplexDial.newStream()
|
||||
await stream.writeLp(&"stream {i}!")
|
||||
@ -409,7 +409,7 @@ suite "Mplex":
|
||||
|
||||
await done.wait(10.seconds)
|
||||
await conn.close()
|
||||
# await mplexDialFut
|
||||
await mplexDialFut
|
||||
await allFuturesThrowing(transport1.close(), transport2.close())
|
||||
await listenFut
|
||||
|
||||
|
@ -71,6 +71,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) =
|
||||
suite "Noise":
|
||||
teardown:
|
||||
for tracker in testTrackers():
|
||||
# echo tracker.dump()
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "e2e: handle write + noise":
|
||||
|
@ -56,7 +56,7 @@ suite "Switch":
|
||||
check tracker.isLeaked() == false
|
||||
|
||||
test "e2e use switch dial proto string":
|
||||
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||
proc testSwitch() {.async, gcsafe.} =
|
||||
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||
let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||
|
||||
@ -86,16 +86,12 @@ suite "Switch":
|
||||
|
||||
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
||||
|
||||
try:
|
||||
await conn.writeLp("Hello!")
|
||||
let msg = cast[string](await conn.readLp(1024))
|
||||
check "Hello!" == msg
|
||||
result = true
|
||||
except LPStreamError:
|
||||
result = false
|
||||
|
||||
await allFuturesThrowing(
|
||||
done.wait(5000.millis) #[if OK won't happen!!]#,
|
||||
done.wait(5.seconds) #[if OK won't happen!!]#,
|
||||
conn.close(),
|
||||
switch1.stop(),
|
||||
switch2.stop(),
|
||||
@ -104,8 +100,7 @@ suite "Switch":
|
||||
# this needs to go at end
|
||||
await allFuturesThrowing(awaiters)
|
||||
|
||||
check:
|
||||
waitFor(testSwitch()) == true
|
||||
waitFor(testSwitch())
|
||||
|
||||
test "e2e use switch no proto string":
|
||||
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||
|
Loading…
x
Reference in New Issue
Block a user