diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 530ccb6..33342a9 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -13,8 +13,7 @@ import types, coder, ../muxer, nimcrypto/utils, - ../../stream/connection, - ../../stream/bufferstream, + ../../stream/[bufferstream, connection, streamseq], ../../peerinfo export connection @@ -51,10 +50,13 @@ type initiator*: bool # initiated remotely or locally flag isLazy*: bool # is channel lazy isOpen*: bool # has channel been opened (only used with isLazy) + isReset*: bool # channel was reset, pushTo should drop data + pushing*: bool closedLocal*: bool # has channel been closed locally msgCode*: MessageType # cached in/out message code closeCode*: MessageType # cached in/out close code resetCode*: MessageType # cached in/out reset code + writeLock: AsyncLock proc open*(s: LPChannel) {.async, gcsafe.} @@ -92,34 +94,30 @@ proc resetMessage(s: LPChannel) {.async.} = # need to re-raise CancelledError. debug "Unexpected cancellation while resetting channel", s except LPStreamEOFError as exc: - trace "muxed connection EOF", exc = exc.msg, s + trace "muxed connection EOF", s, exc = exc.msg except LPStreamClosedError as exc: - trace "muxed connection closed", exc = exc.msg, s + trace "muxed connection closed", s, exc = exc.msg except LPStreamIncompleteError as exc: - trace "incomplete message", exc = exc.msg, s + trace "incomplete message", s, exc = exc.msg except CatchableError as exc: - debug "Unhandled exception leak", exc = exc.msg, s + debug "Unhandled exception leak", s, exc = exc.msg proc open*(s: LPChannel) {.async, gcsafe.} = await s.conn.writeMsg(s.id, MessageType.New, s.name) - trace "opened channel", s + trace "Opened channel", s s.isOpen = true proc closeRemote*(s: LPChannel) {.async.} = - trace "closing remote", s + trace "Closing remote", s try: - await s.drainBuffer() - s.isEof = true # set EOF immediately to prevent further reads # close parent bufferstream to prevent further reads await procCall BufferStream(s).close() - - trace "channel closed on EOF", s except CancelledError as exc: raise exc except CatchableError as exc: - trace "exception closing remote channel", exc = exc.msg, s + trace "exception closing remote channel", s, exc = exc.msg - trace "closed remote", s + trace "Closed remote", s method closed*(s: LPChannel): bool = ## this emulates half-closed behavior @@ -128,6 +126,16 @@ method closed*(s: LPChannel): bool = ## header of the file s.closedLocal +method pushTo*(s: LPChannel, data: seq[byte]) {.async.} = + if s.isReset: + raise newLPStreamClosedError() # Terminate mplex loop + + try: + s.pushing = true + await procCall BufferStream(s).pushTo(data) + finally: + s.pushing = false + method reset*(s: LPChannel) {.base, async, gcsafe.} = if s.closedLocal and s.isEof: trace "channel already closed or reset", s @@ -135,29 +143,38 @@ method reset*(s: LPChannel) {.base, async, gcsafe.} = trace "Resetting channel", s + # First, make sure any new calls to `readOnce` and `pushTo` will fail - there + # may already be such calls in the event queue + s.isEof = true + s.isReset = true + + s.readBuf = StreamSeq() + + s.closedLocal = true + asyncSpawn s.resetMessage() - try: - # drain the buffer before closing - await s.drainBuffer() - await procCall BufferStream(s).close() + # This should wake up any readers by pushing an EOF marker at least + await procCall BufferStream(s).close() # noraises, nocancels - s.isEof = true - s.closedLocal = true - - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "Exception in reset", exc = exc.msg, s + if s.pushing: + # When data is being pushed, there will be two items competing for the + # readQueue slot - the BufferStream.close EOF marker and the pushTo data. + # If the EOF wins, the pushTo call will get stuck because there will be no + # new readers to clear the data. It's worth noting that if there's a reader + # already waiting for data, this reader will be unblocked by the pushTo - + # this is necessary or it will get stuck + if s.readQueue.len > 0: + discard s.readQueue.popFirstNoWait() trace "Channel reset", s method close*(s: LPChannel) {.async, gcsafe.} = if s.closedLocal: - trace "channel already closed", s + trace "Already closed", s return - trace "closing local lpchannel", s + trace "Closing channel", s proc closeInternal() {.async.} = try: @@ -165,15 +182,17 @@ method close*(s: LPChannel) {.async, gcsafe.} = if s.atEof: # already closed by remote close parent buffer immediately await procCall BufferStream(s).close() except CancelledError: - trace "Unexpected cancellation while closing channel", s + debug "Unexpected cancellation while closing channel", s await s.reset() # This is top-level procedure which will work as separate task, so it # do not need to propogate CancelledError. - except CatchableError as exc: - trace "exception closing channel", exc = exc.msg, s + except LPStreamClosedError, LPStreamEOFError: + trace "Connection already closed", s + except CatchableError as exc: # Shouldn't happen? + debug "Exception closing channel", s, exc = exc.msg await s.reset() - trace "lpchannel closed local", s + trace "Closed channel", s s.closedLocal = true # All the errors are handled inside `closeInternal()` procedure. @@ -183,19 +202,38 @@ method initStream*(s: LPChannel) = if s.objName.len == 0: s.objName = "LPChannel" - s.timeoutHandler = proc() {.async, gcsafe.} = - trace "idle timeout expired, resetting LPChannel", s - await s.reset() + s.timeoutHandler = proc(): Future[void] {.gcsafe.} = + trace "Idle timeout expired, resetting LPChannel", s + s.reset() procCall BufferStream(s).initStream() + s.writeLock = newAsyncLock() + +method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = + if s.closedLocal: + raise newLPStreamClosedError() + + try: + if s.isLazy and not(s.isOpen): + await s.open() + + # writes should happen in sequence + trace "write msg", len = msg.len + + await s.conn.writeMsg(s.id, s.msgCode, msg) + s.activity = true + except CatchableError as exc: + trace "exception in lpchannel write handler", s, exc = exc.msg + await s.conn.close() + raise exc + proc init*( L: type LPChannel, id: uint64, conn: Connection, initiator: bool, name: string = "", - size: int = DefaultBufferSize, lazy: bool = false, timeout: Duration = DefaultChanTimeout): LPChannel = @@ -211,26 +249,11 @@ proc init*( resetCode: if initiator: MessageType.ResetOut else: MessageType.ResetIn, dir: if initiator: Direction.Out else: Direction.In) - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - try: - if chann.isLazy and not(chann.isOpen): - await chann.open() + chann.initStream() - # writes should happen in sequence - trace "sending data", len = data.len, conn, chann - - await conn.writeMsg(chann.id, - chann.msgCode, - data) - except CatchableError as exc: - trace "exception in lpchannel write handler", exc = exc.msg, chann - asyncSpawn conn.close() - raise exc - - chann.initBufferStream(writeHandler, size) when chronicles.enabledLogLevel == LogLevel.TRACE: chann.name = if chann.name.len > 0: chann.name else: $chann.oid - trace "created new lpchannel", chann + trace "Created new lpchannel", chann return chann diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 71d24c0..d54b844 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -129,7 +129,7 @@ proc handleStream(m: Mplex, chann: LPChannel) {.async.} = await chann.reset() method handle*(m: Mplex) {.async, gcsafe.} = - trace "Starting mplex main loop", m + trace "Starting mplex handler", m try: while not m.connection.atEof: trace "waiting for data", m @@ -189,11 +189,13 @@ method handle*(m: Mplex) {.async, gcsafe.} = # This procedure is spawned as task and it is not part of public API, so # there no way for this procedure to be cancelled implicitely. debug "Unexpected cancellation in mplex handler", m + except LPStreamEOFError as exc: + trace "Stream EOF", msg = exc.msg, m except CatchableError as exc: - trace "Exception occurred", exception = exc.msg, m + warn "Unexpected exception in mplex read loop", msg = exc.msg, m finally: - trace "stopping mplex main loop", m await m.close() + trace "Stopped mplex handler", m proc init*(M: type Mplex, conn: Connection, @@ -218,12 +220,12 @@ method newStream*(m: Mplex, method close*(m: Mplex) {.async, gcsafe.} = if m.isClosed: + trace "Already closed", m return - - trace "closing mplex muxer", m - m.isClosed = true + trace "Closing mplex", m + let channs = toSeq(m.channels[false].values) & toSeq(m.channels[true].values) for chann in channs: @@ -235,3 +237,5 @@ method close*(m: Mplex) {.async, gcsafe.} = # closed properly m.channels[false].clear() m.channels[true].clear() + + trace "Closed mplex", m diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 47962d7..caf378c 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -167,11 +167,12 @@ proc getSendConn(p: PubSubPeer): Future[Connection] {.async.} = # Another concurrent dial may have populated p.sendConn if p.sendConn != nil: let current = p.sendConn - if not current.isNil: - if not (current.closed() or current.atEof): - # The existing send connection looks like it might work - reuse it - trace "Reusing existing connection", oid = $current.oid - return current + if not (current.closed() or current.atEof): + # The existing send connection looks like it might work - reuse it + debug "Reusing existing connection", current + return current + else: + p.sendConn = nil # Grab a new send connection let (newConn, handshake) = await p.getConn() # ...and here diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 9ca59d7..26056a7 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -30,9 +30,11 @@ ## will suspend until either the amount of elements in the ## buffer goes below ``maxSize`` or more data becomes available. -import std/[deques, math, strformat] +import std/strformat +import stew/byteutils import chronos, chronicles, metrics import ../stream/connection +import ./streamseq when chronicles.enabledLogLevel == LogLevel.TRACE: import oids @@ -42,9 +44,6 @@ export connection logScope: topics = "bufferstream" -const - DefaultBufferSize* = 128 - const BufferStreamTrackerName* = "libp2p.bufferstream" @@ -78,27 +77,9 @@ proc setupBufferStreamTracker(): BufferStreamTracker = addTracker(BufferStreamTrackerName, result) type - # TODO: figure out how to make this generic to avoid casts - WriteHandler* = proc (data: seq[byte]): Future[void] {.gcsafe.} - BufferStream* = ref object of Connection - maxSize*: int # buffer's max size in bytes - readBuf: Deque[byte] # this is a ring buffer based dequeue - readReqs*: Deque[Future[void]] # use dequeue to fire reads in order - dataReadEvent*: AsyncEvent # event triggered when data has been consumed from the internal buffer - writeHandler*: WriteHandler # user provided write callback - writeLock*: AsyncLock # write lock to guarantee ordered writes - lock: AsyncLock # pushTo lock to guarantee ordered reads - piped: BufferStream # a piped bufferstream instance - - AlreadyPipedError* = object of CatchableError - NotWritableError* = object of CatchableError - -proc newAlreadyPipedError*(): ref CatchableError {.inline.} = - result = newException(AlreadyPipedError, "stream already piped") - -proc newNotWritableError*(): ref CatchableError {.inline.} = - result = newException(NotWritableError, "stream is not writable") + readQueue*: AsyncQueue[seq[byte]] # read queue for managing backpressure + readBuf*: StreamSeq # overflow buffer for readOnce func shortLog*(s: BufferStream): auto = if s.isNil: "BufferStream(nil)" @@ -106,217 +87,94 @@ func shortLog*(s: BufferStream): auto = else: &"{shortLog(s.peerInfo.peerId)}:{s.oid}" chronicles.formatIt(BufferStream): shortLog(it) -proc requestReadBytes(s: BufferStream): Future[void] = - ## create a future that will complete when more - ## data becomes available in the read buffer - result = newFuture[void]() - s.readReqs.addLast(result) - # trace "requestReadBytes(): added a future to readReqs", oid = s.oid +proc len*(s: BufferStream): int = + s.readBuf.len + (if s.readQueue.len > 0: s.readQueue[0].len() else: 0) method initStream*(s: BufferStream) = if s.objName.len == 0: s.objName = "BufferStream" procCall Connection(s).initStream() + + s.readQueue = newAsyncQueue[seq[byte]](1) + + trace "BufferStream created", s inc getBufferStreamTracker().opened -proc initBufferStream*(s: BufferStream, - handler: WriteHandler = nil, - size: int = DefaultBufferSize) = - s.initStream() - - s.maxSize = if isPowerOfTwo(size): size else: nextPowerOfTwo(size) - s.readBuf = initDeque[byte](s.maxSize) - s.readReqs = initDeque[Future[void]]() - s.dataReadEvent = newAsyncEvent() - s.lock = newAsyncLock() - s.writeLock = newAsyncLock() - - if not(isNil(handler)): - s.writeHandler = proc (data: seq[byte]) {.async, gcsafe.} = - defer: - if s.writeLock.locked: - s.writeLock.release() - - # Using a lock here to guarantee - # proper write ordering. This is - # specially important when - # implementing half-closed in mplex - # or other functionality that requires - # strict message ordering - - await s.writeLock.acquire() - await handler(data) - - trace "created bufferstream", s - -proc newBufferStream*(handler: WriteHandler = nil, - size: int = DefaultBufferSize, - timeout: Duration = DefaultConnectionTimeout): BufferStream = +proc newBufferStream*(timeout: Duration = DefaultConnectionTimeout): BufferStream = new result result.timeout = timeout - result.initBufferStream(handler, size) - -proc popFirst*(s: BufferStream): byte = - result = s.readBuf.popFirst() - s.dataReadEvent.fire() - -proc popLast*(s: BufferStream): byte = - result = s.readBuf.popLast() - s.dataReadEvent.fire() - -proc shrink(s: BufferStream, fromFirst = 0, fromLast = 0) = - s.readBuf.shrink(fromFirst, fromLast) - s.dataReadEvent.fire() - -proc len*(s: BufferStream): int = s.readBuf.len + result.initStream() method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} = ## Write bytes to internal read buffer, use this to fill up the ## buffer with data. ## - ## This method is async and will wait until all data has been - ## written to the internal buffer; this is done so that backpressure - ## is preserved. + ## `pushTo` will block if the queue is full, thus maintaining backpressure. ## - if s.atEof: + if s.isClosed: raise newLPStreamEOFError() - defer: - # trace "ended", size = s.len - s.lock.release() + if data.len == 0: + return # Don't push 0-length buffers, these signal EOF - await s.lock.acquire() - var index = 0 - while not s.closed(): - while index < data.len and s.readBuf.len < s.maxSize: - s.readBuf.addLast(data[index]) - inc(index) - # trace "pushTo()", msg = "added " & $s.len & " bytes to readBuf", oid = s.oid - - # resolve the next queued read request - if s.readReqs.len > 0: - s.readReqs.popFirst().complete() - # trace "pushTo(): completed a readReqs future", oid = s.oid - - if index >= data.len: - return - - # if we couldn't transfer all the data to the - # internal buf wait on a read event - await s.dataReadEvent.wait() - s.dataReadEvent.clear() - -proc drainBuffer*(s: BufferStream) {.async.} = - ## wait for all data in the buffer to be consumed - ## - - trace "draining buffer", len = s.len, s - while s.len > 0: - await s.dataReadEvent.wait() - s.dataReadEvent.clear() + # We will block here if there is already data queued, until it has been + # processed + trace "Pushing readQueue", s, len = data.len + await s.readQueue.addLast(data) method readOnce*(s: BufferStream, pbytes: pointer, nbytes: int): Future[int] {.async.} = - if s.atEof: + if s.isEof and s.readBuf.len() == 0: raise newLPStreamEOFError() - if s.len() == 0: - await s.requestReadBytes() + var + p = cast[ptr UncheckedArray[byte]](pbytes) - var index = 0 - var size = min(nbytes, s.len) - let output = cast[ptr UncheckedArray[byte]](pbytes) + # First consume leftovers from previous read + var rbytes = s.readBuf.consumeTo(toOpenArray(p, 0, nbytes - 1)) - s.activity = true # reset activity flag - while s.len() > 0 and index < size: - output[index] = s.popFirst() - inc(index) + if rbytes < nbytes: + # There's space in the buffer - consume some data from the read queue + trace "popping readQueue", s, rbytes, nbytes + let buf = await s.readQueue.popFirst() - return size - -method write*(s: BufferStream, msg: seq[byte]) {.async.} = - ## Write sequence of bytes ``sbytes`` of length ``msglen`` to writer - ## stream ``wstream``. - ## - ## Sequence of bytes ``sbytes`` must not be zero-length. - ## - ## If ``msglen < 0`` whole sequence ``sbytes`` will be writen to stream. - ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to - ## stream. - ## - - if s.closed: - raise newLPStreamClosedError() - - if isNil(s.writeHandler): - raise newNotWritableError() - - s.activity = true # reset activity flag - await s.writeHandler(msg) - -# TODO: move pipe routines out -proc pipe*(s: BufferStream, - target: BufferStream): BufferStream = - ## pipe the write end of this stream to - ## be the source of the target stream - ## - ## Note that this only works with the LPStream - ## interface methods `read*` and `write` are - ## piped. - ## - if not(isNil(s.piped)): - raise newAlreadyPipedError() - - s.piped = target - let oldHandler = target.writeHandler - proc handler(data: seq[byte]) {.async, closure, gcsafe.} = - if not isNil(oldHandler): - await oldHandler(data) - - # if we're piping to self, - # then add the data to the - # buffer directly and fire - # the read event - if s == target: - for b in data: - s.readBuf.addLast(b) - - # notify main loop of available - # data - s.dataReadEvent.fire() + if buf.len == 0: + # No more data will arrive on read queue + s.isEof = true else: - await target.pushTo(data) + let remaining = min(buf.len, nbytes - rbytes) + toOpenArray(p, rbytes, nbytes - 1)[0.. 0.millis) + if s.timeout > 0.millis: trace "Monitoring for timeout", s, timeout = s.timeout s.timerTaskFut = s.timeoutMonitor() + if isNil(s.timeoutHandler): + s.timeoutHandler = proc(): Future[void] = s.close() inc getConnectionTracker().opened @@ -133,8 +130,10 @@ proc timeoutMonitor(s: Connection) {.async, gcsafe.} = except CancelledError as exc: raise exc - except CatchableError as exc: - trace "exception in timeout", exc = exc.msg, s + except CatchableError as exc: # Shouldn't happen + warn "exception in timeout", s, exc = exc.msg + finally: + s.timerTaskFut = nil proc init*(C: type Connection, peerInfo: PeerInfo, diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index c308b3b..c59366c 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -74,13 +74,20 @@ proc newLPStreamEOFError*(): ref CatchableError = proc newLPStreamClosedError*(): ref Exception = result = newException(LPStreamClosedError, "Stream Closed!") +func shortLog*(s: LPStream): auto = + if s.isNil: "LPStream(nil)" + else: $s.oid +chronicles.formatIt(LPStream): shortLog(it) + method initStream*(s: LPStream) {.base.} = if s.objName.len == 0: s.objName = "LPStream" + s.closeEvent = newAsyncEvent() s.oid = genOid() + libp2p_open_streams.inc(labelValues = [s.objName]) - trace "stream created", oid = $s.oid, name = s.objName + trace "Stream created", s, objName = s.objName proc join*(s: LPStream): Future[void] = s.closeEvent.wait() @@ -102,15 +109,13 @@ proc readExactly*(s: LPStream, pbytes: pointer, nbytes: int): Future[void] {.async.} = - if s.atEof: raise newLPStreamEOFError() logScope: + s nbytes = nbytes - obName = s.objName - stack = getStackTrace() - oid = $s.oid + objName = s.objName var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) var read = 0 @@ -202,9 +207,14 @@ proc write*(s: LPStream, msg: string): Future[void] = s.write(msg.toBytes()) # TODO: split `close` into `close` and `dispose/destroy` -method close*(s: LPStream) {.base, async.} = - if not s.isClosed: - s.isClosed = true - s.closeEvent.fire() - libp2p_open_streams.dec(labelValues = [s.objName]) - trace "stream destroyed", oid = $s.oid, name = s.objName +method close*(s: LPStream) {.base, async.} = # {.raises [Defect].} + ## close the stream - this may block, but will not raise exceptions + ## + if s.isClosed: + trace "Already closed", s + return + + s.isClosed = true + s.closeEvent.fire() + libp2p_open_streams.dec(labelValues = [s.objName]) + trace "Closed stream", s, objName = s.objName diff --git a/tests/helpers.nim b/tests/helpers.nim index 2e10c01..d28fbcd 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -41,3 +41,16 @@ proc getRng(): ref BrHmacDrbgContext = template rng*(): ref BrHmacDrbgContext = getRng() + +type + WriteHandler* = proc(data: seq[byte]): Future[void] {.gcsafe.} + TestBufferStream* = ref object of BufferStream + writeHandler*: WriteHandler + +method write*(s: TestBufferStream, msg: seq[byte]): Future[void] = + s.writeHandler(msg) + +proc newBufferStream*(writeHandler: WriteHandler): TestBufferStream = + new result + result.writeHandler = writeHandler + result.initStream() diff --git a/tests/testbufferstream.nim b/tests/testbufferstream.nim index ae134b1..847ab67 100644 --- a/tests/testbufferstream.nim +++ b/tests/testbufferstream.nim @@ -1,10 +1,9 @@ -import unittest, strformat +import unittest import chronos, stew/byteutils import ../libp2p/stream/bufferstream, - ../libp2p/stream/lpstream, - ../libp2p/errors + ../libp2p/stream/lpstream -when defined(nimHasUsed): {.used.} +{.used.} suite "BufferStream": teardown: @@ -13,8 +12,7 @@ suite "BufferStream": test "push data to buffer": proc testPushTo(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 16) + let buff = newBufferStream() check buff.len == 0 var data = "12345" await buff.pushTo(data.toBytes()) @@ -28,14 +26,19 @@ suite "BufferStream": test "push and wait": proc testPushTo(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 4) + let buff = newBufferStream() check buff.len == 0 - let fut = buff.pushTo("12345".toBytes()) - check buff.len == 4 - check buff.popFirst() == byte(ord('1')) - await fut + let fut0 = buff.pushTo("1234".toBytes()) + let fut1 = buff.pushTo("5".toBytes()) + check buff.len == 4 # the second write should not be visible yet + + var data: array[1, byte] + check: 1 == await buff.readOnce(addr data[0], data.len) + + check ['1'] == string.fromBytes(data) + await fut0 + await fut1 check buff.len == 4 result = true @@ -47,13 +50,12 @@ suite "BufferStream": test "read with size": proc testRead(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 10) + let buff = newBufferStream() check buff.len == 0 await buff.pushTo("12345".toBytes()) - var data = newSeq[byte](3) - await buff.readExactly(addr data[0], 3) + var data: array[3, byte] + await buff.readExactly(addr data[0], data.len) check ['1', '2', '3'] == string.fromBytes(data) result = true @@ -65,15 +67,14 @@ suite "BufferStream": test "readExactly": proc testReadExactly(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 10) + let buff = newBufferStream() check buff.len == 0 await buff.pushTo("12345".toBytes()) check buff.len == 5 - var data: seq[byte] = newSeq[byte](2) - await buff.readExactly(addr data[0], 2) - check string.fromBytes(data) == @['1', '2'] + var data: array[2, byte] + await buff.readExactly(addr data[0], data.len) + check string.fromBytes(data) == ['1', '2'] result = true @@ -84,19 +85,17 @@ suite "BufferStream": test "readExactly raises": proc testReadExactly(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 10) + let buff = newBufferStream() check buff.len == 0 await buff.pushTo("123".toBytes()) - var data: seq[byte] = newSeq[byte](5) - var readFut: Future[void] - readFut = buff.readExactly(addr data[0], 5) + var data: array[5, byte] + var readFut = buff.readExactly(addr data[0], data.len) await buff.close() try: await readFut - except LPStreamIncompleteError, LPStreamEOFError: + except LPStreamEOFError: result = true check: @@ -104,17 +103,16 @@ suite "BufferStream": test "readOnce": proc testReadOnce(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 10) + let buff = newBufferStream() check buff.len == 0 - var data: seq[byte] = newSeq[byte](3) - let readFut = buff.readOnce(addr data[0], 5) + var data: array[3, byte] + let readFut = buff.readOnce(addr data[0], data.len) await buff.pushTo("123".toBytes()) check buff.len == 3 check (await readFut) == 3 - check string.fromBytes(data) == @['1', '2', '3'] + check string.fromBytes(data) == ['1', '2', '3'] result = true @@ -123,119 +121,75 @@ suite "BufferStream": check: waitFor(testReadOnce()) == true - test "write ptr": - proc testWritePtr(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - check string.fromBytes(data) == "Hello!" - - let buff = newBufferStream(writeHandler, 10) - check buff.len == 0 - - var data = "Hello!" - await buff.write(addr data[0], data.len) - - result = true - - await buff.close() - - check: - waitFor(testWritePtr()) == true - - test "write string": - proc testWritePtr(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - check string.fromBytes(data) == "Hello!" - - let buff = newBufferStream(writeHandler, 10) - check buff.len == 0 - - await buff.write("Hello!") - - result = true - - await buff.close() - - check: - waitFor(testWritePtr()) == true - - test "write bytes": - proc testWritePtr(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - check string.fromBytes(data) == "Hello!" - - let buff = newBufferStream(writeHandler, 10) - check buff.len == 0 - - await buff.write("Hello!".toBytes()) - - result = true - - await buff.close() - - check: - waitFor(testWritePtr()) == true - - test "write should happen in order": - proc testWritePtr(): Future[bool] {.async.} = - var count = 1 - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = - check string.fromBytes(data) == &"Msg {$count}" - count.inc - - let buff = newBufferStream(writeHandler, 10) - check buff.len == 0 - - await buff.write("Msg 1") - await buff.write("Msg 2") - await buff.write("Msg 3") - await buff.write("Msg 4") - await buff.write("Msg 5") - await buff.write("Msg 6") - await buff.write("Msg 7") - await buff.write("Msg 8") - await buff.write("Msg 9") - await buff.write("Msg 10") - - result = true - - await buff.close() - - check: - waitFor(testWritePtr()) == true - test "reads should happen in order": proc testWritePtr(): Future[bool] {.async.} = - proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard - let buff = newBufferStream(writeHandler, 10) + let buff = newBufferStream() check buff.len == 0 - await buff.pushTo("Msg 1".toBytes()) - await buff.pushTo("Msg 2".toBytes()) - await buff.pushTo("Msg 3".toBytes()) + let w1 = buff.pushTo("Msg 1".toBytes()) + let w2 = buff.pushTo("Msg 2".toBytes()) + let w3 = buff.pushTo("Msg 3".toBytes()) + + var data: array[5, byte] + await buff.readExactly(addr data[0], data.len) - var data = newSeq[byte](5) - await buff.readExactly(addr data[0], 5) check string.fromBytes(data) == "Msg 1" - await buff.readExactly(addr data[0], 5) + await buff.readExactly(addr data[0], data.len) check string.fromBytes(data) == "Msg 2" - await buff.readExactly(addr data[0], 5) + await buff.readExactly(addr data[0], data.len) check string.fromBytes(data) == "Msg 3" - await buff.pushTo("Msg 4".toBytes()) - await buff.pushTo("Msg 5".toBytes()) - await buff.pushTo("Msg 6".toBytes()) + for f in [w1, w2, w3]: await f - await buff.readExactly(addr data[0], 5) + let w4 = buff.pushTo("Msg 4".toBytes()) + let w5 = buff.pushTo("Msg 5".toBytes()) + let w6 = buff.pushTo("Msg 6".toBytes()) + + await buff.close() + + await buff.readExactly(addr data[0], data.len) check string.fromBytes(data) == "Msg 4" - await buff.readExactly(addr data[0], 5) + await buff.readExactly(addr data[0], data.len) check string.fromBytes(data) == "Msg 5" - await buff.readExactly(addr data[0], 5) + await buff.readExactly(addr data[0], data.len) check string.fromBytes(data) == "Msg 6" + for f in [w4, w5, w6]: await f + + result = true + + check: + waitFor(testWritePtr()) == true + + test "small reads": + proc testWritePtr(): Future[bool] {.async.} = + let buff = newBufferStream() + check buff.len == 0 + + var writes: seq[Future[void]] + var str: string + for i in 0..<10: + writes.add buff.pushTo("123".toBytes()) + str &= "123" + await buff.close() # all data should still be read after close + + var str2: string + var data: array[2, byte] + try: + while true: + let x = await buff.readOnce(addr data[0], data.len) + str2 &= string.fromBytes(data[0.. B": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - var buf2 = buf1.pipe(newBufferStream()) - - var res1: seq[byte] = newSeq[byte](7) - var readFut = buf2.readExactly(addr res1[0], 7) - await buf1.write("Hello1!".toBytes()) - await readFut - - check: - res1 == "Hello1!".toBytes() - - result = true - - await buf1.close() - await buf2.close() - - check: - waitFor(pipeTest()) == true - - test "pipe A -> B and B -> A": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - var buf2 = newBufferStream() - - buf1 = buf1.pipe(buf2).pipe(buf1) - - var res1: seq[byte] = newSeq[byte](7) - var readFut1 = buf1.readExactly(addr res1[0], 7) - - var res2: seq[byte] = newSeq[byte](7) - var readFut2 = buf2.readExactly(addr res2[0], 7) - - await buf1.write("Hello1!".toBytes()) - await buf2.write("Hello2!".toBytes()) - await allFuturesThrowing(readFut1, readFut2) - - check: - res1 == "Hello2!".toBytes() - res2 == "Hello1!".toBytes() - - result = true - - await buf1.close() - await buf2.close() - - check: - waitFor(pipeTest()) == true - - test "pipe A -> A (echo)": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - - buf1 = buf1.pipe(buf1) - - proc reader(): Future[seq[byte]] {.async.} = - result = newSeq[byte](6) - await buf1.readExactly(addr result[0], 6) - - proc writer(): Future[void] = buf1.write("Hello!".toBytes()) - - var writerFut = writer() - var readerFut = reader() - - await writerFut - check: - (await readerFut) == "Hello!".toBytes() - - result = true - - await buf1.close() - - check: - waitFor(pipeTest()) == true - - test "pipe with `|` operator - A -> B": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - var buf2 = buf1 | newBufferStream() - - var res1: seq[byte] = newSeq[byte](7) - var readFut = buf2.readExactly(addr res1[0], 7) - await buf1.write("Hello1!".toBytes()) - await readFut - - check: - res1 == "Hello1!".toBytes() - - result = true - - await buf1.close() - await buf2.close() - - check: - waitFor(pipeTest()) == true - - test "pipe with `|` operator - A -> B and B -> A": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - var buf2 = newBufferStream() - - buf1 = buf1 | buf2 | buf1 - - var res1: seq[byte] = newSeq[byte](7) - var readFut1 = buf1.readExactly(addr res1[0], 7) - - var res2: seq[byte] = newSeq[byte](7) - var readFut2 = buf2.readExactly(addr res2[0], 7) - - await buf1.write("Hello1!".toBytes()) - await buf2.write("Hello2!".toBytes()) - await allFuturesThrowing(readFut1, readFut2) - - check: - res1 == "Hello2!".toBytes() - res2 == "Hello1!".toBytes() - - result = true - - await buf1.close() - await buf2.close() - - check: - waitFor(pipeTest()) == true - - test "pipe with `|` operator - A -> A (echo)": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream() - - buf1 = buf1 | buf1 - - proc reader(): Future[seq[byte]] {.async.} = - result = newSeq[byte](6) - await buf1.readExactly(addr result[0], 6) - - proc writer(): Future[void] = buf1.write("Hello!".toBytes()) - - var writerFut = writer() - var readerFut = reader() - - await writerFut - check: - (await readerFut) == "Hello!".toBytes() - - result = true - - await buf1.close() - - check: - waitFor(pipeTest()) == true - - # TODO: Need to implement deadlock prevention when - # piping to self - test "pipe deadlock": - proc pipeTest(): Future[bool] {.async.} = - var buf1 = newBufferStream(size = 5) - - buf1 = buf1 | buf1 - - var count = 30000 - proc reader() {.async.} = - var data = newSeq[byte](7) - await buf1.readExactly(addr data[0], 7) - - proc writer() {.async.} = - while count > 0: - await buf1.write("Hello2!".toBytes()) - count.dec - - var writerFut = writer() - var readerFut = reader() - - await allFuturesThrowing(readerFut, writerFut) - result = true - - await buf1.close() - - check: - waitFor(pipeTest()) == true - test "shouldn't get stuck on close": proc closeTest(): Future[bool] {.async.} = - proc createMessage(tmplate: string, size: int): seq[byte] = - result = newSeq[byte](size) - for i in 0 ..< len(result): - result[i] = byte(tmplate[i mod len(tmplate)]) - var stream = newBufferStream() - var message = createMessage("MESSAGE", DefaultBufferSize * 2 + 1) - var fut = stream.pushTo(message) + var + fut = stream.pushTo(toBytes("hello")) + fut2 = stream.pushTo(toBytes("again")) await stream.close() try: await wait(fut, 100.milliseconds) + await wait(fut2, 100.milliseconds) result = true except AsyncTimeoutError: result = false @@ -486,3 +215,19 @@ suite "BufferStream": check: waitFor(closeTest()) == true + + test "no push after close": + proc closeTest(): Future[bool] {.async.} = + var stream = newBufferStream() + await stream.pushTo("123".toBytes()) + var data: array[3, byte] + await stream.readExactly(addr data[0], data.len) + await stream.close() + + try: + await stream.pushTo("123".toBytes()) + except LPStreamClosedError: + result = true + + check: + waitFor(closeTest()) == true diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 0ea9131..57c0e27 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -1,4 +1,4 @@ -import unittest, strformat, strformat, random +import unittest, strformat, strformat, random, oids import chronos, nimcrypto/utils, chronicles, stew/byteutils import ../libp2p/[errors, stream/connection, @@ -132,7 +132,6 @@ suite "Mplex": conn = newBufferStream( proc (data: seq[byte]) {.gcsafe, async.} = discard, - timeout = 5.minutes ) chann = LPChannel.init(1, conn, true) @@ -146,7 +145,6 @@ suite "Mplex": let closeFut = chann.closeRemote() # should still allow reading until buffer EOF await chann.readExactly(addr data[3], 3) - await closeFut try: # this should fail now await chann.readExactly(addr data[0], 3) @@ -155,6 +153,7 @@ suite "Mplex": finally: await chann.close() await conn.close() + await closeFut check: waitFor(testOpenForRead()) == true @@ -165,7 +164,6 @@ suite "Mplex": conn = newBufferStream( proc (data: seq[byte]) {.gcsafe, async.} = discard, - timeout = 5.minutes ) chann = LPChannel.init(1, conn, true) @@ -194,11 +192,9 @@ suite "Mplex": conn = newBufferStream( proc (data: seq[byte]) {.gcsafe, async.} = discard - , timeout = 5.minutes ) chann = LPChannel.init(1, conn, true) - var data = newSeq[byte](6) await chann.closeRemote() # closing channel try: await chann.writeLp(testData) @@ -273,7 +269,7 @@ suite "Mplex": chann = LPChannel.init( 1, conn, true, timeout = 100.millis) - await chann.closeEvent.wait() + check await chann.closeEvent.wait().withTimeout(1.minutes) await conn.close() result = true @@ -555,8 +551,11 @@ suite "Mplex": let mplexListen = Mplex.init(conn) mplexListen.streamHandler = proc(stream: Connection) {.async, gcsafe.} = - let msg = await stream.readLp(MsgSize) - check msg.len == MsgSize + try: + let msg = await stream.readLp(MsgSize) + check msg.len == MsgSize + except CatchableError as e: + echo e.msg await stream.close() complete.complete() @@ -602,10 +601,11 @@ suite "Mplex": buf.buffer = buf.buffer[size..^1] await writer() + await complete.wait(1.seconds) await stream.close() await conn.close() - await complete.wait(1.seconds) + await mplexDialFut await allFuturesThrowing(