diff --git a/chronos/debugutils.nim b/chronos/debugutils.nim index 0f42afa8..451189a3 100644 --- a/chronos/debugutils.nim +++ b/chronos/debugutils.nim @@ -34,9 +34,9 @@ proc dumpPendingFutures*(filter = AllFutureStates): string = ## not yet finished). ## 2. Future[T] objects with ``FutureState.Finished/Cancelled/Failed`` state ## which callbacks are scheduled, but not yet fully processed. - var count = 0'u - var res = "" when defined(chronosFutureTracking): + var count = 0'u + var res = "" for item in pendingFutures(): if item.state in filter: inc(count) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 9cfa9637..6abb019c 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -36,6 +36,7 @@ type par*: ref CatchableError AsyncStreamWriteError* = object of AsyncStreamError par*: ref CatchableError + AsyncStreamWriteEOFError* = object of AsyncStreamWriteError AsyncBuffer* = object offset*: int @@ -218,24 +219,25 @@ proc newAsyncStreamUseClosedError*(): ref AsyncStreamUseClosedError {. newException(AsyncStreamUseClosedError, "Stream is already closed") proc raiseAsyncStreamUseClosedError*() {. - noinline, noreturn, raises: [Defect, AsyncStreamUseClosedError].} = + noinline, noreturn, raises: [Defect, AsyncStreamUseClosedError].} = raise newAsyncStreamUseClosedError() proc raiseAsyncStreamLimitError*() {. - noinline, noreturn, raises: [Defect, AsyncStreamLimitError].} = + noinline, noreturn, raises: [Defect, AsyncStreamLimitError].} = raise newAsyncStreamLimitError() proc raiseAsyncStreamIncompleteError*() {. - noinline, noreturn, raises: [Defect, AsyncStreamIncompleteError].} = + noinline, noreturn, raises: [Defect, AsyncStreamIncompleteError].} = raise newAsyncStreamIncompleteError() proc raiseEmptyMessageDefect*() {.noinline, noreturn.} = raise newException(AsyncStreamIncorrectDefect, "Could not write empty message") -template checkStreamClosed*(t: untyped) = - if t.state == AsyncStreamState.Closed: - raiseAsyncStreamUseClosedError() +proc raiseAsyncStreamWriteEOFError*() {. + noinline, noreturn, raises: [Defect, AsyncStreamWriteEOFError].} = + raise newException(AsyncStreamWriteEOFError, + "Stream finished or remote side dropped connection") proc atEof*(rstream: AsyncStreamReader): bool = ## Returns ``true`` is reading stream is closed or finished and internal @@ -257,93 +259,81 @@ proc atEof*(wstream: AsyncStreamWriter): bool = else: wstream.wsource.atEof() else: - wstream.state != AsyncStreamState.Running - -proc closed*(reader: AsyncStreamReader): bool = - ## Returns ``true`` is reading/writing stream is closed. - reader.state in {AsyncStreamState.Closing, Closed} - -proc finished*(reader: AsyncStreamReader): bool = - ## Returns ``true`` is reading/writing stream is finished (completed). - if isNil(reader.readerLoop): - if isNil(reader.rsource): - reader.tsource.finished() + # `wstream.future` holds `rstream.writerLoop()` call's result. + # Return `true` if `writerLoop()` is not yet started or already stopped. + if isNil(wstream.future) or wstream.future.finished(): + true else: - reader.rsource.finished() - else: - (reader.state == AsyncStreamState.Finished) + wstream.state != AsyncStreamState.Running -proc stopped*(reader: AsyncStreamReader): bool = - ## Returns ``true`` is reading/writing stream is stopped (interrupted). - if isNil(reader.readerLoop): - if isNil(reader.rsource): +proc closed*(rw: AsyncStreamRW): bool = + ## Returns ``true`` is reading/writing stream is closed. + rw.state in {AsyncStreamState.Closing, Closed} + +proc finished*(rw: AsyncStreamRW): bool = + ## Returns ``true`` if reading/writing stream is finished (completed). + rw.atEof() and rw.state == AsyncStreamState.Finished + +proc stopped*(rw: AsyncStreamRW): bool = + ## Returns ``true`` if reading/writing stream is stopped (interrupted). + let loopIsNil = + when rw is AsyncStreamReader: + isNil(rw.readerLoop) + else: + isNil(rw.writerLoop) + + if loopIsNil: + when rw is AsyncStreamReader: + if isNil(rw.rsource): false else: rw.rsource.stopped() + else: + if isNil(rw.wsource): false else: rw.wsource.stopped() + else: + if isNil(rw.future) or rw.future.finished(): false else: - reader.rsource.stopped() - else: - (reader.state == AsyncStreamState.Stopped) + rw.state == AsyncStreamState.Stopped -proc running*(reader: AsyncStreamReader): bool = - ## Returns ``true`` is reading/writing stream is still pending. - if isNil(reader.readerLoop): - if isNil(reader.rsource): - reader.tsource.running() +proc running*(rw: AsyncStreamRW): bool = + ## Returns ``true`` if reading/writing stream is still pending. + let loopIsNil = + when rw is AsyncStreamReader: + isNil(rw.readerLoop) else: - reader.rsource.running() - else: - (reader.state == AsyncStreamState.Running) - -proc failed*(reader: AsyncStreamReader): bool = - if isNil(reader.readerLoop): - if isNil(reader.rsource): - reader.tsource.failed() + isNil(rw.writerLoop) + if loopIsNil: + when rw is AsyncStreamReader: + if isNil(rw.rsource): rw.tsource.running() else: rw.rsource.running() else: - reader.rsource.failed() + if isNil(rw.wsource): rw.tsource.running() else: rw.wsource.running() else: - (reader.state == AsyncStreamState.Error) - -proc closed*(writer: AsyncStreamWriter): bool = - ## Returns ``true`` is reading/writing stream is closed. - writer.state in {AsyncStreamState.Closing, Closed} - -proc finished*(writer: AsyncStreamWriter): bool = - ## Returns ``true`` is reading/writing stream is finished (completed). - if isNil(writer.writerLoop): - if isNil(writer.wsource): - writer.tsource.finished() - else: - writer.wsource.finished() - else: - (writer.state == AsyncStreamState.Finished) - -proc stopped*(writer: AsyncStreamWriter): bool = - ## Returns ``true`` is reading/writing stream is stopped (interrupted). - if isNil(writer.writerLoop): - if isNil(writer.wsource): + if isNil(rw.future) or rw.future.finished(): false else: - writer.wsource.stopped() - else: - (writer.state == AsyncStreamState.Stopped) + rw.state == AsyncStreamState.Running -proc running*(writer: AsyncStreamWriter): bool = - ## Returns ``true`` is reading/writing stream is still pending. - if isNil(writer.writerLoop): - if isNil(writer.wsource): - writer.tsource.running() +proc failed*(rw: AsyncStreamRW): bool = + ## Returns ``true`` if reading/writing stream is in failed state. + let loopIsNil = + when rw is AsyncStreamReader: + isNil(rw.readerLoop) else: - writer.wsource.running() + isNil(rw.writerLoop) + if loopIsNil: + when rw is AsyncStreamReader: + if isNil(rw.rsource): rw.tsource.failed() else: rw.rsource.failed() + else: + if isNil(rw.wsource): rw.tsource.failed() else: rw.wsource.failed() else: - (writer.state == AsyncStreamState.Running) + if isNil(rw.future) or rw.future.finished(): + false + else: + rw.state == AsyncStreamState.Error -proc failed*(writer: AsyncStreamWriter): bool = - if isNil(writer.writerLoop): - if isNil(writer.wsource): - writer.tsource.failed() - else: - writer.wsource.failed() - else: - (writer.state == AsyncStreamState.Error) +template checkStreamClosed*(t: untyped) = + if t.closed(): raiseAsyncStreamUseClosedError() + +template checkStreamFinished*(t: untyped) = + if t.atEof(): raiseAsyncStreamWriteEOFError() proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {. gcsafe, raises: [Defect].} @@ -787,6 +777,8 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, ## ## ``nbytes`` must be more then zero. checkStreamClosed(wstream) + checkStreamFinished(wstream) + if nbytes <= 0: raiseEmptyMessageDefect() @@ -834,6 +826,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink seq[byte], ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. checkStreamClosed(wstream) + checkStreamFinished(wstream) + let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) if length <= 0: raiseEmptyMessageDefect() @@ -885,6 +879,8 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink string, ## If ``msglen > len(sbytes)`` only ``len(sbytes)`` bytes will be written to ## stream. checkStreamClosed(wstream) + checkStreamFinished(wstream) + let length = if msglen <= 0: len(sbytes) else: min(msglen, len(sbytes)) if length <= 0: raiseEmptyMessageDefect() @@ -929,23 +925,25 @@ proc write*(wstream: AsyncStreamWriter, sbytes: sink string, proc finish*(wstream: AsyncStreamWriter) {.async.} = ## Finish write stream ``wstream``. checkStreamClosed(wstream) - - if not isNil(wstream.wsource): - if isNil(wstream.writerLoop): - await wstream.wsource.finish() - else: - var item = WriteItem(kind: Pointer) - item.size = 0 - item.future = newFuture[void]("async.stream.finish") - try: - await wstream.queue.put(item) - await item.future - except CancelledError as exc: - raise exc - except AsyncStreamError as exc: - raise exc - except CatchableError as exc: - raise newAsyncStreamWriteError(exc) + # For AsyncStreamWriter Finished state could be set manually or by stream's + # writeLoop, so we not going to raise exception here. + if not(wstream.atEof()): + if not isNil(wstream.wsource): + if isNil(wstream.writerLoop): + await wstream.wsource.finish() + else: + var item = WriteItem(kind: Pointer) + item.size = 0 + item.future = newFuture[void]("async.stream.finish") + try: + await wstream.queue.put(item) + await item.future + except CancelledError as exc: + raise exc + except AsyncStreamError as exc: + raise exc + except CatchableError as exc: + raise newAsyncStreamWriteError(exc) proc join*(rw: AsyncStreamRW): Future[void] = ## Get Future[void] which will be completed when stream become finished or diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index a0d3c1ab..7010c3a4 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -11,7 +11,7 @@ ## uses sources of BearSSL by Thomas Pornin. import bearssl/[brssl, ec, errors, pem, rsa, ssl, x509], - bearssl/abi/cacert + bearssl/certs/cacert import ../asyncloop, ../timer, ../asyncsync import asyncstream, ../transports/stream, ../transports/common export asyncloop, asyncsync, timer, asyncstream @@ -102,13 +102,6 @@ type TLSStreamProtocolError* = object of TLSStreamError errCode*: int -proc newTLSStreamReadError(p: ref AsyncStreamError): ref TLSStreamReadError {. - noinline.} = - var w = newException(TLSStreamReadError, "Read stream failed") - w.msg = w.msg & ", originated from [" & $p.name & "] " & p.msg - w.par = p - w - proc newTLSStreamWriteError(p: ref AsyncStreamError): ref TLSStreamWriteError {. noinline.} = var w = newException(TLSStreamWriteError, "Write stream failed") @@ -375,10 +368,10 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = # Syncing state for reader and writer stream.writer.state = loopState + stream.reader.state = loopState if loopState == AsyncStreamState.Error: if isNil(stream.reader.error): - stream.reader.error = newTLSStreamReadError(error) - stream.reader.state = loopState + stream.reader.state = AsyncStreamState.Finished if not(isNil(error)): # Completing all pending writes diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 9dab2023..77bb4989 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -483,6 +483,48 @@ suite "AsyncStream test suite": await server.join() result = true check waitFor(testConsume2(initTAddress("127.0.0.1:46001"))) == true + test "AsyncStream(AsyncStream) write(eof) test": + proc testWriteEof(address: TransportAddress): Future[bool] {.async.} = + let + size = 10240 + message = createBigMessage("ABCDEFGHIJKLMNOP", size) + + proc processClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var wstream = newAsyncStreamWriter(transp) + var wbstream = newBoundedStreamWriter(wstream, uint64(size)) + try: + check wbstream.atEof() == false + await wbstream.write(message) + check wbstream.atEof() == false + await wbstream.finish() + check wbstream.atEof() == true + expect AsyncStreamWriteEOFError: + await wbstream.write(message) + expect AsyncStreamWriteEOFError: + await wbstream.write(message) + expect AsyncStreamWriteEOFError: + await wbstream.write(message) + check wbstream.atEof() == true + await wbstream.closeWait() + check wbstream.atEof() == true + finally: + await wstream.closeWait() + await transp.closeWait() + + let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay} + var server = createStreamServer(address, processClient, flags = flags) + server.start() + var conn = await connect(server.localAddress()) + try: + discard await conn.consume() + finally: + await conn.closeWait() + server.stop() + await server.closeWait() + return true + + check waitFor(testWriteEof(initTAddress("127.0.0.1:46001"))) == true test "AsyncStream(AsyncStream) leaks test": check: getTracker("async.stream.reader").isLeaked() == false