diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 99aaafa..a5a00a4 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -7,6 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. +import oids, deques import chronos, chronicles import types, coder, @@ -22,21 +23,45 @@ export lpstream logScope: topic = "MplexChannel" +## Channel half-closed states +## +## | State | Closed local | Closed remote +## |============================================= +## | Read | Yes (until EOF) | No +## | Write | No | Yes +## + type LPChannel* = ref object of BufferStream - id*: uint64 - name*: string - conn*: Connection - initiator*: bool - isLazy*: bool - isOpen*: bool - isReset*: bool - closedLocal*: bool - closedRemote*: bool - handlerFuture*: Future[void] - msgCode*: MessageType - closeCode*: MessageType - resetCode*: MessageType + id*: uint64 # channel id + name*: string # name of the channel (for debugging) + conn*: Connection # wrapped connection used to for writing + initiator*: bool # initiated remotely or locally flag + isLazy*: bool # is channel lazy + isOpen*: bool # has channel been oppened (only used with isLazy) + closedLocal*: bool # has channel been closed locally + msgCode*: MessageType # cached in/out message code + closeCode*: MessageType # cached in/out close code + resetCode*: MessageType # cached in/out reset code + +proc open*(s: LPChannel) {.async, gcsafe.} + +template withWriteLock(lock: AsyncLock, body: untyped): untyped = + try: + await lock.acquire() + body + finally: + lock.release() + +template withEOFExceptions(body: untyped): untyped = + try: + body + except LPStreamEOFError as exc: + trace "muxed connection EOF", exc = exc.msg + except LPStreamClosedError as exc: + trace "muxed connection closed", exc = exc.msg + except LPStreamIncompleteError as exc: + trace "incomplete message", exc = exc.msg proc newChannel*(id: uint64, conn: Connection, @@ -55,110 +80,114 @@ proc newChannel*(id: uint64, result.isLazy = lazy let chan = result - proc writeHandler(data: seq[byte]): Future[void] {.async.} = + proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = + if chan.isLazy and not(chan.isOpen): + await chan.open() + # writes should happen in sequence - trace "sending data ", data = data.shortLog, - id = chan.id, - initiator = chan.initiator + trace "sending data", data = data.shortLog, + id = chan.id, + initiator = chan.initiator, + name = chan.name, + oid = chan.oid await conn.writeMsg(chan.id, chan.msgCode, data) # write header result.initBufferStream(writeHandler, size) + when chronicles.enabledLogLevel == LogLevel.TRACE: + result.name = if result.name.len > 0: result.name else: $result.oid + + trace "created new lpchannel", id = result.id, + oid = result.oid, + initiator = result.initiator, + name = result.name proc closeMessage(s: LPChannel) {.async.} = - await s.conn.writeMsg(s.id, s.closeCode) # write header + withEOFExceptions: + withWriteLock(s.writeLock): + trace "sending close message", id = s.id, + initiator = s.initiator, + name = s.name, + oid = s.oid -proc cleanUp*(s: LPChannel): Future[void] = - # method which calls the underlying buffer's `close` - # method used instead of `close` since it's overloaded to - # simulate half-closed streams - result = procCall close(BufferStream(s)) - -proc tryCleanup(s: LPChannel) {.async, inline.} = - # if stream is EOF, then cleanup immediatelly - if s.closedRemote and s.len == 0: - await s.cleanUp() - -proc closedByRemote*(s: LPChannel) {.async.} = - s.closedRemote = true - if s.len == 0: - await s.cleanUp() - -proc open*(s: LPChannel): Future[void] = - s.isOpen = true - s.conn.writeMsg(s.id, MessageType.New, s.name) - -method close*(s: LPChannel) {.async, gcsafe.} = - s.closedLocal = true - await s.closeMessage() + await s.conn.writeMsg(s.id, s.closeCode) # write close proc resetMessage(s: LPChannel) {.async.} = - await s.conn.writeMsg(s.id, s.resetCode) + withEOFExceptions: + withWriteLock(s.writeLock): + trace "sending reset message", id = s.id, + initiator = s.initiator, + name = s.name, + oid = s.oid -proc resetByRemote*(s: LPChannel) {.async.} = - # Immediately block futher calls - s.isReset = true + await s.conn.writeMsg(s.id, s.resetCode) # write reset - # start and await async teardown - let - futs = await allFinished( - s.close(), - s.closedByRemote(), - s.cleanUp() - ) +proc open*(s: LPChannel) {.async, gcsafe.} = + ## NOTE: Don't call withExcAndLock or withWriteLock, + ## because this already gets called from writeHandler + ## which is locked + withEOFExceptions: + await s.conn.writeMsg(s.id, MessageType.New, s.name) + trace "oppened channel", oid = s.oid, + name = s.name, + initiator = s.initiator + s.isOpen = true - checkFutures(futs, [LPStreamEOFError]) +proc closeRemote*(s: LPChannel) {.async.} = + trace "got EOF, closing channel", id = s.id, + initiator = s.initiator, + name = s.name, + oid = s.oid -proc reset*(s: LPChannel) {.async.} = - let - futs = await allFinished( - s.resetMessage(), - s.resetByRemote() - ) + # wait for all data in the buffer to be consumed + while s.len > 0: + await s.dataReadEvent.wait() + s.dataReadEvent.clear() - checkFutures(futs, [LPStreamEOFError]) + # TODO: Not sure if this needs to be set here or bfore consuming + # the buffer + s.isEof = true # set EOF immediately to prevent further reads + await procCall BufferStream(s).close() # close parent bufferstream + + trace "channel closed on EOF", id = s.id, + initiator = s.initiator, + oid = s.oid, + name = s.name method closed*(s: LPChannel): bool = - trace "closing lpchannel", id = s.id, initiator = s.initiator - result = s.closedRemote and s.len == 0 + ## this emulates half-closed behavior + ## when closed locally writing is + ## dissabled - see the table in the + ## header of the file + s.closedLocal -proc pushTo*(s: LPChannel, data: seq[byte]): Future[void] = - if s.closedRemote or s.isReset: - var retFuture = newFuture[void]("LPChannel.pushTo") - retFuture.fail(newLPStreamEOFError()) - return retFuture +method close*(s: LPChannel) {.async, gcsafe.} = + if s.closedLocal: + return - trace "pushing data to channel", data = data.shortLog, - id = s.id, - initiator = s.initiator + trace "closing local lpchannel", id = s.id, + initiator = s.initiator, + name = s.name, + oid = s.oid + # TODO: we should install a timer that on expire + # will make sure the channel did close by the remote + # so the hald-closed flow completed, if it didn't + # we should send a `reset` and move on. + await s.closeMessage() + s.closedLocal = true + if s.atEof: # already closed by remote close parent buffer imediately + await procCall BufferStream(s).close() - result = procCall pushTo(BufferStream(s), data) + trace "lpchannel closed local", id = s.id, + initiator = s.initiator, + name = s.name, + oid = s.oid -template raiseEOF(): untyped = - if s.closed or s.isReset: - raise newLPStreamEOFError() - -method readExactly*(s: LPChannel, - pbytes: pointer, - nbytes: int): - Future[void] {.async.} = - raiseEOF() - await procCall readExactly(BufferStream(s), pbytes, nbytes) - await s.tryCleanup() - -method readOnce*(s: LPChannel, - pbytes: pointer, - nbytes: int): - Future[int] {.async.} = - raiseEOF() - result = await procCall readOnce(BufferStream(s), pbytes, nbytes) - await s.tryCleanup() - -method write*(s: LPChannel, msg: seq[byte]) {.async.} = - if s.closedLocal or s.isReset: - raise newLPStreamEOFError() - - if s.isLazy and not s.isOpen: - await s.open() - - await procCall write(BufferStream(s), msg) +method reset*(s: LPChannel) {.base, async.} = + # we asyncCheck here because the other end + # might be dead already - reset is always + # optimistic + asyncCheck s.resetMessage() + await procCall BufferStream(s).close() + s.isEof = true + s.closedLocal = true diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index fe2cb55..adf2fe3 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -7,15 +7,12 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -## TODO: -## Timeouts and message limits are still missing -## they need to be added ASAP - -import tables, sequtils +import tables, sequtils, oids import chronos, chronicles import ../muxer, ../../connection, ../../stream/lpstream, + ../../stream/bufferstream, ../../utility, ../../errors, coder, @@ -27,18 +24,21 @@ logScope: type Mplex* = ref object of Muxer - remote*: Table[uint64, LPChannel] - local*: Table[uint64, LPChannel] - handlers*: array[2, Table[uint64, Future[void]]] + remote: Table[uint64, LPChannel] + local: Table[uint64, LPChannel] + handlerFuts: seq[Future[void]] currentId*: uint64 maxChannels*: uint64 + isClosed: bool + when chronicles.enabledLogLevel == LogLevel.TRACE: + oid*: Oid proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = if initiator: - trace "picking local channels", initiator = initiator + trace "picking local channels", initiator = initiator, oid = m.oid result = m.local else: - trace "picking remote channels", initiator = initiator + trace "picking remote channels", initiator = initiator, oid = m.oid result = m.remote proc newStreamInternal*(m: Mplex, @@ -49,36 +49,27 @@ proc newStreamInternal*(m: Mplex, Future[LPChannel] {.async, gcsafe.} = ## create new channel/stream let id = if initiator: m.currentId.inc(); m.currentId else: chanId - trace "creating new channel", channelId = id, initiator = initiator - result = newChannel(id, m.connection, initiator, name, lazy = lazy) + trace "creating new channel", channelId = id, + initiator = initiator, + name = name, + oid = m.oid + result = newChannel(id, + m.connection, + initiator, + name, + lazy = lazy) m.getChannelList(initiator)[id] = result -proc cleanupChann(m: Mplex, chann: LPChannel, initiator: bool) {.async.} = - ## call the channel's `close` to signal the - ## remote that the channel is closing - if not isNil(chann) and not chann.closed: - trace "cleaning up channel", id = chann.id - await chann.close() - await chann.cleanUp() - m.getChannelList(initiator).del(chann.id) - trace "cleaned up channel", id = chann.id - -proc cleanupChann(chann: LPChannel) {.async.} = - trace "cleaning up channel", id = chann.id - await chann.reset() - await chann.close() - await chann.cleanUp() - trace "cleaned up channel", id = chann.id - method handle*(m: Mplex) {.async, gcsafe.} = - trace "starting mplex main loop" + trace "starting mplex main loop", oid = m.oid try: while not m.connection.closed: - trace "waiting for data" + trace "waiting for data", oid = m.oid let (id, msgType, data) = await m.connection.readMsg() trace "read message from connection", id = id, msgType = msgType, - data = data.shortLog + data = data.shortLog, + oid = m.oid let initiator = bool(ord(msgType) and 1) var channel: LPChannel if MessageType(msgType) != MessageType.New: @@ -86,7 +77,8 @@ method handle*(m: Mplex) {.async, gcsafe.} = if id notin channels: trace "Channel not found, skipping", id = id, initiator = initiator, - msg = msgType + msg = msgType, + oid = m.oid continue channel = channels[id] @@ -94,26 +86,33 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.New: let name = cast[string](data) channel = await m.newStreamInternal(false, id, name) - trace "created channel", id = id, name = name, inititator = initiator + trace "created channel", id = id, + name = name, + inititator = channel.initiator, + channoid = channel.oid, + oid = m.oid if not isNil(m.streamHandler): let stream = newConnection(channel) stream.peerInfo = m.connection.peerInfo + var fut = newFuture[void]() proc handler() {.async.} = tryAndWarn "mplex channel handler": await m.streamHandler(stream) - if not initiator: - await m.cleanupChann(channel, false) - if not initiator: - m.handlers[0][id] = handler() - else: - m.handlers[1][id] = handler() + fut = handler() + m.handlerFuts.add(fut) + fut.addCallback do(udata: pointer): + m.handlerFuts.keepItIf(it != fut) + of MessageType.MsgIn, MessageType.MsgOut: trace "pushing data to channel", id = id, initiator = initiator, msgType = msgType, - size = data.len + size = data.len, + name = channel.name, + channoid = channel.oid, + oid = m.oid if data.len > MaxMsgSize: raise newLPStreamLimitError() @@ -121,29 +120,42 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.CloseIn, MessageType.CloseOut: trace "closing channel", id = id, initiator = initiator, - msgType = msgType + msgType = msgType, + name = channel.name, + channoid = channel.oid, + oid = m.oid - await channel.closedByRemote() + await channel.closeRemote() m.getChannelList(initiator).del(id) + trace "deleted channel", id = id, + initiator = initiator, + msgType = msgType, + name = channel.name, + channoid = channel.oid, + oid = m.oid of MessageType.ResetIn, MessageType.ResetOut: trace "resetting channel", id = id, initiator = initiator, - msgType = msgType + msgType = msgType, + name = channel.name, + channoid = channel.oid, + oid = m.oid - await channel.resetByRemote() + await channel.reset() m.getChannelList(initiator).del(id) + trace "deleted channel", id = id, + initiator = initiator, + msgType = msgType, + name = channel.name, + channoid = channel.oid, + oid = m.oid break except CatchableError as exc: - trace "Exception occurred", exception = exc.msg + trace "Exception occurred", exception = exc.msg, oid = m.oid finally: - trace "stopping mplex main loop" + trace "stopping mplex main loop", oid = m.oid await m.close() -proc internalCleanup(m: Mplex, conn: Connection) {.async.} = - await conn.closeEvent.wait() - trace "connection closed, cleaning up mplex" - await m.close() - proc newMplex*(conn: Connection, maxChanns: uint = MaxChannels): Mplex = new result @@ -152,7 +164,16 @@ proc newMplex*(conn: Connection, result.remote = initTable[uint64, LPChannel]() result.local = initTable[uint64, LPChannel]() - asyncCheck result.internalCleanup(conn) + when chronicles.enabledLogLevel == LogLevel.TRACE: + result.oid = genOid() + +proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = + ## remove the local channel from the internal tables + ## + await chann.closeEvent.wait() + if not isNil(chann): + m.getChannelList(true).del(chann.id) + trace "cleaned up channel", id = chann.id method newStream*(m: Mplex, name: string = "", @@ -163,24 +184,23 @@ method newStream*(m: Mplex, result = newConnection(channel) result.peerInfo = m.connection.peerInfo + asyncCheck m.cleanupChann(channel) + method close*(m: Mplex) {.async, gcsafe.} = - trace "closing mplex muxer" + if m.isClosed: + return - if not m.connection.closed(): - await m.connection.close() + trace "closing mplex muxer", oid = m.oid - let - futs = await allFinished( - toSeq(m.remote.values).mapIt(it.cleanupChann()) & - toSeq(m.local.values).mapIt(it.cleanupChann()) & - toSeq(m.handlers[0].values).mapIt(it) & - toSeq(m.handlers[1].values).mapIt(it)) + checkFutures( + await allFinished( + toSeq(m.remote.values).mapIt(it.reset()) & + toSeq(m.local.values).mapIt(it.reset()))) - checkFutures(futs) + checkFutures(await allFinished(m.handlerFuts)) - m.handlers[0].clear() - m.handlers[1].clear() + await m.connection.close() m.remote.clear() m.local.clear() - - trace "mplex muxer closed" + m.handlerFuts = @[] + m.isClosed = true diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 9036183..4fb0df2 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -20,7 +20,8 @@ import pubsub, ../../peerinfo, ../../connection, ../../peer, - ../../errors + ../../errors, + ../../utility logScope: topic = "GossipSub" diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index e81aa1a..565d1e5 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -113,7 +113,7 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = p.onConnect.wait().addCallback do (udata: pointer): asyncCheck sendToRemote() trace "enqueued message to send at a later time", peer = p.id, - encoded = encodedHex + encoded = digest except CatchableError as exc: trace "Exception occurred in PubSubPeer.send", exc = exc.msg diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 83b234b..a93d04a 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -55,8 +55,7 @@ method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection] result = await s.handleConn(conn, initiator) except CatchableError as exc: warn "securing connection failed", msg = exc.msg - if not conn.closed(): - await conn.close() + await conn.close() method readExactly*(s: SecureConn, pbytes: pointer, diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 6437b7b..39934eb 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -36,26 +36,18 @@ import ../stream/lpstream export lpstream +logScope: + topic = "BufferStream" + +declareGauge libp2p_open_bufferstream, "open BufferStream instances" + const - BufferStreamTrackerName* = "libp2p.bufferstream" DefaultBufferSize* = 1024 +const + BufferStreamTrackerName* = "libp2p.bufferstream" + type - # TODO: figure out how to make this generic to avoid casts - WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} - - BufferStream* = ref object of LPStream - maxSize*: int # buffer's max size in bytes - readBuf: Deque[byte] # this is a ring buffer based dequeue, this makes it perfect as the backing store here - readReqs: Deque[Future[void]] # use dequeue to fire reads in order - dataReadEvent: AsyncEvent - writeHandler*: WriteHandler - lock: AsyncLock - isPiped: bool - - AlreadyPipedError* = object of CatchableError - NotWritableError* = object of CatchableError - BufferStreamTracker* = ref object of TrackerBase opened*: uint64 closed*: uint64 @@ -83,7 +75,23 @@ proc setupBufferStreamTracker(): BufferStreamTracker = result.dump = dumpTracking result.isLeaked = leakTransport addTracker(BufferStreamTrackerName, result) -declareGauge libp2p_open_bufferstream, "open BufferStream instances" + +type + # TODO: figure out how to make this generic to avoid casts + WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} + + BufferStream* = ref object of LPStream + maxSize*: int # buffer's max size in bytes + readBuf: Deque[byte] # this is a ring buffer based dequeue + readReqs*: Deque[Future[void]] # use dequeue to fire reads in order + dataReadEvent*: AsyncEvent # event triggered when data has been consumed from the internal buffer + writeHandler*: WriteHandler # user provided write callback + writeLock*: AsyncLock # write lock to guarantee ordered writes + lock: AsyncLock # pushTo lock to guarantee ordered reads + piped: BufferStream # a piped bufferstream instance + + AlreadyPipedError* = object of CatchableError + NotWritableError* = object of CatchableError proc newAlreadyPipedError*(): ref Exception {.inline.} = result = newException(AlreadyPipedError, "stream already piped") @@ -96,7 +104,7 @@ proc requestReadBytes(s: BufferStream): Future[void] = ## data becomes available in the read buffer result = newFuture[void]() s.readReqs.addLast(result) - trace "requestReadBytes(): added a future to readReqs" + trace "requestReadBytes(): added a future to readReqs", oid = s.oid proc initBufferStream*(s: BufferStream, handler: WriteHandler = nil, @@ -106,12 +114,29 @@ proc initBufferStream*(s: BufferStream, s.readReqs = initDeque[Future[void]]() s.dataReadEvent = newAsyncEvent() s.lock = newAsyncLock() - s.writeHandler = handler + s.writeLock = newAsyncLock() s.closeEvent = newAsyncEvent() - inc getBufferStreamTracker().opened + s.isClosed = false + + if not(isNil(handler)): + s.writeHandler = proc (data: seq[byte]) {.async, gcsafe.} = + try: + # Using a lock here to guarantee + # proper write ordering. This is + # specially important when + # implementing half-closed in mplex + # or other functionality that requires + # strict message ordering + await s.writeLock.acquire() + await handler(data) + finally: + s.writeLock.release() + when chronicles.enabledLogLevel == LogLevel.TRACE: s.oid = genOid() - s.isClosed = false + + trace "created bufferstream", oid = s.oid + inc getBufferStreamTracker().opened libp2p_open_bufferstream.inc() proc newBufferStream*(handler: WriteHandler = nil, @@ -133,7 +158,7 @@ proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) = proc len*(s: BufferStream): int = s.readBuf.len -proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} = +method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} = ## Write bytes to internal read buffer, use this to fill up the ## buffer with data. ## @@ -142,9 +167,8 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} = ## is preserved. ## - when chronicles.enabledLogLevel == LogLevel.TRACE: - logScope: - stream_oid = $s.oid + if s.atEof: + raise newLPStreamEOFError() try: await s.lock.acquire() @@ -153,12 +177,12 @@ proc pushTo*(s: BufferStream, data: seq[byte]) {.async.} = while index < data.len and s.readBuf.len < s.maxSize: s.readBuf.addLast(data[index]) inc(index) - trace "pushTo()", msg = "added " & $index & " bytes to readBuf" + trace "pushTo()", msg = "added " & $index & " bytes to readBuf", oid = s.oid # resolve the next queued read request if s.readReqs.len > 0: s.readReqs.popFirst().complete() - trace "pushTo(): completed a readReqs future" + trace "pushTo(): completed a readReqs future", oid = s.oid if index >= data.len: return @@ -180,11 +204,11 @@ method readExactly*(s: BufferStream, ## If EOF is received and ``nbytes`` is not yet read, the procedure ## will raise ``LPStreamIncompleteError``. ## - when chronicles.enabledLogLevel == LogLevel.TRACE: - logScope: - stream_oid = $s.oid - trace "read()", requested_bytes = nbytes + if s.atEof: + raise newLPStreamEOFError() + + trace "readExactly()", requested_bytes = nbytes, oid = s.oid var index = 0 if s.readBuf.len() == 0: @@ -195,7 +219,7 @@ method readExactly*(s: BufferStream, while s.readBuf.len() > 0 and index < nbytes: output[index] = s.popFirst() inc(index) - trace "readExactly()", read_bytes = index + trace "readExactly()", read_bytes = index, oid = s.oid if index < nbytes: await s.requestReadBytes() @@ -209,6 +233,10 @@ method readOnce*(s: BufferStream, ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## internal buffer, otherwise it will wait until some bytes will be received. ## + + if s.atEof: + raise newLPStreamEOFError() + if s.readBuf.len == 0: await s.requestReadBytes() @@ -216,7 +244,7 @@ method readOnce*(s: BufferStream, await s.readExactly(pbytes, len) result = len -method write*(s: BufferStream, msg: seq[byte]): Future[void] = +method write*(s: BufferStream, msg: seq[byte]) {.async.} = ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer ## stream ``wstream``. ## @@ -226,12 +254,14 @@ method write*(s: BufferStream, msg: seq[byte]): Future[void] = ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. ## - if isNil(s.writeHandler): - var retFuture = newFuture[void]("BufferStream.write(seq)") - retFuture.fail(newNotWritableError()) - return retFuture - result = s.writeHandler(msg) + if s.closed: + raise newLPStreamClosedError() + + if isNil(s.writeHandler): + raise newNotWritableError() + + await s.writeHandler(msg) proc pipe*(s: BufferStream, target: BufferStream): BufferStream = @@ -242,10 +272,10 @@ proc pipe*(s: BufferStream, ## interface methods `read*` and `write` are ## piped. ## - if s.isPiped: + if not(isNil(s.piped)): raise newAlreadyPipedError() - s.isPiped = true + s.piped = target let oldHandler = target.writeHandler proc handler(data: seq[byte]) {.async, closure, gcsafe.} = if not isNil(oldHandler): @@ -272,10 +302,10 @@ proc `|`*(s: BufferStream, target: BufferStream): BufferStream = ## pipe operator to make piping less verbose pipe(s, target) -method close*(s: BufferStream) {.async.} = +method close*(s: BufferStream) {.async, gcsafe.} = ## close the stream and clear the buffer if not s.isClosed: - trace "closing bufferstream" + trace "closing bufferstream", oid = s.oid for r in s.readReqs: if not(isNil(r)) and not(r.finished()): r.fail(newLPStreamEOFError()) @@ -283,7 +313,10 @@ method close*(s: BufferStream) {.async.} = s.readBuf.clear() s.closeEvent.fire() s.isClosed = true + inc getBufferStreamTracker().closed libp2p_open_bufferstream.dec() + + trace "bufferstream closed", oid = s.oid else: - trace "attempt to close an already closed bufferstream", trace=getStackTrace() + trace "attempt to close an already closed bufferstream", trace = getStackTrace() diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index fce18d4..597f3a4 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -21,6 +21,7 @@ proc newChronosStream*(client: StreamTransport): ChronosStream = result.client = client result.closeEvent = newAsyncEvent() + template withExceptions(body: untyped) = try: body @@ -38,20 +39,23 @@ template withExceptions(body: untyped) = method readExactly*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[void] {.async.} = - if s.client.atEof: + if s.atEof: raise newLPStreamEOFError() withExceptions: await s.client.readExactly(pbytes, nbytes) method readOnce*(s: ChronosStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = - if s.client.atEof: + if s.atEof: raise newLPStreamEOFError() withExceptions: result = await s.client.readOnce(pbytes, nbytes) method write*(s: ChronosStream, msg: seq[byte]) {.async.} = + if s.closed: + raise newLPStreamClosedError() + if msg.len == 0: return @@ -63,6 +67,9 @@ method write*(s: ChronosStream, msg: seq[byte]) {.async.} = method closed*(s: ChronosStream): bool {.inline.} = result = s.client.closed +method atEof*(s: ChronosStream): bool {.inline.} = + s.client.atEof() + method close*(s: ChronosStream) {.async.} = if not s.closed: trace "shutting chronos stream", address = $s.client.remoteAddress() diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 86188be..e22aaca 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -15,6 +15,7 @@ import ../varint, type LPStream* = ref object of RootObj isClosed*: bool + isEof*: bool closeEvent*: AsyncEvent when chronicles.enabledLogLevel == LogLevel.TRACE: oid*: Oid @@ -28,6 +29,7 @@ type LPStreamWriteError* = object of LPStreamError par*: ref Exception LPStreamEOFError* = object of LPStreamError + LPStreamClosedError* = object of LPStreamError InvalidVarintError* = object of LPStreamError MaxSizeError* = object of LPStreamError @@ -59,9 +61,15 @@ proc newLPStreamIncorrectDefect*(m: string): ref Exception = proc newLPStreamEOFError*(): ref Exception = result = newException(LPStreamEOFError, "Stream EOF!") +proc newLPStreamClosedError*(): ref Exception = + result = newException(LPStreamClosedError, "Stream Closed!") + method closed*(s: LPStream): bool {.base, inline.} = s.isClosed +method atEof*(s: LPStream): bool {.base, inline.} = + s.isEof + method readExactly*(s: LPStream, pbytes: pointer, nbytes: int): diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 06a55e6..7a90a29 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -43,8 +43,8 @@ proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} = proc dumpTracking(): string {.gcsafe.} = var tracker = getTcpTransportTracker() - result = "Opened transports: " & $tracker.opened & "\n" & - "Closed transports: " & $tracker.closed + result = "Opened tcp transports: " & $tracker.opened & "\n" & + "Closed tcp transports: " & $tracker.closed proc leakTransport(): bool {.gcsafe.} = var tracker = getTcpTransportTracker() @@ -98,7 +98,6 @@ proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T = method initTransport*(t: TcpTransport) = t.multicodec = multiCodec("tcp") - inc getTcpTransportTracker().opened method close*(t: TcpTransport) {.async, gcsafe.} = diff --git a/tests/testmplex.nim b/tests/testmplex.nim index ad13e7c..16e31a2 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -122,7 +122,7 @@ suite "Mplex": await chann.close() try: await chann.write("Hello") - except LPStreamEOFError: + except LPStreamClosedError: result = true finally: await chann.reset() @@ -141,7 +141,7 @@ suite "Mplex": chann = newChannel(1, conn, true) await chann.pushTo(("Hello!").toBytes) - let closeFut = chann.closedByRemote() + let closeFut = chann.closeRemote() var data = newSeq[byte](6) await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer @@ -163,7 +163,7 @@ suite "Mplex": let conn = newConnection(newBufferStream(writeHandler)) chann = newChannel(1, conn, true) - await chann.closedByRemote() + await chann.closeRemote() try: await chann.pushTo(@[byte(1)]) except LPStreamEOFError: @@ -204,7 +204,7 @@ suite "Mplex": await chann.reset() try: await chann.write(("Hello!").toBytes) - except LPStreamEOFError: + except LPStreamClosedError: result = true finally: await conn.close() @@ -401,7 +401,7 @@ suite "Mplex": let mplexDial = newMplex(conn) # TODO: Reenable once half-closed is working properly - # let mplexDialFut = mplexDial.handle() + let mplexDialFut = mplexDial.handle() for i in 1..10: let stream = await mplexDial.newStream() await stream.writeLp(&"stream {i}!") @@ -409,7 +409,7 @@ suite "Mplex": await done.wait(10.seconds) await conn.close() - # await mplexDialFut + await mplexDialFut await allFuturesThrowing(transport1.close(), transport2.close()) await listenFut diff --git a/tests/testnoise.nim b/tests/testnoise.nim index 59b9e4e..c167bc8 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -71,6 +71,7 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) = suite "Noise": teardown: for tracker in testTrackers(): + # echo tracker.dump() check tracker.isLeaked() == false test "e2e: handle write + noise": diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 31cf31c..dfe23c4 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -56,7 +56,7 @@ suite "Switch": check tracker.isLeaked() == false test "e2e use switch dial proto string": - proc testSwitch(): Future[bool] {.async, gcsafe.} = + proc testSwitch() {.async, gcsafe.} = let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") @@ -86,16 +86,12 @@ suite "Switch": let conn = await switch2.dial(switch1.peerInfo, TestCodec) - try: - await conn.writeLp("Hello!") - let msg = cast[string](await conn.readLp(1024)) - check "Hello!" == msg - result = true - except LPStreamError: - result = false + await conn.writeLp("Hello!") + let msg = cast[string](await conn.readLp(1024)) + check "Hello!" == msg await allFuturesThrowing( - done.wait(5000.millis) #[if OK won't happen!!]#, + done.wait(5.seconds) #[if OK won't happen!!]#, conn.close(), switch1.stop(), switch2.stop(), @@ -104,8 +100,7 @@ suite "Switch": # this needs to go at end await allFuturesThrowing(awaiters) - check: - waitFor(testSwitch()) == true + waitFor(testSwitch()) test "e2e use switch no proto string": proc testSwitch(): Future[bool] {.async, gcsafe.} =