diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index b031cdc..cc37083 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -71,7 +71,7 @@ method closed*(s: LPChannel): bool = proc closeUnderlying(s: LPChannel): Future[void] {.async.} = ## Channels may be closed for reading and writing in any order - we'll close ## the underlying bufferstream when both directions are closed - if s.closedLocal and s.isEof: + if s.closedLocal and s.atEof(): await procCall BufferStream(s).close() proc reset*(s: LPChannel) {.async, gcsafe.} = @@ -79,6 +79,8 @@ proc reset*(s: LPChannel) {.async, gcsafe.} = trace "Already closed", s return + s.isClosed = true + trace "Resetting channel", s, len = s.len # First, make sure any new calls to `readOnce` and `pushData` etc will fail - @@ -88,17 +90,20 @@ proc reset*(s: LPChannel) {.async, gcsafe.} = s.readBuf = StreamSeq() s.pushedEof = true - for i in 0.. 0, "nbytes must be positive integer") - if s.isEof and s.readBuf.len() == 0: + doAssert(not s.reading, "Only one concurrent read allowed") + + if s.returnedEof: raise newLPStreamEOFError() var @@ -105,12 +111,23 @@ method readOnce*(s: BufferStream, # First consume leftovers from previous read var rbytes = s.readBuf.consumeTo(toOpenArray(p, 0, nbytes - 1)) - if rbytes < nbytes: + if rbytes < nbytes and not s.isEof: # There's space in the buffer - consume some data from the read queue - trace "popping readQueue", s, rbytes, nbytes - let buf = await s.readQueue.popFirst() + s.reading = true + let buf = + try: + await s.readQueue.popFirst() + except CatchableError as exc: + # When an exception happens here, the Bufferstream is effectively + # broken and no more reads will be valid - for now, return EOF if it's + # called again, though this is not completely true - EOF represents an + # "orderly" shutdown and that's not what happened here.. + s.returnedEof = true + raise exc + finally: + s.reading = false - if buf.len == 0 or s.isEof: # Another task might have set EOF! + if buf.len == 0: # No more data will arrive on read queue trace "EOF", s s.isEof = true @@ -130,6 +147,12 @@ method readOnce*(s: BufferStream, s.activity = true + # We want to return 0 exactly once - after that, we'll start raising instead - + # this is a bit nuts in a mixed exception / return value world, but allows the + # consumer of the stream to rely on the 0-byte read as a "regular" EOF marker + # (instead of _sometimes_ getting an exception). + s.returnedEof = rbytes == 0 + return rbytes method closeImpl*(s: BufferStream): Future[void] = diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index b1966b9..c9a6314 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -84,7 +84,7 @@ proc timeoutMonitor(s: Connection) {.async, gcsafe.} = while true: await sleepAsync(s.timeout) - if s.closed or s.atEof: + if s.closed and s.atEof: return if s.activity: diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index baec119..5fbe453 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -129,10 +129,10 @@ method initStream*(s: LPStream) {.base.} = proc join*(s: LPStream): Future[void] = s.closeEvent.wait() -method closed*(s: LPStream): bool {.base, inline.} = +method closed*(s: LPStream): bool {.base.} = s.isClosed -method atEof*(s: LPStream): bool {.base, inline.} = +method atEof*(s: LPStream): bool {.base.} = s.isEof method readOnce*(s: LPStream, @@ -272,6 +272,13 @@ method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].} proc closeWithEOF*(s: LPStream): Future[void] {.async.} = ## Close the stream and wait for EOF - use this with half-closed streams where ## an EOF is expected to arrive from the other end. + ## + ## Note - this should only be used when there has been an in-protocol + ## notification that no more data will arrive and that the only thing left + ## for the other end to do is to close the stream gracefully. + ## + ## In particular, it must not be used when there is another concurrent read + ## ongoing (which may be the case during cancellations)! await s.close() if s.atEof(): diff --git a/tests/testbufferstream.nim b/tests/testbufferstream.nim index 866e3e6..460838b 100644 --- a/tests/testbufferstream.nim +++ b/tests/testbufferstream.nim @@ -142,6 +142,54 @@ suite "BufferStream": check str == str2 await buff.close() + asyncTest "read all data after eof": + let buff = newBufferStream() + check buff.len == 0 + + await buff.pushData("12345".toBytes()) + var data: array[2, byte] + check: (await buff.readOnce(addr data[0], data.len)) == 2 + + await buff.pushEof() + + check: + not buff.atEof() + (await buff.readOnce(addr data[0], data.len)) == 2 + not buff.atEof() + (await buff.readOnce(addr data[0], data.len)) == 1 + buff.atEof() + # exactly one 0-byte read + (await buff.readOnce(addr data[0], data.len)) == 0 + + expect LPStreamEOFError: + discard (await buff.readOnce(addr data[0], data.len)) + + await buff.close() # all data should still be read after close + + asyncTest "read more data after eof": + let buff = newBufferStream() + check buff.len == 0 + + await buff.pushData("12345".toBytes()) + var data: array[5, byte] + check: (await buff.readOnce(addr data[0], 1)) == 1 # 4 bytes in readBuf + + await buff.pushEof() + + check: + not buff.atEof() + (await buff.readOnce(addr data[0], 1)) == 1 # 3 bytes in readBuf, eof marker processed + not buff.atEof() + (await buff.readOnce(addr data[0], data.len)) == 3 # 0 bytes in readBuf + buff.atEof() + # exactly one 0-byte read + (await buff.readOnce(addr data[0], data.len)) == 0 + + expect LPStreamEOFError: + discard (await buff.readOnce(addr data[0], data.len)) + + await buff.close() # all data should still be read after close + asyncTest "shouldn't get stuck on close": var stream = newBufferStream() var diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 07694e2..9e6dc6f 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -140,7 +140,7 @@ suite "Mplex": await chann.readExactly(addr data[0], 3) let closeFut = chann.pushEof() # closing channel let readFut = chann.readExactly(addr data[3], 3) - await all(closeFut, readFut) + await allFutures(closeFut, readFut) expect LPStreamEOFError: await chann.readExactly(addr data[0], 6) # this should fail now @@ -194,7 +194,7 @@ suite "Mplex": await conn.close() - asyncTest "should complete read": + asyncTest "reset should complete read": proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newBufferStream(writeHandler) @@ -209,33 +209,66 @@ suite "Mplex": await conn.close() - asyncTest "should complete pushData": + asyncTest "reset should complete pushData": proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newBufferStream(writeHandler) chann = LPChannel.init(1, conn, true) - await chann.pushData(@[0'u8]) - let fut = chann.pushData(@[0'u8]) + let futs = @[ + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + ] await chann.reset() - check await fut.withTimeout(100.millis) + check await allFutures(futs).withTimeout(100.millis) await conn.close() - asyncTest "should complete both read and push": + asyncTest "reset should complete both read and push": proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newBufferStream(writeHandler) chann = LPChannel.init(1, conn, true) var data = newSeq[byte](1) - let rfut = chann.readExactly(addr data[0], 1) - let wfut = chann.pushData(@[0'u8]) - let wfut2 = chann.pushData(@[0'u8]) + let futs = [ + chann.readExactly(addr data[0], 1), + chann.pushData(@[0'u8]), + ] await chann.reset() - check await allFutures(rfut, wfut, wfut2).withTimeout(100.millis) + check await allFutures(futs).withTimeout(100.millis) await conn.close() - asyncTest "should complete both read and push after cancel": + asyncTest "reset should complete both read and pushes": + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let + conn = newBufferStream(writeHandler) + chann = LPChannel.init(1, conn, true) + + var data = newSeq[byte](1) + let futs = [ + chann.readExactly(addr data[0], 1), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + chann.pushData(@[0'u8]), + ] + await chann.reset() + check await allFutures(futs).withTimeout(100.millis) + await futs[0] + await conn.close() + + asyncTest "reset should complete both read and push with cancel": proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard let conn = newBufferStream(writeHandler) @@ -244,11 +277,9 @@ suite "Mplex": var data = newSeq[byte](1) let rfut = chann.readExactly(addr data[0], 1) rfut.cancel() + let xfut = chann.reset() - let wfut = chann.pushData(@[0'u8]) - let wfut2 = chann.pushData(@[0'u8]) - await chann.reset() - check await allFutures(rfut, wfut, wfut2).withTimeout(100.millis) + check await allFutures(rfut, xfut).withTimeout(100.millis) await conn.close() asyncTest "should complete both read and push after reset": @@ -259,14 +290,14 @@ suite "Mplex": var data = newSeq[byte](1) let rfut = chann.readExactly(addr data[0], 1) - let fut2 = sleepAsync(1.millis) or rfut + let rfut2 = sleepAsync(1.millis) or rfut await sleepAsync(5.millis) let wfut = chann.pushData(@[0'u8]) let wfut2 = chann.pushData(@[0'u8]) await chann.reset() - check await allFutures(rfut, wfut, wfut2).withTimeout(100.millis) + check await allFutures(rfut, rfut2, wfut, wfut2).withTimeout(100.millis) await conn.close() asyncTest "channel should fail writing": @@ -466,9 +497,9 @@ suite "Mplex": let msg = await stream.readLp(1024) check string.fromBytes(msg) == &"stream {count}!" count.inc - await stream.close() - if count == 10: + if count == 11: done.complete() + await stream.close() await mplexListen.handle() await mplexListen.close() @@ -508,9 +539,9 @@ suite "Mplex": check string.fromBytes(msg) == &"stream {count} from dialer!" await stream.writeLp(&"stream {count} from listener!") count.inc - await stream.close() - if count == 10: + if count == 11: done.complete() + await stream.close() await mplexListen.handle() await mplexListen.close()