diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim index efa043e..b0cda66 100644 --- a/chronos/apps/http/httpbodyrw.nim +++ b/chronos/apps/http/httpbodyrw.nim @@ -8,6 +8,7 @@ # MIT license (LICENSE-MIT) import ../../asyncloop, ../../asyncsync import ../../streams/[asyncstream, boundstream] +import httpcommon const HttpBodyReaderTrackerName* = "http.body.reader" @@ -17,9 +18,11 @@ const type HttpBodyReader* = ref object of AsyncStreamReader + bstate*: HttpState streams*: seq[AsyncStreamReader] HttpBodyWriter* = ref object of AsyncStreamWriter + bstate*: HttpState streams*: seq[AsyncStreamWriter] HttpBodyTracker* = ref object of TrackerBase @@ -93,21 +96,24 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = ## ## First stream in sequence will be used as a source. doAssert(len(streams) > 0, "At least one stream must be added") - var res = HttpBodyReader(streams: @streams) + var res = HttpBodyReader(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) trackHttpBodyReader(res) res proc closeWait*(bstream: HttpBodyReader) {.async.} = ## Close and free resource allocated by body reader. - var res = newSeq[Future[void]]() - # We closing streams in reversed order because stream at position [0], uses - # data from stream at position [1]. - for index in countdown((len(bstream.streams) - 1), 0): - res.add(bstream.streams[index].closeWait()) - await allFutures(res) - await procCall(closeWait(AsyncStreamReader(bstream))) - untrackHttpBodyReader(bstream) + if bstream.bstate == HttpState.Alive: + bstream.bstate = HttpState.Closing + var res = newSeq[Future[void]]() + # We closing streams in reversed order because stream at position [0], uses + # data from stream at position [1]. + for index in countdown((len(bstream.streams) - 1), 0): + res.add(bstream.streams[index].closeWait()) + await allFutures(res) + await procCall(closeWait(AsyncStreamReader(bstream))) + bstream.bstate = HttpState.Closed + untrackHttpBodyReader(bstream) proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = ## HttpBodyWriter is AsyncStreamWriter which holds references to all the @@ -115,19 +121,22 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = ## ## First stream in sequence will be used as a destination. doAssert(len(streams) > 0, "At least one stream must be added") - var res = HttpBodyWriter(streams: @streams) + var res = HttpBodyWriter(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) trackHttpBodyWriter(res) res proc closeWait*(bstream: HttpBodyWriter) {.async.} = ## Close and free all the resources allocated by body writer. - var res = newSeq[Future[void]]() - for index in countdown(len(bstream.streams) - 1, 0): - res.add(bstream.streams[index].closeWait()) - await allFutures(res) - await procCall(closeWait(AsyncStreamWriter(bstream))) - untrackHttpBodyWriter(bstream) + if bstream.bstate == HttpState.Alive: + bstream.bstate = HttpState.Closing + var res = newSeq[Future[void]]() + for index in countdown(len(bstream.streams) - 1, 0): + res.add(bstream.streams[index].closeWait()) + await allFutures(res) + await procCall(closeWait(AsyncStreamWriter(bstream))) + bstream.bstate = HttpState.Closed + untrackHttpBodyWriter(bstream) proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [Defect].} = if len(bstream.streams) == 1: @@ -144,3 +153,7 @@ proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [Defect].} = false else: false + +proc closed*(bstream: HttpBodyReader | HttpBodyWriter): bool {. + raises: [Defect].} = + bstream.bstate != HttpState.Alive diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index a391dba..6ce6a98 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -1183,11 +1183,22 @@ proc redirect*(request: HttpClientRequestRef, proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {. async.} = - let response = await request.send() - let data = await response.getBodyBytes() - let code = response.status - await response.closeWait() - return (code, data) + var response: HttpClientResponseRef + try: + response = await request.send() + let buffer = await response.getBodyBytes() + let status = response.status + await response.closeWait() + response = nil + return (status, buffer) + except HttpError as exc: + if not(isNil(response)): + await response.closeWait() + raise exc + except CancelledError as exc: + if not(isNil(response)): + await response.closeWait() + raise exc proc fetch*(session: HttpSessionRef, url: Uri): Future[HttpResponseTuple] {. async.} = diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index 530b507..0653aa6 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -68,6 +68,9 @@ type CommaSeparatedArray ## Enable usage of comma symbol as separator of array ## items + HttpState* {.pure.} = enum + Alive, Closing, Closed + proc raiseHttpCriticalError*(msg: string, code = Http400) {.noinline, noreturn.} = raise (ref HttpCriticalError)(code: code, msg: msg) diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index d116a56..a8c000b 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -42,8 +42,7 @@ type RequestFence* = Result[HttpRequestRef, HttpProcessError] HttpRequestFlags* {.pure.} = enum - BoundBody, UnboundBody, MultipartForm, UrlencodedForm, - ClientExpect + BoundBody, UnboundBody, MultipartForm, UrlencodedForm, ClientExpect HttpResponseFlags* {.pure.} = enum KeepAlive, Chunked @@ -83,6 +82,7 @@ type HttpServerRef* = ref HttpServer HttpRequest* = object of RootObj + state*: HttpState headers*: HttpTable query*: HttpTable postTable: Option[HttpTable] @@ -113,6 +113,7 @@ type HttpResponseRef* = ref HttpResponse HttpConnection* = object of RootObj + state*: HttpState server*: HttpServerRef transp: StreamTransport mainReader*: AsyncStreamReader @@ -250,7 +251,7 @@ proc hasBody*(request: HttpRequestRef): bool {.raises: [Defect].} = proc prepareRequest(conn: HttpConnectionRef, req: HttpRequestHeader): HttpResultCode[HttpRequestRef] {. raises: [Defect].}= - var request = HttpRequestRef(connection: conn) + var request = HttpRequestRef(connection: conn, state: HttpState.Alive) if req.version notin {HttpVersion10, HttpVersion11}: return err(Http505) @@ -402,38 +403,57 @@ proc handleExpect*(request: HttpRequestRef) {.async.} = proc getBody*(request: HttpRequestRef): Future[seq[byte]] {.async.} = ## Obtain request's body as sequence of bytes. - let res = request.getBodyReader() - if res.isErr(): + let bodyReader = request.getBodyReader() + if bodyReader.isErr(): return @[] else: - let reader = res.get() + var reader = bodyReader.get() try: await request.handleExpect() - var res = await reader.read() + let res = await reader.read() if reader.hasOverflow(): + await reader.closeWait() + reader = nil raiseHttpCriticalError(MaximumBodySizeError, Http413) - return res + else: + await reader.closeWait() + reader = nil + return res + except CancelledError as exc: + if not(isNil(reader)): + await reader.closeWait() + raise exc except AsyncStreamError: + if not(isNil(reader)): + await reader.closeWait() raiseHttpCriticalError("Unable to read request's body") - finally: - await closeWait(res.get()) proc consumeBody*(request: HttpRequestRef): Future[void] {.async.} = ## Consume/discard request's body. - let res = request.getBodyReader() - if res.isErr(): + let bodyReader = request.getBodyReader() + if bodyReader.isErr(): return else: - let reader = res.get() + var reader = bodyReader.get() try: await request.handleExpect() discard await reader.consume() if reader.hasOverflow(): + await reader.closeWait() + reader = nil raiseHttpCriticalError(MaximumBodySizeError, Http413) + else: + await reader.closeWait() + reader = nil + return + except CancelledError as exc: + if not(isNil(reader)): + await reader.closeWait() + raise exc except AsyncStreamError: + if not(isNil(reader)): + await reader.closeWait() raiseHttpCriticalError("Unable to read request's body") - finally: - await closeWait(res.get()) proc getAcceptInfo*(request: HttpRequestRef): Result[AcceptInfo, cstring] = ## Returns value of `Accept` header as `AcceptInfo` object. @@ -574,6 +594,7 @@ proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} = proc init*(value: var HttpConnection, server: HttpServerRef, transp: StreamTransport) = value = HttpConnection( + state: HttpState.Alive, server: server, transp: transp, buffer: newSeq[byte](server.maxHeadersSize), @@ -590,23 +611,32 @@ proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, res proc closeWait*(conn: HttpConnectionRef) {.async.} = - var pending: seq[Future[void]] - if conn.reader != conn.mainReader: - pending.add(conn.reader.closeWait()) - if conn.writer != conn.mainWriter: - pending.add(conn.writer.closeWait()) - if len(pending) > 0: + if conn.state == HttpState.Alive: + conn.state = HttpState.Closing + var pending: seq[Future[void]] + if conn.reader != conn.mainReader: + pending.add(conn.reader.closeWait()) + if conn.writer != conn.mainWriter: + pending.add(conn.writer.closeWait()) + if len(pending) > 0: + await allFutures(pending) + # After we going to close everything else. + pending.setLen(3) + pending[0] = conn.mainReader.closeWait() + pending[1] = conn.mainWriter.closeWait() + pending[2] = conn.transp.closeWait() await allFutures(pending) - # After we going to close everything else. - await allFutures(conn.mainReader.closeWait(), conn.mainWriter.closeWait(), - conn.transp.closeWait()) + conn.state = HttpState.Closed proc closeWait(req: HttpRequestRef) {.async.} = - if req.response.isSome(): - let resp = req.response.get() - if (HttpResponseFlags.Chunked in resp.flags) and - not(isNil(resp.chunkedWriter)): - await resp.chunkedWriter.closeWait() + if req.state == HttpState.Alive: + if req.response.isSome(): + req.state = HttpState.Closing + let resp = req.response.get() + if (HttpResponseFlags.Chunked in resp.flags) and + not(isNil(resp.chunkedWriter)): + await resp.chunkedWriter.closeWait() + req.state = HttpState.Closed proc createConnection(server: HttpServerRef, transp: StreamTransport): Future[HttpConnectionRef] {. @@ -700,16 +730,21 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = if arg.isErr(): let code = arg.error().code - case arg.error().error - of HTTPServerError.TimeoutError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HTTPServerError.RecoverableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HTTPServerError.CriticalError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HTTPServerError.CatchableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.DisconnectError: + try: + case arg.error().error + of HTTPServerError.TimeoutError: + discard await conn.sendErrorResponse(HttpVersion11, code, false) + of HTTPServerError.RecoverableError: + discard await conn.sendErrorResponse(HttpVersion11, code, false) + of HTTPServerError.CriticalError: + discard await conn.sendErrorResponse(HttpVersion11, code, false) + of HTTPServerError.CatchableError: + discard await conn.sendErrorResponse(HttpVersion11, code, false) + of HttpServerError.DisconnectError: + discard + except CancelledError: + # We swallowing `CancelledError` in a loop, but we going to exit + # loop ASAP. discard break else: @@ -718,33 +753,52 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = if lastErrorCode.isNone(): if isNil(resp): # Response was `nil`. - discard await conn.sendErrorResponse(HttpVersion11, Http404, - false) + try: + discard await conn.sendErrorResponse(HttpVersion11, Http404, false) + except CancelledError: + keepConn = false else: - case resp.state - of HttpResponseState.Empty: - # Response was ignored - discard await conn.sendErrorResponse(HttpVersion11, Http404, - keepConn) - of HttpResponseState.Prepared: - # Response was prepared but not sent. - discard await conn.sendErrorResponse(HttpVersion11, Http409, - keepConn) - else: - # some data was already sent to the client. - discard + try: + case resp.state + of HttpResponseState.Empty: + # Response was ignored + discard await conn.sendErrorResponse(HttpVersion11, Http404, + keepConn) + of HttpResponseState.Prepared: + # Response was prepared but not sent. + discard await conn.sendErrorResponse(HttpVersion11, Http409, + keepConn) + else: + # some data was already sent to the client. + discard + except CancelledError: + keepConn = false else: - discard await conn.sendErrorResponse(HttpVersion11, lastErrorCode.get(), - false) + try: + discard await conn.sendErrorResponse(HttpVersion11, + lastErrorCode.get(), false) + except CancelledError: + keepConn = false + # Closing and releasing all the request resources. - await request.closeWait() + try: + await request.closeWait() + except CancelledError: + # We swallowing `CancelledError` in a loop, but we still need to close + # `request` before exiting. + await request.closeWait() if not(keepConn): break # Connection could be `nil` only when secure handshake is failed. if not(isNil(conn)): - await conn.closeWait() + try: + await conn.closeWait() + except CancelledError: + # Cancellation could be happened while we closing `conn`. But we still + # need to close it. + await conn.closeWait() server.connections.del(transp.getId()) # if server.maxConnections > 0: diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 04b7ed4..acc92ef 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -59,11 +59,14 @@ type Error, ## Stream has stored error Stopped, ## Stream was closed while working Finished, ## Stream was properly finished + Closing, ## Stream is closing Closed ## Stream was closed - StreamReaderLoop* = proc (stream: AsyncStreamReader): Future[void] {.gcsafe, raises: [Defect].} + StreamReaderLoop* = proc (stream: AsyncStreamReader): Future[void] {. + gcsafe, raises: [Defect].} ## Main read loop for read streams. - StreamWriterLoop* = proc (stream: AsyncStreamWriter): Future[void] {.gcsafe, raises: [Defect].} + StreamWriterLoop* = proc (stream: AsyncStreamWriter): Future[void] {. + gcsafe, raises: [Defect].} ## Main write loop for write streams. AsyncStreamReader* = ref object of RootRef @@ -223,10 +226,6 @@ proc raiseAsyncStreamIncompleteError*() {. noinline, noreturn, raises: [Defect, AsyncStreamIncompleteError].} = raise newAsyncStreamIncompleteError() -proc raiseAsyncStreamIncorrectDefect*(m: string) {. - noinline, noreturn, raises: [Defect].} = - raise newException(AsyncStreamIncorrectDefect, m) - proc raiseEmptyMessageDefect*() {.noinline, noreturn.} = raise newException(AsyncStreamIncorrectDefect, "Could not write empty message") @@ -244,7 +243,7 @@ proc atEof*(rstream: AsyncStreamReader): bool = else: rstream.rsource.atEof() else: - rstream.state in {AsyncStreamState.Stopped, Finished, Closed, Error} and + (rstream.state != AsyncStreamState.Running) and (rstream.buffer.dataLen() == 0) proc atEof*(wstream: AsyncStreamWriter): bool = @@ -255,11 +254,11 @@ proc atEof*(wstream: AsyncStreamWriter): bool = else: wstream.wsource.atEof() else: - wstream.state in {AsyncStreamState.Stopped, Finished, Closed, Error} + wstream.state != AsyncStreamState.Running proc closed*(reader: AsyncStreamReader): bool = ## Returns ``true`` is reading/writing stream is closed. - (reader.state == AsyncStreamState.Closed) + reader.state in {AsyncStreamState.Closing, Closed} proc finished*(reader: AsyncStreamReader): bool = ## Returns ``true`` is reading/writing stream is finished (completed). @@ -302,7 +301,7 @@ proc failed*(reader: AsyncStreamReader): bool = proc closed*(writer: AsyncStreamWriter): bool = ## Returns ``true`` is reading/writing stream is closed. - (writer.state == AsyncStreamState.Closed) + writer.state in {AsyncStreamState.Closing, Closed} proc finished*(writer: AsyncStreamWriter): bool = ## Returns ``true`` is reading/writing stream is finished (completed). @@ -965,39 +964,38 @@ proc close*(rw: AsyncStreamRW) = ## Close and frees resources of stream ``rw``. ## ## Note close() procedure is not completed immediately! - if rw.closed(): - raiseAsyncStreamIncorrectDefect("Stream is already closed!") + if not(rw.closed()): + rw.state = AsyncStreamState.Closing - rw.state = AsyncStreamState.Closed + proc continuation(udata: pointer) {.raises: [Defect].} = + if not isNil(rw.udata): + GC_unref(cast[ref int](rw.udata)) + if not(rw.future.finished()): + rw.future.complete() + when rw is AsyncStreamReader: + untrackAsyncStreamReader(rw) + elif rw is AsyncStreamWriter: + untrackAsyncStreamWriter(rw) + rw.state = AsyncStreamState.Closed - proc continuation(udata: pointer) {.raises: [Defect].} = - if not isNil(rw.udata): - GC_unref(cast[ref int](rw.udata)) - if not(rw.future.finished()): - rw.future.complete() when rw is AsyncStreamReader: - untrackAsyncStreamReader(rw) + if isNil(rw.rsource) or isNil(rw.readerLoop) or isNil(rw.future): + callSoon(continuation) + else: + if rw.future.finished(): + callSoon(continuation) + else: + rw.future.addCallback(continuation) + rw.future.cancel() elif rw is AsyncStreamWriter: - untrackAsyncStreamWriter(rw) - - when rw is AsyncStreamReader: - if isNil(rw.rsource) or isNil(rw.readerLoop) or isNil(rw.future): - callSoon(continuation) - else: - if rw.future.finished(): + if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future): callSoon(continuation) else: - rw.future.addCallback(continuation) - rw.future.cancel() - elif rw is AsyncStreamWriter: - if isNil(rw.wsource) or isNil(rw.writerLoop) or isNil(rw.future): - callSoon(continuation) - else: - if rw.future.finished(): - callSoon(continuation) - else: - rw.future.addCallback(continuation) - rw.future.cancel() + if rw.future.finished(): + callSoon(continuation) + else: + rw.future.addCallback(continuation) + rw.future.cancel() proc closeWait*(rw: AsyncStreamRW): Future[void] = ## Close and frees resources of stream ``rw``. diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim index b63d21e..e1997a0 100644 --- a/chronos/streams/boundstream.nim +++ b/chronos/streams/boundstream.nim @@ -110,7 +110,8 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = if toRead == 0: # When ``rstream.boundSize`` is set and we already readed # ``rstream.boundSize`` bytes. - rstream.state = AsyncStreamState.Finished + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished else: let res = await readUntilBoundary(rstream.rsource, addr buffer[0], toRead, rstream.boundary) @@ -123,7 +124,8 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = # consumer and declaring stream EOF. Otherwise could not be # consumed. await upload(addr rstream.buffer, addr buffer[0], length) - rstream.state = AsyncStreamState.Finished + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished else: rstream.offset = rstream.offset + uint64(res) # There should be one step between transferring last bytes to the @@ -134,10 +136,12 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = if (res < toRead) and rstream.rsource.atEof(): case rstream.cmpop of BoundCmp.Equal: - rstream.state = AsyncStreamState.Error - rstream.error = newBoundedStreamIncompleteError() + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() of BoundCmp.LessOrEqual: - rstream.state = AsyncStreamState.Finished + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished else: rstream.offset = rstream.offset + uint64(res) # There should be one step between transferring last bytes to the @@ -148,24 +152,30 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = if (res < toRead) and rstream.rsource.atEof(): case rstream.cmpop of BoundCmp.Equal: - rstream.state = AsyncStreamState.Error - rstream.error = newBoundedStreamIncompleteError() + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() of BoundCmp.LessOrEqual: - rstream.state = AsyncStreamState.Finished + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished else: case rstream.cmpop of BoundCmp.Equal: - rstream.state = AsyncStreamState.Error - rstream.error = newBoundedStreamIncompleteError() + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() of BoundCmp.LessOrEqual: - rstream.state = AsyncStreamState.Finished + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished except AsyncStreamError as exc: - rstream.state = AsyncStreamState.Error - rstream.error = exc + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = exc except CancelledError: - rstream.state = AsyncStreamState.Error - rstream.error = newAsyncStreamUseClosedError() + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newAsyncStreamUseClosedError() case rstream.state of AsyncStreamState.Running: @@ -178,7 +188,7 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = # Send `EOF` state to the consumer and wait until it will be received. await rstream.buffer.transfer() break - of AsyncStreamState.Closed: + of AsyncStreamState.Closing, AsyncStreamState.Closed: break proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = @@ -203,26 +213,32 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = wstream.offset = wstream.offset + uint64(item.size) item.future.complete() else: - wstream.state = AsyncStreamState.Error - error = newBoundedStreamOverflowError() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamOverflowError() else: if wstream.offset == wstream.boundSize: - wstream.state = AsyncStreamState.Finished - item.future.complete() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Finished + item.future.complete() else: case wstream.cmpop of BoundCmp.Equal: - wstream.state = AsyncStreamState.Error - error = newBoundedStreamIncompleteError() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Error + error = newBoundedStreamIncompleteError() of BoundCmp.LessOrEqual: - wstream.state = AsyncStreamState.Finished - item.future.complete() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Finished + item.future.complete() except CancelledError: - wstream.state = AsyncStreamState.Stopped - error = newAsyncStreamUseClosedError() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() except AsyncStreamError as exc: - wstream.state = AsyncStreamState.Error - error = exc + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Error + error = exc case wstream.state of AsyncStreamState.Running: @@ -232,7 +248,8 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = if not(item.future.finished()): item.future.fail(error) break - of AsyncStreamState.Finished, AsyncStreamState.Closed: + of AsyncStreamState.Finished, AsyncStreamState.Closing, + AsyncStreamState.Closed: error = newAsyncStreamUseClosedError() break diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index 598d641..759f709 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -108,8 +108,9 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = let cres = getChunkSize(buffer.toOpenArray(0, res - len(CRLF) - 1)) if cres.isErr(): - rstream.error = newException(ChunkedStreamProtocolError, $cres.error) - rstream.state = AsyncStreamState.Error + if rstream.state == AsyncStreamState.Running: + rstream.error = newException(ChunkedStreamProtocolError, $cres.error) + rstream.state = AsyncStreamState.Error else: var chunksize = cres.get() if chunksize > 0'u64: @@ -127,28 +128,34 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = await rstream.rsource.readExactly(addr buffer[0], 2) if buffer[0] != CRLF[0] or buffer[1] != CRLF[1]: - rstream.error = newException(ChunkedStreamProtocolError, - "Unexpected trailing bytes") - rstream.state = AsyncStreamState.Error + if rstream.state == AsyncStreamState.Running: + rstream.error = newException(ChunkedStreamProtocolError, + "Unexpected trailing bytes") + rstream.state = AsyncStreamState.Error else: # Reading trailing line for last chunk discard await rstream.rsource.readUntil(addr buffer[0], len(buffer), CRLF) - rstream.state = AsyncStreamState.Finished - await rstream.buffer.transfer() + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Finished + await rstream.buffer.transfer() except CancelledError: - rstream.state = AsyncStreamState.Stopped + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Stopped except AsyncStreamLimitError: - rstream.state = AsyncStreamState.Error - rstream.error = newException(ChunkedStreamProtocolError, - "Chunk header exceeds maximum size") + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newException(ChunkedStreamProtocolError, + "Chunk header exceeds maximum size") except AsyncStreamIncompleteError: - rstream.state = AsyncStreamState.Error - rstream.error = newException(ChunkedStreamIncompleteError, - "Incomplete chunk received") + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = newException(ChunkedStreamIncompleteError, + "Incomplete chunk received") except AsyncStreamReadError as exc: - rstream.state = AsyncStreamState.Error - rstream.error = exc + if rstream.state == AsyncStreamState.Running: + rstream.state = AsyncStreamState.Error + rstream.error = exc if rstream.state != AsyncStreamState.Running: # We need to notify consumer about error/close, but we do not care about @@ -194,13 +201,16 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = # Everything is fine, completing queue item's future. item.future.complete() # Set stream state to Finished. - wstream.state = AsyncStreamState.Finished + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Finished except CancelledError: - wstream.state = AsyncStreamState.Stopped - error = newAsyncStreamUseClosedError() + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Stopped + error = newAsyncStreamUseClosedError() except AsyncStreamError as exc: - wstream.state = AsyncStreamState.Error - error = exc + if wstream.state == AsyncStreamState.Running: + wstream.state = AsyncStreamState.Error + error = exc if wstream.state != AsyncStreamState.Running: if wstream.state == AsyncStreamState.Finished: diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 960f56f..70a7d68 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -146,10 +146,12 @@ proc tlsWriteRec(engine: ptr SslEngineContext, sslEngineSendrecAck(engine, length) return TLSResult.Success except AsyncStreamError as exc: - writer.state = AsyncStreamState.Error - writer.error = exc + if writer.state == AsyncStreamState.Running: + writer.state = AsyncStreamState.Error + writer.error = exc except CancelledError: - writer.state = AsyncStreamState.Stopped + if writer.state == AsyncStreamState.Running: + writer.state = AsyncStreamState.Stopped return TLSResult.Error proc tlsWriteApp(engine: ptr SslEngineContext, @@ -180,7 +182,8 @@ proc tlsWriteApp(engine: ptr SslEngineContext, item.future.complete() return TLSResult.Success except CancelledError: - writer.state = AsyncStreamState.Stopped + if writer.state == AsyncStreamState.Running: + writer.state = AsyncStreamState.Stopped return TLSResult.Error proc tlsReadRec(engine: ptr SslEngineContext, @@ -197,10 +200,12 @@ proc tlsReadRec(engine: ptr SslEngineContext, else: return TLSResult.Success except CancelledError: - reader.state = AsyncStreamState.Stopped + if reader.state == AsyncStreamState.Running: + reader.state = AsyncStreamState.Stopped except AsyncStreamError as exc: - reader.state = AsyncStreamState.Error - reader.error = exc + if reader.state == AsyncStreamState.Running: + reader.state = AsyncStreamState.Error + reader.error = exc return TLSResult.Error proc tlsReadApp(engine: ptr SslEngineContext, @@ -212,7 +217,8 @@ proc tlsReadApp(engine: ptr SslEngineContext, sslEngineRecvappAck(engine, length) return TLSResult.Success except CancelledError: - reader.state = AsyncStreamState.Stopped + if reader.state == AsyncStreamState.Running: + reader.state = AsyncStreamState.Stopped return TLSResult.Error template readAndReset(fut: untyped) = @@ -224,11 +230,13 @@ template readAndReset(fut: untyped) = continue of TLSResult.Error: fut = nil - loopState = AsyncStreamState.Error + if loopState == AsyncStreamState.Running: + loopState = AsyncStreamState.Error break of TLSResult.EOF: fut = nil - loopState = AsyncStreamState.Finished + if loopState == AsyncStreamState.Running: + loopState = AsyncStreamState.Finished break proc cancelAndWait*(a, b, c, d: Future[TLSResult]): Future[void] = @@ -285,7 +293,8 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = var state = sslEngineCurrentState(engine) if (state and SSL_CLOSED) == SSL_CLOSED: - loopState = AsyncStreamState.Finished + if loopState == AsyncStreamState.Running: + loopState = AsyncStreamState.Finished break if isNil(sendRecFut): @@ -332,7 +341,8 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = try: discard await one(waiting) except CancelledError: - loopState = AsyncStreamState.Stopped + if loopState == AsyncStreamState.Running: + loopState = AsyncStreamState.Stopped if loopState != AsyncStreamState.Running: break