From 1598471ed23a6e1aa60cca7601df1f3429dee223 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 21 Dec 2023 15:52:16 +0100 Subject: [PATCH 01/11] add a test for `results.?` compatibility (#484) Finally! (haha) --- tests/testmacro.nim | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/testmacro.nim b/tests/testmacro.nim index 0133793..9b19c68 100644 --- a/tests/testmacro.nim +++ b/tests/testmacro.nim @@ -555,3 +555,27 @@ suite "Exceptions tracking": await raiseException() waitFor(callCatchAll()) + + test "Results compatibility": + proc returnOk(): Future[Result[int, string]] {.async: (raises: []).} = + ok(42) + + proc returnErr(): Future[Result[int, string]] {.async: (raises: []).} = + err("failed") + + proc testit(): Future[Result[void, string]] {.async: (raises: []).} = + let + v = await returnOk() + + check: + v.isOk() and v.value() == 42 + + let + vok = ?v + check: + vok == 42 + + discard ?await returnErr() + + check: + waitFor(testit()).error() == "failed" From 41f77d261ead2508acdd3bd3f88a5cbbcefff05f Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Wed, 27 Dec 2023 20:57:39 +0100 Subject: [PATCH 02/11] Better line information on effect violation We can capture the line info from the original future source and direct violation errors there --- chronos/internal/asyncfutures.nim | 96 ++++++++++++++++++------------- chronos/internal/asyncmacro.nim | 4 +- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/chronos/internal/asyncfutures.nim b/chronos/internal/asyncfutures.nim index a7fd961..8078952 100644 --- a/chronos/internal/asyncfutures.nim +++ b/chronos/internal/asyncfutures.nim @@ -478,14 +478,26 @@ when chronosStackTrace: # newMsg.add "\n" & $entry error.msg = newMsg -proc internalCheckComplete*(fut: FutureBase) {.raises: [CatchableError].} = - # For internal use only. Used in asyncmacro - if not(isNil(fut.internalError)): - when chronosStackTrace: - injectStacktrace(fut.internalError) - raise fut.internalError +proc deepLineInfo(n: NimNode, p: LineInfo) = + n.setLineInfo(p) + for i in 0.. Date: Thu, 4 Jan 2024 16:17:42 +0100 Subject: [PATCH 03/11] prevent http `closeWait` future from being cancelled (#486) * simplify `closeWait` implementations * remove redundant cancellation callbacks * use `noCancel` to avoid forgetting the right future flags * add a few missing raises trackers * enforce `OwnCancelSchedule` on manually created futures that don't raise `CancelledError` * ensure cancellations don't reach internal futures --- chronos/apps/http/httpbodyrw.nim | 4 +- chronos/apps/http/httpclient.nim | 16 ++--- chronos/asyncsync.nim | 8 +-- chronos/futures.nim | 19 ++++++ chronos/internal/asyncfutures.nim | 2 + chronos/internal/raisesfutures.nim | 91 ++++++++++++++----------- chronos/streams/asyncstream.nim | 37 +++------- chronos/transports/common.nim | 4 +- chronos/transports/datagram.nim | 43 ++++-------- chronos/transports/stream.nim | 104 ++++++++++++----------------- 10 files changed, 152 insertions(+), 176 deletions(-) diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim index c9ac899..9a11e85 100644 --- a/chronos/apps/http/httpbodyrw.nim +++ b/chronos/apps/http/httpbodyrw.nim @@ -43,7 +43,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async: (raises: []).} = ## Close and free resource allocated by body reader. if bstream.bstate == HttpState.Alive: bstream.bstate = HttpState.Closing - var res = newSeq[Future[void]]() + var res = newSeq[Future[void].Raising([])]() # We closing streams in reversed order because stream at position [0], uses # data from stream at position [1]. for index in countdown((len(bstream.streams) - 1), 0): @@ -68,7 +68,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async: (raises: []).} = ## Close and free all the resources allocated by body writer. if bstream.bstate == HttpState.Alive: bstream.bstate = HttpState.Closing - var res = newSeq[Future[void]]() + var res = newSeq[Future[void].Raising([])]() for index in countdown(len(bstream.streams) - 1, 0): res.add(bstream.streams[index].closeWait()) await noCancel(allFutures(res)) diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index 5f4bd71..33a6b7f 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -294,7 +294,7 @@ proc new*(t: typedesc[HttpSessionRef], if HttpClientFlag.Http11Pipeline in flags: sessionWatcher(res) else: - Future[void].Raising([]).init("session.watcher.placeholder") + nil res proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] = @@ -607,7 +607,7 @@ proc closeWait(conn: HttpClientConnectionRef) {.async: (raises: []).} = conn.state = HttpClientConnectionState.Closing let pending = block: - var res: seq[Future[void]] + var res: seq[Future[void].Raising([])] if not(isNil(conn.reader)) and not(conn.reader.closed()): res.add(conn.reader.closeWait()) if not(isNil(conn.writer)) and not(conn.writer.closed()): @@ -847,14 +847,14 @@ proc sessionWatcher(session: HttpSessionRef) {.async: (raises: []).} = break proc closeWait*(request: HttpClientRequestRef) {.async: (raises: []).} = - var pending: seq[FutureBase] + var pending: seq[Future[void].Raising([])] if request.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: request.state = HttpReqRespState.Closing if not(isNil(request.writer)): if not(request.writer.closed()): - pending.add(FutureBase(request.writer.closeWait())) + pending.add(request.writer.closeWait()) request.writer = nil - pending.add(FutureBase(request.releaseConnection())) + pending.add(request.releaseConnection()) await noCancel(allFutures(pending)) request.session = nil request.error = nil @@ -862,14 +862,14 @@ proc closeWait*(request: HttpClientRequestRef) {.async: (raises: []).} = untrackCounter(HttpClientRequestTrackerName) proc closeWait*(response: HttpClientResponseRef) {.async: (raises: []).} = - var pending: seq[FutureBase] + var pending: seq[Future[void].Raising([])] if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: response.state = HttpReqRespState.Closing if not(isNil(response.reader)): if not(response.reader.closed()): - pending.add(FutureBase(response.reader.closeWait())) + pending.add(response.reader.closeWait()) response.reader = nil - pending.add(FutureBase(response.releaseConnection())) + pending.add(response.releaseConnection()) await noCancel(allFutures(pending)) response.session = nil response.error = nil diff --git a/chronos/asyncsync.nim b/chronos/asyncsync.nim index f77d5fe..5fab9b2 100644 --- a/chronos/asyncsync.nim +++ b/chronos/asyncsync.nim @@ -523,15 +523,13 @@ proc closeWait*(ab: AsyncEventQueue): Future[void] {. {FutureFlag.OwnCancelSchedule}) proc continuation(udata: pointer) {.gcsafe.} = retFuture.complete() - proc cancellation(udata: pointer) {.gcsafe.} = - # We are not going to change the state of `retFuture` to cancelled, so we - # will prevent the entire sequence of Futures from being cancelled. - discard + + # Ignore cancellation requests - we'll complete the future soon enough + retFuture.cancelCallback = nil ab.close() # Schedule `continuation` to be called only after all the `reader` # notifications will be scheduled and processed. - retFuture.cancelCallback = cancellation callSoon(continuation) retFuture diff --git a/chronos/futures.nim b/chronos/futures.nim index 6fb9592..fd8dbfe 100644 --- a/chronos/futures.nim +++ b/chronos/futures.nim @@ -34,6 +34,19 @@ type FutureFlag* {.pure.} = enum OwnCancelSchedule + ## When OwnCancelSchedule is set, the owner of the future is responsible + ## for implementing cancellation in one of 3 ways: + ## + ## * ensure that cancellation requests never reach the future by means of + ## not exposing it to user code, `await` and `tryCancel` + ## * set `cancelCallback` to `nil` to stop cancellation propagation - this + ## is appropriate when it is expected that the future will be completed + ## in a regular way "soon" + ## * set `cancelCallback` to a handler that implements cancellation in an + ## operation-specific way + ## + ## If `cancelCallback` is not set and the future gets cancelled, a + ## `Defect` will be raised. FutureFlags* = set[FutureFlag] @@ -104,6 +117,12 @@ proc internalInitFutureBase*(fut: FutureBase, loc: ptr SrcLoc, fut.internalState = state fut.internalLocation[LocationKind.Create] = loc fut.internalFlags = flags + if FutureFlag.OwnCancelSchedule in flags: + # Owners must replace `cancelCallback` with `nil` if they want to ignore + # cancellations + fut.internalCancelcb = proc(_: pointer) = + raiseAssert "Cancellation request for non-cancellable future" + if state != FutureState.Pending: fut.internalLocation[LocationKind.Finish] = loc diff --git a/chronos/internal/asyncfutures.nim b/chronos/internal/asyncfutures.nim index 8078952..5ce9da4 100644 --- a/chronos/internal/asyncfutures.nim +++ b/chronos/internal/asyncfutures.nim @@ -1013,6 +1013,7 @@ proc cancelAndWait*(future: FutureBase, loc: ptr SrcLoc): Future[void] {. if future.finished(): retFuture.complete() else: + retFuture.cancelCallback = nil cancelSoon(future, continuation, cast[pointer](retFuture), loc) retFuture @@ -1057,6 +1058,7 @@ proc noCancel*[F: SomeFuture](future: F): auto = # async: (raw: true, raises: as if future.finished(): completeFuture() else: + retFuture.cancelCallback = nil future.addCallback(continuation) retFuture diff --git a/chronos/internal/raisesfutures.nim b/chronos/internal/raisesfutures.nim index 20fa6ed..5b91f41 100644 --- a/chronos/internal/raisesfutures.nim +++ b/chronos/internal/raisesfutures.nim @@ -18,45 +18,6 @@ proc makeNoRaises*(): NimNode {.compileTime.} = ident"void" -macro Raising*[T](F: typedesc[Future[T]], E: varargs[typedesc]): untyped = - ## Given a Future type instance, return a type storing `{.raises.}` - ## information - ## - ## Note; this type may change in the future - E.expectKind(nnkBracket) - - let raises = if E.len == 0: - makeNoRaises() - else: - nnkTupleConstr.newTree(E.mapIt(it)) - nnkBracketExpr.newTree( - ident "InternalRaisesFuture", - nnkDotExpr.newTree(F, ident"T"), - raises - ) - -template init*[T, E]( - F: type InternalRaisesFuture[T, E], fromProc: static[string] = ""): F = - ## Creates a new pending future. - ## - ## Specifying ``fromProc``, which is a string specifying the name of the proc - ## that this future belongs to, is a good habit as it helps with debugging. - let res = F() - internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {}) - res - -template init*[T, E]( - F: type InternalRaisesFuture[T, E], fromProc: static[string] = "", - flags: static[FutureFlags]): F = - ## Creates a new pending future. - ## - ## Specifying ``fromProc``, which is a string specifying the name of the proc - ## that this future belongs to, is a good habit as it helps with debugging. - let res = F() - internalInitFutureBase( - res, getSrcLocation(fromProc), FutureState.Pending, flags) - res - proc dig(n: NimNode): NimNode {.compileTime.} = # Dig through the layers of type to find the raises list if n.eqIdent("void"): @@ -87,6 +48,58 @@ proc members(tup: NimNode): seq[NimNode] {.compileTime.} = for t in tup.members(): result.add(t) +macro hasException(raises: typedesc, ident: static string): bool = + newLit(raises.members.anyIt(it.eqIdent(ident))) + +macro Raising*[T](F: typedesc[Future[T]], E: varargs[typedesc]): untyped = + ## Given a Future type instance, return a type storing `{.raises.}` + ## information + ## + ## Note; this type may change in the future + E.expectKind(nnkBracket) + + let raises = if E.len == 0: + makeNoRaises() + else: + nnkTupleConstr.newTree(E.mapIt(it)) + nnkBracketExpr.newTree( + ident "InternalRaisesFuture", + nnkDotExpr.newTree(F, ident"T"), + raises + ) + +template init*[T, E]( + F: type InternalRaisesFuture[T, E], fromProc: static[string] = ""): F = + ## Creates a new pending future. + ## + ## Specifying ``fromProc``, which is a string specifying the name of the proc + ## that this future belongs to, is a good habit as it helps with debugging. + when not hasException(type(E), "CancelledError"): + static: + raiseAssert "Manually created futures must either own cancellation schedule or raise CancelledError" + + + let res = F() + internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {}) + res + +template init*[T, E]( + F: type InternalRaisesFuture[T, E], fromProc: static[string] = "", + flags: static[FutureFlags]): F = + ## Creates a new pending future. + ## + ## Specifying ``fromProc``, which is a string specifying the name of the proc + ## that this future belongs to, is a good habit as it helps with debugging. + let res = F() + when not hasException(type(E), "CancelledError"): + static: + doAssert FutureFlag.OwnCancelSchedule in flags, + "Manually created futures must either own cancellation schedule or raise CancelledError" + + internalInitFutureBase( + res, getSrcLocation(fromProc), FutureState.Pending, flags) + res + proc containsSignature(members: openArray[NimNode], typ: NimNode): bool {.compileTime.} = let typHash = signatureHash(typ) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index a521084..4fbe7a4 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -77,7 +77,7 @@ type udata: pointer error*: ref AsyncStreamError bytesCount*: uint64 - future: Future[void] + future: Future[void].Raising([]) AsyncStreamWriter* = ref object of RootRef wsource*: AsyncStreamWriter @@ -88,7 +88,7 @@ type error*: ref AsyncStreamError udata: pointer bytesCount*: uint64 - future: Future[void] + future: Future[void].Raising([]) AsyncStream* = object of RootObj reader*: AsyncStreamReader @@ -897,44 +897,27 @@ proc close*(rw: AsyncStreamRW) = rw.future.addCallback(continuation) rw.future.cancelSoon() -proc closeWait*(rw: AsyncStreamRW): Future[void] {. - async: (raw: true, raises: []).} = +proc closeWait*(rw: AsyncStreamRW): Future[void] {.async: (raises: []).} = ## Close and frees resources of stream ``rw``. - const FutureName = - when rw is AsyncStreamReader: - "async.stream.reader.closeWait" - else: - "async.stream.writer.closeWait" - - let retFuture = Future[void].Raising([]).init(FutureName) - - if rw.closed(): - retFuture.complete() - return retFuture - - proc continuation(udata: pointer) {.gcsafe, raises:[].} = - retFuture.complete() - - rw.close() - if rw.future.finished(): - retFuture.complete() - else: - rw.future.addCallback(continuation, cast[pointer](retFuture)) - retFuture + if not rw.closed(): + rw.close() + await noCancel(rw.join()) proc startReader(rstream: AsyncStreamReader) = rstream.state = Running if not isNil(rstream.readerLoop): rstream.future = rstream.readerLoop(rstream) else: - rstream.future = newFuture[void]("async.stream.empty.reader") + rstream.future = Future[void].Raising([]).init( + "async.stream.empty.reader", {FutureFlag.OwnCancelSchedule}) proc startWriter(wstream: AsyncStreamWriter) = wstream.state = Running if not isNil(wstream.writerLoop): wstream.future = wstream.writerLoop(wstream) else: - wstream.future = newFuture[void]("async.stream.empty.writer") + wstream.future = Future[void].Raising([]).init( + "async.stream.empty.writer", {FutureFlag.OwnCancelSchedule}) proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, queueSize = AsyncStreamDefaultQueueSize) = diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index ba7568a..8fa062a 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -73,7 +73,7 @@ when defined(windows) or defined(nimdoc): udata*: pointer # User-defined pointer flags*: set[ServerFlags] # Flags bufferSize*: int # Size of internal transports' buffer - loopFuture*: Future[void] # Server's main Future + loopFuture*: Future[void].Raising([]) # Server's main Future domain*: Domain # Current server domain (IPv4 or IPv6) apending*: bool asock*: AsyncFD # Current AcceptEx() socket @@ -92,7 +92,7 @@ else: udata*: pointer # User-defined pointer flags*: set[ServerFlags] # Flags bufferSize*: int # Size of internal transports' buffer - loopFuture*: Future[void] # Server's main Future + loopFuture*: Future[void].Raising([]) # Server's main Future errorCode*: OSErrorCode # Current error code dualstack*: DualStackType # IPv4/IPv6 dualstack parameters diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index fed15d3..88db7ee 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -44,7 +44,7 @@ type remote: TransportAddress # Remote address udata*: pointer # User-driven pointer function: DatagramCallback # Receive data callback - future: Future[void] # Transport's life future + future: Future[void].Raising([]) # Transport's life future raddr: Sockaddr_storage # Reader address storage ralen: SockLen # Reader address length waddr: Sockaddr_storage # Writer address storage @@ -359,7 +359,8 @@ when defined(windows): res.queue = initDeque[GramVector]() res.udata = udata res.state = {ReadPaused, WritePaused} - res.future = newFuture[void]("datagram.transport") + res.future = Future[void].Raising([]).init( + "datagram.transport", {FutureFlag.OwnCancelSchedule}) res.rovl.data = CompletionData(cb: readDatagramLoop, udata: cast[pointer](res)) res.wovl.data = CompletionData(cb: writeDatagramLoop, @@ -568,7 +569,8 @@ else: res.queue = initDeque[GramVector]() res.udata = udata res.state = {ReadPaused, WritePaused} - res.future = newFuture[void]("datagram.transport") + res.future = Future[void].Raising([]).init( + "datagram.transport", {FutureFlag.OwnCancelSchedule}) GC_ref(res) # Start tracking transport trackCounter(DgramTransportTrackerName) @@ -840,31 +842,16 @@ proc join*(transp: DatagramTransport): Future[void] {. return retFuture +proc closed*(transp: DatagramTransport): bool {.inline.} = + ## Returns ``true`` if transport in closed state. + {ReadClosed, WriteClosed} * transp.state != {} + proc closeWait*(transp: DatagramTransport): Future[void] {. - async: (raw: true, raises: []).} = + async: (raises: []).} = ## Close transport ``transp`` and release all resources. - let retFuture = newFuture[void]( - "datagram.transport.closeWait", {FutureFlag.OwnCancelSchedule}) - - if {ReadClosed, WriteClosed} * transp.state != {}: - retFuture.complete() - return retFuture - - proc continuation(udata: pointer) {.gcsafe.} = - retFuture.complete() - - proc cancellation(udata: pointer) {.gcsafe.} = - # We are not going to change the state of `retFuture` to cancelled, so we - # will prevent the entire sequence of Futures from being cancelled. - discard - - transp.close() - if transp.future.finished(): - retFuture.complete() - else: - transp.future.addCallback(continuation, cast[pointer](retFuture)) - retFuture.cancelCallback = cancellation - retFuture + if not transp.closed(): + transp.close() + await noCancel(transp.join()) proc send*(transp: DatagramTransport, pbytes: pointer, nbytes: int): Future[void] {. @@ -1020,7 +1007,3 @@ proc getMessage*(transp: DatagramTransport): seq[byte] {. proc getUserData*[T](transp: DatagramTransport): T {.inline.} = ## Obtain user data stored in ``transp`` object. cast[T](transp.udata) - -proc closed*(transp: DatagramTransport): bool {.inline.} = - ## Returns ``true`` if transport in closed state. - {ReadClosed, WriteClosed} * transp.state != {} diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index c0d1cfc..73699a2 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -76,7 +76,7 @@ when defined(windows): offset: int # Reading buffer offset error: ref TransportError # Current error queue: Deque[StreamVector] # Writer queue - future: Future[void] # Stream life future + future: Future[void].Raising([]) # Stream life future # Windows specific part rwsabuf: WSABUF # Reader WSABUF wwsabuf: WSABUF # Writer WSABUF @@ -103,7 +103,7 @@ else: offset: int # Reading buffer offset error: ref TransportError # Current error queue: Deque[StreamVector] # Writer queue - future: Future[void] # Stream life future + future: Future[void].Raising([]) # Stream life future case kind*: TransportKind of TransportKind.Socket: domain: Domain # Socket transport domain (IPv4/IPv6) @@ -598,7 +598,8 @@ when defined(windows): transp.buffer = newSeq[byte](bufsize) transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() - transp.future = newFuture[void]("stream.socket.transport") + transp.future = Future[void].Raising([]).init( + "stream.socket.transport", {FutureFlag.OwnCancelSchedule}) GC_ref(transp) transp @@ -619,7 +620,8 @@ when defined(windows): transp.flags = flags transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() - transp.future = newFuture[void]("stream.pipe.transport") + transp.future = Future[void].Raising([]).init( + "stream.pipe.transport", {FutureFlag.OwnCancelSchedule}) GC_ref(transp) transp @@ -1457,7 +1459,8 @@ else: transp.buffer = newSeq[byte](bufsize) transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() - transp.future = newFuture[void]("socket.stream.transport") + transp.future = Future[void].Raising([]).init( + "socket.stream.transport", {FutureFlag.OwnCancelSchedule}) GC_ref(transp) transp @@ -1473,7 +1476,8 @@ else: transp.buffer = newSeq[byte](bufsize) transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() - transp.future = newFuture[void]("pipe.stream.transport") + transp.future = Future[void].Raising([]).init( + "pipe.stream.transport", {FutureFlag.OwnCancelSchedule}) GC_ref(transp) transp @@ -1806,6 +1810,9 @@ proc connect*(address: TransportAddress, if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay) connect(address, bufferSize, child, localAddress, mappedFlags, dualstack) +proc closed*(server: StreamServer): bool = + server.status == ServerStatus.Closed + proc close*(server: StreamServer) = ## Release ``server`` resources. ## @@ -1832,22 +1839,11 @@ proc close*(server: StreamServer) = else: server.sock.closeSocket(continuation) -proc closeWait*(server: StreamServer): Future[void] {. - async: (raw: true, raises: []).} = +proc closeWait*(server: StreamServer): Future[void] {.async: (raises: []).} = ## Close server ``server`` and release all resources. - let retFuture = newFuture[void]( - "stream.server.closeWait", {FutureFlag.OwnCancelSchedule}) - - proc continuation(udata: pointer) = - retFuture.complete() - - server.close() - - if not(server.loopFuture.finished()): - server.loopFuture.addCallback(continuation, cast[pointer](retFuture)) - else: - retFuture.complete() - retFuture + if not server.closed(): + server.close() + await noCancel(server.join()) proc getBacklogSize(backlog: int): cint = doAssert(backlog >= 0 and backlog <= high(int32)) @@ -2058,7 +2054,9 @@ proc createStreamServer*(host: TransportAddress, sres.init = init sres.bufferSize = bufferSize sres.status = Starting - sres.loopFuture = newFuture[void]("stream.transport.server") + sres.loopFuture = asyncloop.init( + Future[void].Raising([]), "stream.transport.server", + {FutureFlag.OwnCancelSchedule}) sres.udata = udata sres.dualstack = dualstack if localAddress.family == AddressFamily.None: @@ -2630,6 +2628,23 @@ proc join*(transp: StreamTransport): Future[void] {. retFuture.complete() return retFuture +proc closed*(transp: StreamTransport): bool {.inline.} = + ## Returns ``true`` if transport in closed state. + ({ReadClosed, WriteClosed} * transp.state != {}) + +proc finished*(transp: StreamTransport): bool {.inline.} = + ## Returns ``true`` if transport in finished (EOF) state. + ({ReadEof, WriteEof} * transp.state != {}) + +proc failed*(transp: StreamTransport): bool {.inline.} = + ## Returns ``true`` if transport in error state. + ({ReadError, WriteError} * transp.state != {}) + +proc running*(transp: StreamTransport): bool {.inline.} = + ## Returns ``true`` if transport is still pending. + ({ReadClosed, ReadEof, ReadError, + WriteClosed, WriteEof, WriteError} * transp.state == {}) + proc close*(transp: StreamTransport) = ## Closes and frees resources of transport ``transp``. ## @@ -2672,31 +2687,11 @@ proc close*(transp: StreamTransport) = elif transp.kind == TransportKind.Socket: closeSocket(transp.fd, continuation) -proc closeWait*(transp: StreamTransport): Future[void] {. - async: (raw: true, raises: []).} = +proc closeWait*(transp: StreamTransport): Future[void] {.async: (raises: []).} = ## Close and frees resources of transport ``transp``. - let retFuture = newFuture[void]( - "stream.transport.closeWait", {FutureFlag.OwnCancelSchedule}) - - if {ReadClosed, WriteClosed} * transp.state != {}: - retFuture.complete() - return retFuture - - proc continuation(udata: pointer) {.gcsafe.} = - retFuture.complete() - - proc cancellation(udata: pointer) {.gcsafe.} = - # We are not going to change the state of `retFuture` to cancelled, so we - # will prevent the entire sequence of Futures from being cancelled. - discard - - transp.close() - if transp.future.finished(): - retFuture.complete() - else: - transp.future.addCallback(continuation, cast[pointer](retFuture)) - retFuture.cancelCallback = cancellation - retFuture + if not transp.closed(): + transp.close() + await noCancel(transp.join()) proc shutdownWait*(transp: StreamTransport): Future[void] {. async: (raw: true, raises: [TransportError, CancelledError]).} = @@ -2756,23 +2751,6 @@ proc shutdownWait*(transp: StreamTransport): Future[void] {. callSoon(continuation, nil) retFuture -proc closed*(transp: StreamTransport): bool {.inline.} = - ## Returns ``true`` if transport in closed state. - ({ReadClosed, WriteClosed} * transp.state != {}) - -proc finished*(transp: StreamTransport): bool {.inline.} = - ## Returns ``true`` if transport in finished (EOF) state. - ({ReadEof, WriteEof} * transp.state != {}) - -proc failed*(transp: StreamTransport): bool {.inline.} = - ## Returns ``true`` if transport in error state. - ({ReadError, WriteError} * transp.state != {}) - -proc running*(transp: StreamTransport): bool {.inline.} = - ## Returns ``true`` if transport is still pending. - ({ReadClosed, ReadEof, ReadError, - WriteClosed, WriteEof, WriteError} * transp.state == {}) - proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil, bufferSize = DefaultStreamBufferSize ): Result[StreamTransport, OSErrorCode] = From f0a2d4df61302d24baa6c0f1c257f92045c9ee57 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 8 Jan 2024 14:54:50 +0100 Subject: [PATCH 04/11] Feature flag for raises support (#488) Feature flags allow consumers of chronos to target versions with and without certain features via compile-time selection. The first feature flag added is for raise tracking support. --- chronos/config.nim | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chronos/config.nim b/chronos/config.nim index 21c3132..47bf669 100644 --- a/chronos/config.nim +++ b/chronos/config.nim @@ -11,6 +11,15 @@ ## `chronosDebug` can be defined to enable several debugging helpers that come ## with a runtime cost - it is recommeneded to not enable these in production ## code. +## +## In this file we also declare feature flags starting with `chronosHas...` - +## these constants are declared when a feature exists in a particular release - +## each flag is declared as an integer starting at 0 during experimental +## development, 1 when feature complete and higher numbers when significant +## functionality has been added. If a feature ends up being removed (or changed +## in a backwards-incompatible way), the feature flag will be removed or renamed +## also - you can use `when declared(chronosHasXxx): when chronosHasXxx >= N:` +## to require a particular version. const chronosHandleException* {.booldefine.}: bool = false ## Remap `Exception` to `AsyncExceptionError` for all `async` functions. @@ -79,6 +88,9 @@ const "" ## OS polling engine type which is going to be used by chronos. + chronosHasRaises* = 0 + ## raises effect support via `async: (raises: [])` + when defined(chronosStrictException): {.warning: "-d:chronosStrictException has been deprecated in favor of handleException".} # In chronos v3, this setting was used as the opposite of From b02b9608c3c4a4815da39583847dad026d89781d Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Fri, 12 Jan 2024 15:27:36 +0200 Subject: [PATCH 05/11] HTTP server middleware implementation. (#483) * HTTP server middleware implementation and test. * Address review comments. * Address review comments. --- chronos/apps/http/httpserver.nim | 339 ++++++++++++++++++++--------- docs/examples/middleware.nim | 130 +++++++++++ docs/src/SUMMARY.md | 1 + docs/src/examples.md | 2 + docs/src/http_server_middleware.md | 102 +++++++++ tests/testhttpserver.nim | 299 ++++++++++++++++++++++++- 6 files changed, 771 insertions(+), 102 deletions(-) create mode 100644 docs/examples/middleware.nim create mode 100644 docs/src/http_server_middleware.md diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index 9646956..c716d14 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -11,11 +11,14 @@ import std/[tables, uri, strutils] import stew/[base10], httputils, results -import ../../asyncloop, ../../asyncsync +import ../../[asyncloop, asyncsync] import ../../streams/[asyncstream, boundstream, chunkstream] import "."/[httptable, httpcommon, multipart] +from ../../transports/common import TransportAddress, ServerFlags, `$`, `==` + export asyncloop, asyncsync, httptable, httpcommon, httputils, multipart, asyncstream, boundstream, chunkstream, uri, tables, results +export TransportAddress, ServerFlags, `$`, `==` type HttpServerFlags* {.pure.} = enum @@ -107,6 +110,7 @@ type maxRequestBodySize*: int processCallback*: HttpProcessCallback2 createConnCallback*: HttpConnectionCallback + middlewares: seq[HttpProcessCallback2] HttpServerRef* = ref HttpServer @@ -158,6 +162,16 @@ type HttpConnectionRef* = ref HttpConnection + MiddlewareHandleCallback* = proc( + middleware: HttpServerMiddlewareRef, request: RequestFence, + handler: HttpProcessCallback2): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} + + HttpServerMiddleware* = object of RootObj + handler*: MiddlewareHandleCallback + + HttpServerMiddlewareRef* = ref HttpServerMiddleware + ByteChar* = string | seq[byte] proc init(htype: typedesc[HttpProcessError], error: HttpServerError, @@ -175,6 +189,8 @@ proc init(htype: typedesc[HttpProcessError], proc defaultResponse*(exc: ref CatchableError): HttpResponseRef +proc defaultResponse*(msg: HttpMessage): HttpResponseRef + proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef, transp: StreamTransport, connectionId: string): HttpConnectionHolderRef = @@ -188,20 +204,54 @@ proc createConnection(server: HttpServerRef, transp: StreamTransport): Future[HttpConnectionRef] {. async: (raises: [CancelledError, HttpConnectionError]).} -proc new*(htype: typedesc[HttpServerRef], - address: TransportAddress, - processCallback: HttpProcessCallback2, - serverFlags: set[HttpServerFlags] = {}, - socketFlags: set[ServerFlags] = {ReuseAddr}, - serverUri = Uri(), - serverIdent = "", - maxConnections: int = -1, - bufferSize: int = 4096, - backlogSize: int = DefaultBacklogSize, - httpHeadersTimeout = 10.seconds, - maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576, - dualstack = DualStackType.Auto): HttpResult[HttpServerRef] = +proc prepareMiddlewares( + requestProcessCallback: HttpProcessCallback2, + middlewares: openArray[HttpServerMiddlewareRef] + ): seq[HttpProcessCallback2] = + var + handlers: seq[HttpProcessCallback2] + currentHandler = requestProcessCallback + + if len(middlewares) == 0: + return handlers + + let mws = @middlewares + handlers = newSeq[HttpProcessCallback2](len(mws)) + + for index in countdown(len(mws) - 1, 0): + let processor = + block: + var res: HttpProcessCallback2 + closureScope: + let + middleware = mws[index] + realHandler = currentHandler + res = + proc(request: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError], raw: true).} = + middleware.handler(middleware, request, realHandler) + res + handlers[index] = processor + currentHandler = processor + handlers + +proc new*( + htype: typedesc[HttpServerRef], + address: TransportAddress, + processCallback: HttpProcessCallback2, + serverFlags: set[HttpServerFlags] = {}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + serverIdent = "", + maxConnections: int = -1, + bufferSize: int = 4096, + backlogSize: int = DefaultBacklogSize, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + dualstack = DualStackType.Auto, + middlewares: openArray[HttpServerMiddlewareRef] = [] + ): HttpResult[HttpServerRef] = let serverUri = if len(serverUri.hostname) > 0: @@ -240,24 +290,28 @@ proc new*(htype: typedesc[HttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initOrderedTable[string, HttpConnectionHolderRef]() + connections: initOrderedTable[string, HttpConnectionHolderRef](), + middlewares: prepareMiddlewares(processCallback, middlewares) ) ok(res) -proc new*(htype: typedesc[HttpServerRef], - address: TransportAddress, - processCallback: HttpProcessCallback, - serverFlags: set[HttpServerFlags] = {}, - socketFlags: set[ServerFlags] = {ReuseAddr}, - serverUri = Uri(), - serverIdent = "", - maxConnections: int = -1, - bufferSize: int = 4096, - backlogSize: int = DefaultBacklogSize, - httpHeadersTimeout = 10.seconds, - maxHeadersSize: int = 8192, - maxRequestBodySize: int = 1_048_576, - dualstack = DualStackType.Auto): HttpResult[HttpServerRef] {. +proc new*( + htype: typedesc[HttpServerRef], + address: TransportAddress, + processCallback: HttpProcessCallback, + serverFlags: set[HttpServerFlags] = {}, + socketFlags: set[ServerFlags] = {ReuseAddr}, + serverUri = Uri(), + serverIdent = "", + maxConnections: int = -1, + bufferSize: int = 4096, + backlogSize: int = DefaultBacklogSize, + httpHeadersTimeout = 10.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576, + dualstack = DualStackType.Auto, + middlewares: openArray[HttpServerMiddlewareRef] = [] + ): HttpResult[HttpServerRef] {. deprecated: "Callback could raise only CancelledError, annotate with " & "{.async: (raises: [CancelledError]).}".} = @@ -273,7 +327,7 @@ proc new*(htype: typedesc[HttpServerRef], HttpServerRef.new(address, wrap, serverFlags, socketFlags, serverUri, serverIdent, maxConnections, bufferSize, backlogSize, httpHeadersTimeout, maxHeadersSize, maxRequestBodySize, - dualstack) + dualstack, middlewares) proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] = var defaultFlags: set[HttpServerFlags] = {} @@ -345,6 +399,18 @@ proc defaultResponse*(exc: ref CatchableError): HttpResponseRef = else: HttpResponseRef(state: HttpResponseState.ErrorCode, status: Http503) +proc defaultResponse*(msg: HttpMessage): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: msg.code) + +proc defaultResponse*(err: HttpProcessError): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: err.code) + +proc dropResponse*(): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.Failed) + +proc codeResponse*(status: HttpCode): HttpResponseRef = + HttpResponseRef(state: HttpResponseState.ErrorCode, status: status) + proc dumbResponse*(): HttpResponseRef {. deprecated: "Please use defaultResponse() instead".} = ## Create an empty response to return when request processor got no request. @@ -362,29 +428,21 @@ proc hasBody*(request: HttpRequestRef): bool = request.requestFlags * {HttpRequestFlags.BoundBody, HttpRequestFlags.UnboundBody} != {} -proc prepareRequest(conn: HttpConnectionRef, - req: HttpRequestHeader): HttpResultMessage[HttpRequestRef] = - var request = HttpRequestRef(connection: conn, state: HttpState.Alive) +func new(t: typedesc[HttpRequestRef], conn: HttpConnectionRef): HttpRequestRef = + HttpRequestRef(connection: conn, state: HttpState.Alive) - if req.version notin {HttpVersion10, HttpVersion11}: - return err(HttpMessage.init(Http505, "Unsupported HTTP protocol version")) +proc updateRequest*(request: HttpRequestRef, scheme: string, meth: HttpMethod, + version: HttpVersion, requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. - request.scheme = - if HttpServerFlags.Secure in conn.server.flags: - "https" - else: - "http" - - request.version = req.version - request.meth = req.meth - - request.rawPath = - block: - let res = req.uri() - if len(res) == 0: - return err(HttpMessage.init(Http400, "Invalid request URI")) - res + # Store request version and call method. + request.scheme = scheme + request.version = version + request.meth = meth + # Processing request's URI + request.rawPath = requestUri request.uri = if request.rawPath != "*": let uri = parseUri(request.rawPath) @@ -396,10 +454,11 @@ proc prepareRequest(conn: HttpConnectionRef, uri.path = "*" uri + # Conversion of request query string to HttpTable. request.query = block: let queryFlags = - if QueryCommaSeparatedArray in conn.server.flags: + if QueryCommaSeparatedArray in request.connection.server.flags: {QueryParamsFlag.CommaSeparatedArray} else: {} @@ -408,22 +467,8 @@ proc prepareRequest(conn: HttpConnectionRef, table.add(key, value) table - request.headers = - block: - var table = HttpTable.init() - # Retrieve headers and values - for key, value in req.headers(): - table.add(key, value) - # Validating HTTP request headers - # Some of the headers must be present only once. - if table.count(ContentTypeHeader) > 1: - return err(HttpMessage.init(Http400, "Multiple Content-Type headers")) - if table.count(ContentLengthHeader) > 1: - return err(HttpMessage.init(Http400, "Multiple Content-Length headers")) - if table.count(TransferEncodingHeader) > 1: - return err(HttpMessage.init(Http400, - "Multuple Transfer-Encoding headers")) - table + # Store request headers + request.headers = headers # Preprocessing "Content-Encoding" header. request.contentEncoding = @@ -443,15 +488,17 @@ proc prepareRequest(conn: HttpConnectionRef, # steps to reveal information about body. request.contentLength = if ContentLengthHeader in request.headers: + # Request headers has `Content-Length` header present. let length = request.headers.getInt(ContentLengthHeader) if length != 0: if request.meth == MethodTrace: let msg = "TRACE requests could not have request body" return err(HttpMessage.init(Http400, msg)) - # Because of coversion to `int` we should avoid unexpected OverflowError. + # Because of coversion to `int` we should avoid unexpected + # OverflowError. if length > uint64(high(int)): return err(HttpMessage.init(Http413, "Unsupported content length")) - if length > uint64(conn.server.maxRequestBodySize): + if length > uint64(request.connection.server.maxRequestBodySize): return err(HttpMessage.init(Http413, "Content length exceeds limits")) request.requestFlags.incl(HttpRequestFlags.BoundBody) int(length) @@ -459,6 +506,7 @@ proc prepareRequest(conn: HttpConnectionRef, 0 else: if TransferEncodingFlags.Chunked in request.transferEncoding: + # Request headers has "Transfer-Encoding: chunked" header present. if request.meth == MethodTrace: let msg = "TRACE requests could not have request body" return err(HttpMessage.init(Http400, msg)) @@ -466,8 +514,9 @@ proc prepareRequest(conn: HttpConnectionRef, 0 if request.hasBody(): - # If request has body, we going to understand how its encoded. + # If the request has a body, we will determine how it is encoded. if ContentTypeHeader in request.headers: + # Request headers has "Content-Type" header present. let contentType = getContentType(request.headers.getList(ContentTypeHeader)).valueOr: let msg = "Incorrect or missing Content-Type header" @@ -477,12 +526,67 @@ proc prepareRequest(conn: HttpConnectionRef, elif contentType == MultipartContentType: request.requestFlags.incl(HttpRequestFlags.MultipartForm) request.contentTypeData = Opt.some(contentType) - + # If `Expect` header is present, we will handle expectation procedure. if ExpectHeader in request.headers: let expectHeader = request.headers.getString(ExpectHeader) if strip(expectHeader).toLowerAscii() == "100-continue": request.requestFlags.incl(HttpRequestFlags.ClientExpect) + ok() + +proc updateRequest*(request: HttpRequestRef, meth: HttpMethod, + requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, meth, request.version, requestUri, + headers) + +proc updateRequest*(request: HttpRequestRef, requestUri: string, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + requestUri, headers) + +proc updateRequest*(request: HttpRequestRef, + requestUri: string): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + requestUri, request.headers) + +proc updateRequest*(request: HttpRequestRef, + headers: HttpTable): HttpResultMessage[void] = + ## Update HTTP request object using base request object with new properties. + updateRequest(request, request.scheme, request.meth, request.version, + request.rawPath, headers) + +proc prepareRequest(conn: HttpConnectionRef, + req: HttpRequestHeader): HttpResultMessage[HttpRequestRef] = + let + request = HttpRequestRef.new(conn) + scheme = + if HttpServerFlags.Secure in conn.server.flags: + "https" + else: + "http" + headers = + block: + var table = HttpTable.init() + # Retrieve headers and values + for key, value in req.headers(): + table.add(key, value) + # Validating HTTP request headers + # Some of the headers must be present only once. + if table.count(ContentTypeHeader) > 1: + return err(HttpMessage.init(Http400, + "Multiple Content-Type headers")) + if table.count(ContentLengthHeader) > 1: + return err(HttpMessage.init(Http400, + "Multiple Content-Length headers")) + if table.count(TransferEncodingHeader) > 1: + return err(HttpMessage.init(Http400, + "Multuple Transfer-Encoding headers")) + table + ? updateRequest(request, scheme, req.meth, req.version, req.uri(), headers) trackCounter(HttpServerRequestTrackerName) ok(request) @@ -736,16 +840,19 @@ proc sendDefaultResponse( # Response was ignored, so we respond with not found. await conn.sendErrorResponse(version, Http404, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.Prepared: # Response was prepared but not sent, so we can respond with some # error code await conn.sendErrorResponse(HttpVersion11, Http409, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.ErrorCode: # Response with error code await conn.sendErrorResponse(version, response.status, false) + response.setResponseState(HttpResponseState.Finished) HttpProcessExitType.Immediate of HttpResponseState.Sending, HttpResponseState.Failed, HttpResponseState.Cancelled: @@ -755,6 +862,7 @@ proc sendDefaultResponse( # Response was ignored, so we respond with not found. await conn.sendErrorResponse(version, Http404, keepConnection.toBool()) + response.setResponseState(HttpResponseState.Finished) keepConnection of HttpResponseState.Finished: keepConnection @@ -878,6 +986,25 @@ proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] = if isNil(connection): return Opt.none(TransportAddress) getRemoteAddress(connection.transp) +proc getLocalAddress(transp: StreamTransport): Opt[TransportAddress] = + if isNil(transp): return Opt.none(TransportAddress) + try: + Opt.some(transp.localAddress()) + except TransportOsError: + Opt.none(TransportAddress) + +proc getLocalAddress(connection: HttpConnectionRef): Opt[TransportAddress] = + if isNil(connection): return Opt.none(TransportAddress) + getLocalAddress(connection.transp) + +proc remote*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns remote address of HTTP request's connection. + request.connection.getRemoteAddress() + +proc local*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns local address of HTTP request's connection. + request.connection.getLocalAddress() + proc getRequestFence*(server: HttpServerRef, connection: HttpConnectionRef): Future[RequestFence] {. async: (raises: []).} = @@ -920,6 +1047,14 @@ proc getConnectionFence*(server: HttpServerRef, ConnectionFence.err(HttpProcessError.init( HttpServerError.DisconnectError, exc, address, Http400)) +proc invokeProcessCallback(server: HttpServerRef, + req: RequestFence): Future[HttpResponseRef] {. + async: (raw: true, raises: [CancelledError]).} = + if len(server.middlewares) > 0: + server.middlewares[0](req) + else: + server.processCallback(req) + proc processRequest(server: HttpServerRef, connection: HttpConnectionRef, connId: string): Future[HttpProcessExitType] {. @@ -941,7 +1076,7 @@ proc processRequest(server: HttpServerRef, try: let response = try: - await connection.server.processCallback(requestFence) + await invokeProcessCallback(connection.server, requestFence) except CancelledError: # Cancelled, exiting return HttpProcessExitType.Immediate @@ -962,7 +1097,7 @@ proc processLoop(holder: HttpConnectionHolderRef) {.async: (raises: []).} = if res.isErr(): if res.error.kind != HttpServerError.InterruptError: discard await noCancel( - server.processCallback(RequestFence.err(res.error))) + invokeProcessCallback(server, RequestFence.err(res.error))) server.connections.del(connectionId) return res.get() @@ -1160,31 +1295,6 @@ proc post*(req: HttpRequestRef): Future[HttpTable] {. elif HttpRequestFlags.UnboundBody in req.requestFlags: raiseHttpProtocolError(Http400, "Unsupported request body") -proc setHeader*(resp: HttpResponseRef, key, value: string) = - ## Sets value of header ``key`` to ``value``. - doAssert(resp.getResponseState() == HttpResponseState.Empty) - resp.headersTable.set(key, value) - -proc setHeaderDefault*(resp: HttpResponseRef, key, value: string) = - ## Sets value of header ``key`` to ``value``, only if header ``key`` is not - ## present in the headers table. - discard resp.headersTable.hasKeyOrPut(key, value) - -proc addHeader*(resp: HttpResponseRef, key, value: string) = - ## Adds value ``value`` to header's ``key`` value. - doAssert(resp.getResponseState() == HttpResponseState.Empty) - resp.headersTable.add(key, value) - -proc getHeader*(resp: HttpResponseRef, key: string, - default: string = ""): string = - ## Returns value of header with name ``name`` or ``default``, if header is - ## not present in the table. - resp.headersTable.getString(key, default) - -proc hasHeader*(resp: HttpResponseRef, key: string): bool = - ## Returns ``true`` if header with name ``key`` present in the headers table. - key in resp.headersTable - template checkPending(t: untyped) = let currentState = t.getResponseState() doAssert(currentState == HttpResponseState.Empty, @@ -1199,10 +1309,41 @@ template checkStreamResponseState(t: untyped) = {HttpResponseState.Prepared, HttpResponseState.Sending}, "Response is in the wrong state") +template checkResponseCanBeModified(t: untyped) = + doAssert(t.getResponseState() in + {HttpResponseState.Empty, HttpResponseState.ErrorCode}, + "Response could not be modified at this stage") + template checkPointerLength(t1, t2: untyped) = doAssert(not(isNil(t1)), "pbytes must not be nil") doAssert(t2 >= 0, "nbytes should be bigger or equal to zero") +proc setHeader*(resp: HttpResponseRef, key, value: string) = + ## Sets value of header ``key`` to ``value``. + checkResponseCanBeModified(resp) + resp.headersTable.set(key, value) + +proc setHeaderDefault*(resp: HttpResponseRef, key, value: string) = + ## Sets value of header ``key`` to ``value``, only if header ``key`` is not + ## present in the headers table. + checkResponseCanBeModified(resp) + discard resp.headersTable.hasKeyOrPut(key, value) + +proc addHeader*(resp: HttpResponseRef, key, value: string) = + ## Adds value ``value`` to header's ``key`` value. + checkResponseCanBeModified(resp) + resp.headersTable.add(key, value) + +proc getHeader*(resp: HttpResponseRef, key: string, + default: string = ""): string = + ## Returns value of header with name ``name`` or ``default``, if header is + ## not present in the table. + resp.headersTable.getString(key, default) + +proc hasHeader*(resp: HttpResponseRef, key: string): bool = + ## Returns ``true`` if header with name ``key`` present in the headers table. + key in resp.headersTable + func createHeaders(resp: HttpResponseRef): string = var answer = $(resp.version) & " " & $(resp.status) & "\r\n" for k, v in resp.headersTable.stringItems(): diff --git a/docs/examples/middleware.nim b/docs/examples/middleware.nim new file mode 100644 index 0000000..9d06a89 --- /dev/null +++ b/docs/examples/middleware.nim @@ -0,0 +1,130 @@ +import chronos/apps/http/httpserver + +{.push raises: [].} + +proc firstMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let request = reqfence.get() + var headers = request.headers + + if request.uri.path.startsWith("/path/to/hidden/resources"): + headers.add("X-Filter", "drop") + elif request.uri.path.startsWith("/path/to/blocked/resources"): + headers.add("X-Filter", "block") + else: + headers.add("X-Filter", "pass") + + # Updating request by adding new HTTP header `X-Filter`. + let res = request.updateRequest(headers) + if res.isErr(): + # We use default error handler in case of error which will respond with + # proper HTTP status code error. + return defaultResponse(res.error) + + # Calling next handler. + await nextHandler(reqfence) + +proc secondMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let + request = reqfence.get() + filtered = request.headers.getString("X-Filter", "pass") + + if filtered == "drop": + # Force HTTP server to drop connection with remote peer. + dropResponse() + elif filtered == "block": + # Force HTTP server to respond with HTTP `404 Not Found` error code. + codeResponse(Http404) + else: + # Calling next handler. + await nextHandler(reqfence) + +proc thirdMiddlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Ignore request errors + return await nextHandler(reqfence) + + let request = reqfence.get() + echo "QUERY = [", request.rawPath, "]" + echo request.headers + try: + if request.uri.path == "/path/to/plugin/resources/page1": + await request.respond(Http200, "PLUGIN PAGE1") + elif request.uri.path == "/path/to/plugin/resources/page2": + await request.respond(Http200, "PLUGIN PAGE2") + else: + # Calling next handler. + await nextHandler(reqfence) + except HttpWriteError as exc: + # We use default error handler if we unable to send response. + defaultResponse(exc) + +proc mainHandler( + reqfence: RequestFence +): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + return defaultResponse() + + let request = reqfence.get() + try: + if request.uri.path == "/path/to/original/page1": + await request.respond(Http200, "ORIGINAL PAGE1") + elif request.uri.path == "/path/to/original/page2": + await request.respond(Http200, "ORIGINAL PAGE2") + else: + # Force HTTP server to respond with `404 Not Found` status code. + codeResponse(Http404) + except HttpWriteError as exc: + defaultResponse(exc) + +proc middlewareExample() {.async: (raises: []).} = + let + middlewares = [ + HttpServerMiddlewareRef(handler: firstMiddlewareHandler), + HttpServerMiddlewareRef(handler: secondMiddlewareHandler), + HttpServerMiddlewareRef(handler: thirdMiddlewareHandler) + ] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + boundAddress = + if isAvailable(AddressFamily.IPv6): + AnyAddress6 + else: + AnyAddress + res = HttpServerRef.new(boundAddress, mainHandler, + socketFlags = socketFlags, + middlewares = middlewares) + + doAssert(res.isOk(), "Unable to start HTTP server") + let server = res.get() + server.start() + let address = server.instance.localAddress() + echo "HTTP server running on ", address + try: + await server.join() + except CancelledError: + discard + finally: + await server.stop() + await server.closeWait() + +when isMainModule: + waitFor(middlewareExample()) diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 4f2ee56..f834367 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -8,6 +8,7 @@ - [Errors and exceptions](./error_handling.md) - [Tips, tricks and best practices](./tips.md) - [Porting code to `chronos`](./porting.md) +- [HTTP server middleware](./http_server_middleware.md) # Developer guide diff --git a/docs/src/examples.md b/docs/src/examples.md index c71247c..49c6dc4 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -16,3 +16,5 @@ Examples are available in the [`docs/examples/`](https://github.com/status-im/ni * [httpget](https://github.com/status-im/nim-chronos/tree/master/docs/examples/httpget.nim) - Downloading a web page using the http client * [twogets](https://github.com/status-im/nim-chronos/tree/master/docs/examples/twogets.nim) - Download two pages concurrently +* [middleware](https://github.com/status-im/nim-chronos/tree/master/docs/examples/middleware.nim) +- Deploy multiple HTTP server middlewares diff --git a/docs/src/http_server_middleware.md b/docs/src/http_server_middleware.md new file mode 100644 index 0000000..6edd9b5 --- /dev/null +++ b/docs/src/http_server_middleware.md @@ -0,0 +1,102 @@ +## HTTP server middleware + +Chronos provides a powerful mechanism for customizing HTTP request handlers via +middlewares. + +A middleware is a coroutine that can modify, block or filter HTTP request. + +Single HTTP server could support unlimited number of middlewares, but you need to consider that each request in worst case could go through all the middlewares, and therefore a huge number of middlewares can have a significant impact on HTTP server performance. + +Order of middlewares is also important: right after HTTP server has received request, it will be sent to the first middleware in list, and each middleware will be responsible for passing control to other middlewares. Therefore, when building a list, it would be a good idea to place the request handlers at the end of the list, while keeping the middleware that could block or modify the request at the beginning of the list. + +Middleware could also modify HTTP server request, and these changes will be visible to all handlers (either middlewares or the original request handler). This can be done using the following helpers: + +```nim + proc updateRequest*(request: HttpRequestRef, scheme: string, meth: HttpMethod, + version: HttpVersion, requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, meth: HttpMethod, + requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, requestUri: string, + headers: HttpTable): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, + requestUri: string): HttpResultMessage[void] + + proc updateRequest*(request: HttpRequestRef, + headers: HttpTable): HttpResultMessage[void] +``` + +As you can see all the HTTP request parameters could be modified: request method, version, request path and request headers. + +Middleware could also use helpers to obtain more information about remote and local addresses of request's connection (this could be helpful when you need to do some IP address filtering). + +```nim + proc remote*(request: HttpRequestRef): Opt[TransportAddress] + ## Returns remote address of HTTP request's connection. + proc local*(request: HttpRequestRef): Opt[TransportAddress] = + ## Returns local address of HTTP request's connection. +``` + +Every middleware is the coroutine which looks like this: + +```nim + proc middlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = +``` + +Where `middleware` argument is the object which could hold some specific values, `reqfence` is HTTP request which is enclosed with HTTP server error information and `nextHandler` is reference to next request handler, it could be either middleware handler or the original request processing callback handler. + +```nim + await nextHandler(reqfence) +``` + +You should perform await for the response from the `nextHandler(reqfence)`. Usually you should call next handler when you dont want to handle request or you dont know how to handle it, for example: + +```nim + proc middlewareHandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # We dont know or do not want to handle failed requests, so we call next handler. + return await nextHandler(reqfence) + let request = reqfence.get() + if request.uri.path == "/path/we/able/to/respond": + try: + # Sending some response. + await request.respond(Http200, "TEST") + except HttpWriteError as exc: + # We could also return default response for exception or other types of error. + defaultResponse(exc) + elif request.uri.path == "/path/for/rewrite": + # We going to modify request object for this request, next handler will receive it with different request path. + let res = request.updateRequest("/path/to/new/location") + if res.isErr(): + return defaultResponse(res.error) + await nextHandler(reqfence) + elif request.uri.path == "/restricted/path": + if request.remote().isNone(): + # We can't obtain remote address, so we force HTTP server to respond with `401 Unauthorized` status code. + return codeResponse(Http401) + if $(request.remote().get()).startsWith("127.0.0.1"): + # Remote peer's address starts with "127.0.0.1", sending proper response. + await request.respond(Http200, "AUTHORIZED") + else: + # Force HTTP server to respond with `403 Forbidden` status code. + codeResponse(Http403) + elif request.uri.path == "/blackhole": + # Force HTTP server to drop connection with remote peer. + dropResponse() + else: + # All other requests should be handled by somebody else. + await nextHandler(reqfence) +``` + diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 0183f1b..91064f5 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -18,9 +18,16 @@ suite "HTTP server testing suite": TooBigTest = enum GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest TestHttpResponse = object + status: int headers: HttpTable data: string + FirstMiddlewareRef = ref object of HttpServerMiddlewareRef + someInteger: int + + SecondMiddlewareRef = ref object of HttpServerMiddlewareRef + someString: string + proc httpClient(address: TransportAddress, data: string): Future[string] {.async.} = var transp: StreamTransport @@ -50,7 +57,7 @@ suite "HTTP server testing suite": zeroMem(addr buffer[0], len(buffer)) await transp.readExactly(addr buffer[0], length) let data = bytesToString(buffer.toOpenArray(0, length - 1)) - let headers = + let (status, headers) = block: let resp = parseResponse(hdata, false) if resp.failed(): @@ -58,8 +65,38 @@ suite "HTTP server testing suite": var res = HttpTable.init() for key, value in resp.headers(hdata): res.add(key, value) - res - return TestHttpResponse(headers: headers, data: data) + (resp.code, res) + TestHttpResponse(status: status, headers: headers, data: data) + + proc httpClient3(address: TransportAddress, + data: string): Future[TestHttpResponse] {.async.} = + var + transp: StreamTransport + buffer = newSeq[byte](4096) + sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + try: + transp = await connect(address) + if len(data) > 0: + let wres = await transp.write(data) + if wres != len(data): + raise newException(ValueError, "Unable to write full request") + let hres = await transp.readUntil(addr buffer[0], len(buffer), sep) + var hdata = @buffer + hdata.setLen(hres) + var rres = bytesToString(await transp.read()) + let (status, headers) = + block: + let resp = parseResponse(hdata, false) + if resp.failed(): + raise newException(ValueError, "Unable to decode response headers") + var res = HttpTable.init() + for key, value in resp.headers(hdata): + res.add(key, value) + (resp.code, res) + TestHttpResponse(status: status, headers: headers, data: rres) + finally: + if not(isNil(transp)): + await closeWait(transp) proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} = var serverRes = false @@ -1490,5 +1527,261 @@ suite "HTTP server testing suite": await server.stop() await server.closeWait() + asyncTest "HTTP middleware request filtering test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = FirstMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + if request.uri.path == "/first": + # This is request we are waiting for, so we going to process it. + try: + await request.respond(Http200, $mw.someInteger) + except HttpWriteError as exc: + defaultResponse(exc) + else: + # We know nothing about request's URI, so we pass this request to the + # next handler which could process such request. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc init(t: typedesc[SecondMiddlewareRef], + data: string): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = SecondMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + + if request.uri.path == "/second": + # This is request we are waiting for, so we going to process it. + try: + await request.respond(Http200, mw.someString) + except HttpWriteError as exc: + defaultResponse(exc) + else: + # We know nothing about request's URI, so we pass this request to the + # next handler which could process such request. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + SecondMiddlewareRef(someString: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + if request.uri.path == "/test": + try: + await request.respond(Http200, "ORIGIN") + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370), + SecondMiddlewareRef.init("SECOND")] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + req4 = "GET /noway HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp2 = await httpClient3(address, req2) + resp3 = await httpClient3(address, req3) + resp4 = await httpClient3(address, req4) + + check: + resp1.status == 200 + resp1.data == "ORIGIN" + resp2.status == 200 + resp2.data == "655370" + resp3.status == 200 + resp3.data == "SECOND" + resp4.status == 404 + + await server.stop() + await server.closeWait() + + asyncTest "HTTP middleware request modification test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + let mw = FirstMiddlewareRef(middleware) + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let + request = reqfence.get() + modifiedUri = "/modified/" & $mw.someInteger & request.rawPath + var modifiedHeaders = request.headers + modifiedHeaders.add("X-Modified", "test-value") + + let res = request.updateRequest(modifiedUri, modifiedHeaders) + if res.isErr(): + return defaultResponse(res.error) + + # We sending modified request to the next handler. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + try: + await request.respond(Http200, request.rawPath & ":" & + request.headers.getString("x-modified")) + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370)] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + req4 = "GET /noway HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp2 = await httpClient3(address, req2) + resp3 = await httpClient3(address, req3) + resp4 = await httpClient3(address, req4) + + check: + resp1.status == 200 + resp1.data == "/modified/655370/test:test-value" + resp2.status == 200 + resp2.data == "/modified/655370/first:test-value" + resp3.status == 200 + resp3.data == "/modified/655370/second:test-value" + resp4.status == 200 + resp4.data == "/modified/655370/noway:test-value" + + await server.stop() + await server.closeWait() + + asyncTest "HTTP middleware request blocking test": + proc init(t: typedesc[FirstMiddlewareRef], + data: int): HttpServerMiddlewareRef = + proc shandler( + middleware: HttpServerMiddlewareRef, + reqfence: RequestFence, + nextHandler: HttpProcessCallback2 + ): Future[HttpResponseRef] {.async: (raises: [CancelledError]).} = + if reqfence.isErr(): + # Our handler is not supposed to handle request errors, so we + # call next handler in sequence which could process errors. + return await nextHandler(reqfence) + + let request = reqfence.get() + if request.uri.path == "/first": + # Blocking request by disconnecting remote peer. + dropResponse() + elif request.uri.path == "/second": + # Blocking request by sending HTTP error message with 401 code. + codeResponse(Http401) + else: + # Allow all other requests to be processed by next handler. + await nextHandler(reqfence) + + HttpServerMiddlewareRef( + FirstMiddlewareRef(someInteger: data, handler: shandler)) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async: (raises: [CancelledError]).} = + if r.isOk(): + let request = r.get() + try: + await request.respond(Http200, "ORIGIN") + except HttpWriteError as exc: + defaultResponse(exc) + else: + defaultResponse() + + let + middlewares = [FirstMiddlewareRef.init(655370)] + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + middlewares = middlewares) + check res.isOk() + + let server = res.get() + server.start() + let + address = server.instance.localAddress() + req1 = "GET /test HTTP/1.1\r\n\r\n" + req2 = "GET /first HTTP/1.1\r\n\r\n" + req3 = "GET /second HTTP/1.1\r\n\r\n" + resp1 = await httpClient3(address, req1) + resp3 = await httpClient3(address, req3) + + check: + resp1.status == 200 + resp1.data == "ORIGIN" + resp3.status == 401 + + let checked = + try: + let res {.used.} = await httpClient3(address, req2) + false + except TransportIncompleteError: + true + + check: + checked == true + + await server.stop() + await server.closeWait() + test "Leaks test": checkLeaks() From 92acf68b04070dfe8eb65bab71fbf63804979a16 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Fri, 12 Jan 2024 15:39:45 +0200 Subject: [PATCH 06/11] Fix examples documentation. --- docs/src/examples.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index 49c6dc4..0bcfc74 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -16,5 +16,4 @@ Examples are available in the [`docs/examples/`](https://github.com/status-im/ni * [httpget](https://github.com/status-im/nim-chronos/tree/master/docs/examples/httpget.nim) - Downloading a web page using the http client * [twogets](https://github.com/status-im/nim-chronos/tree/master/docs/examples/twogets.nim) - Download two pages concurrently -* [middleware](https://github.com/status-im/nim-chronos/tree/master/docs/examples/middleware.nim) -- Deploy multiple HTTP server middlewares +* [middleware](https://github.com/status-im/nim-chronos/tree/master/docs/examples/middleware.nim) - Deploy multiple HTTP server middlewares From 1021a7d29453ac184cc406483ff5fcdb73d48472 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 18 Jan 2024 13:34:16 +0100 Subject: [PATCH 07/11] check leaks after every test (#487) --- chronos/unittest2/asynctests.nim | 3 ++- tests/testasyncstream.nim | 22 +++++++++------------- tests/testdatagram.nim | 5 +++-- tests/testhttpclient.nim | 5 ++--- tests/testhttpserver.nim | 6 +++--- tests/testproc.nim | 6 +++--- tests/testshttpserver.nim | 5 ++--- tests/teststream.nim | 11 +++-------- 8 files changed, 27 insertions(+), 36 deletions(-) diff --git a/chronos/unittest2/asynctests.nim b/chronos/unittest2/asynctests.nim index 758e0a6..9e01dba 100644 --- a/chronos/unittest2/asynctests.nim +++ b/chronos/unittest2/asynctests.nim @@ -26,6 +26,7 @@ template checkLeaks*(name: string): untyped = ", closed = " & $ counter.closed check counter.opened == counter.closed -template checkLeaks*(): untyped = +proc checkLeaks*() = for key in getThreadDispatcher().trackerCounterKeys(): checkLeaks(key) + GC_fullCollect() diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index bd0207f..399eb63 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -84,6 +84,9 @@ proc createBigMessage(message: string, size: int): seq[byte] = res suite "AsyncStream test suite": + teardown: + checkLeaks() + test "AsyncStream(StreamTransport) readExactly() test": proc testReadExactly(): Future[bool] {.async.} = proc serveClient(server: StreamServer, @@ -256,9 +259,6 @@ suite "AsyncStream test suite": result = true check waitFor(testConsume()) == true - test "AsyncStream(StreamTransport) leaks test": - checkLeaks() - test "AsyncStream(AsyncStream) readExactly() test": proc testReadExactly2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, @@ -581,10 +581,10 @@ suite "AsyncStream test suite": check waitFor(testWriteEof()) == true - test "AsyncStream(AsyncStream) leaks test": +suite "ChunkedStream test suite": + teardown: checkLeaks() -suite "ChunkedStream test suite": test "ChunkedStream test vectors": const ChunkedVectors = [ ["4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n", @@ -890,10 +890,10 @@ suite "ChunkedStream test suite": check waitFor(testSmallChunk(262400, 4096, 61)) == true check waitFor(testSmallChunk(767309, 4457, 173)) == true - test "ChunkedStream leaks test": +suite "TLSStream test suite": + teardown: checkLeaks() -suite "TLSStream test suite": const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] test "Simple HTTPS connection": proc headerClient(address: TransportAddress, @@ -1023,10 +1023,9 @@ suite "TLSStream test suite": let res = waitFor checkTrustAnchors("Some message") check res == "Some message\r\n" - test "TLSStream leaks test": - checkLeaks() - suite "BoundedStream test suite": + teardown: + checkLeaks() type BoundarySizeTest = enum @@ -1402,6 +1401,3 @@ suite "BoundedStream test suite": return (writer1Res and writer2Res and readerRes) check waitFor(checkEmptyStreams()) == true - - test "BoundedStream leaks test": - checkLeaks() diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index bd33ef3..7b27c34 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -13,6 +13,9 @@ import ".."/chronos {.used.} suite "Datagram Transport test suite": + teardown: + checkLeaks() + const TestsCount = 2000 ClientsCount = 20 @@ -727,5 +730,3 @@ suite "Datagram Transport test suite": DualStackType.Auto, initTAddress("[::1]:0"))) == true else: skip() - test "Transports leak test": - checkLeaks() diff --git a/tests/testhttpclient.nim b/tests/testhttpclient.nim index 967f896..a468aae 100644 --- a/tests/testhttpclient.nim +++ b/tests/testhttpclient.nim @@ -74,6 +74,8 @@ N8r5CwGcIX/XPC3lKazzbZ8baA== """ suite "HTTP client testing suite": + teardown: + checkLeaks() type TestResponseTuple = tuple[status: int, data: string, count: int] @@ -1516,6 +1518,3 @@ suite "HTTP client testing suite": res.isErr() and res.error == HttpAddressErrorType.NameLookupFailed res.error.isRecoverableError() not(res.error.isCriticalError()) - - test "Leaks test": - checkLeaks() diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 91064f5..70cca33 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -14,6 +14,9 @@ import stew/base10 {.used.} suite "HTTP server testing suite": + teardown: + checkLeaks() + type TooBigTest = enum GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest @@ -1782,6 +1785,3 @@ suite "HTTP server testing suite": await server.stop() await server.closeWait() - - test "Leaks test": - checkLeaks() diff --git a/tests/testproc.nim b/tests/testproc.nim index 588e308..4d8accf 100644 --- a/tests/testproc.nim +++ b/tests/testproc.nim @@ -16,6 +16,9 @@ when defined(posix): when defined(nimHasUsed): {.used.} suite "Asynchronous process management test suite": + teardown: + checkLeaks() + const OutputTests = when defined(windows): [ @@ -463,6 +466,3 @@ suite "Asynchronous process management test suite": skip() else: check getCurrentFD() == markFD - - test "Leaks test": - checkLeaks() diff --git a/tests/testshttpserver.nim b/tests/testshttpserver.nim index 18e84a9..f846d8d 100644 --- a/tests/testshttpserver.nim +++ b/tests/testshttpserver.nim @@ -75,6 +75,8 @@ N8r5CwGcIX/XPC3lKazzbZ8baA== suite "Secure HTTP server testing suite": + teardown: + checkLeaks() proc httpsClient(address: TransportAddress, data: string, flags = {NoVerifyHost, NoVerifyServerName} @@ -184,6 +186,3 @@ suite "Secure HTTP server testing suite": return serverRes and data == "EXCEPTION" check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true - - test "Leaks test": - checkLeaks() diff --git a/tests/teststream.nim b/tests/teststream.nim index fb5534b..bf4c455 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -16,6 +16,9 @@ when defined(windows): importc: "_get_osfhandle", header:"".} suite "Stream Transport test suite": + teardown: + checkLeaks() + const ConstantMessage = "SOMEDATA" BigMessagePattern = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -1555,12 +1558,6 @@ suite "Stream Transport test suite": check waitFor(testAccept(addresses[i])) == true test prefixes[i] & "close() while in accept() waiting test": check waitFor(testAcceptClose(addresses[i])) == true - test prefixes[i] & "Intermediate transports leak test #1": - checkLeaks() - when defined(windows): - skip() - else: - checkLeaks(StreamTransportTrackerName) test prefixes[i] & "accept() too many file descriptors test": when defined(windows): skip() @@ -1671,8 +1668,6 @@ suite "Stream Transport test suite": DualStackType.Disabled, initTAddress("[::1]:0"))) == true else: skip() - test "Leaks test": - checkLeaks() test "File descriptors leak test": when defined(windows): # Windows handle numbers depends on many conditions, so we can't use From 3ca2c5e6b510c15ce88c94ed25731b30f7ad46b5 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Fri, 19 Jan 2024 09:21:10 +0100 Subject: [PATCH 08/11] deprecate `callback=`, UDP fixes (fixes #491) (#492) Using the callback setter may lead to callbacks owned by others being reset, which is unexpected. * don't crash on zero-length UDP writes --- chronos/internal/asyncfutures.nim | 8 +++++-- chronos/transports/datagram.nim | 38 +++++++++++++++---------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/chronos/internal/asyncfutures.nim b/chronos/internal/asyncfutures.nim index 5ce9da4..496a776 100644 --- a/chronos/internal/asyncfutures.nim +++ b/chronos/internal/asyncfutures.nim @@ -330,7 +330,8 @@ proc removeCallback*(future: FutureBase, cb: CallbackFunc, proc removeCallback*(future: FutureBase, cb: CallbackFunc) = future.removeCallback(cb, cast[pointer](future)) -proc `callback=`*(future: FutureBase, cb: CallbackFunc, udata: pointer) = +proc `callback=`*(future: FutureBase, cb: CallbackFunc, udata: pointer) {. + deprecated: "use addCallback/removeCallback/clearCallbacks to manage the callback list".} = ## Clears the list of callbacks and sets the callback proc to be called when ## the future completes. ## @@ -341,11 +342,14 @@ proc `callback=`*(future: FutureBase, cb: CallbackFunc, udata: pointer) = future.clearCallbacks future.addCallback(cb, udata) -proc `callback=`*(future: FutureBase, cb: CallbackFunc) = +proc `callback=`*(future: FutureBase, cb: CallbackFunc) {. + deprecated: "use addCallback/removeCallback/clearCallbacks instead to manage the callback list".} = ## Sets the callback proc to be called when the future completes. ## ## If future has already completed then ``cb`` will be called immediately. + {.push warning[Deprecated]: off.} `callback=`(future, cb, cast[pointer](future)) + {.pop.} proc `cancelCallback=`*(future: FutureBase, cb: CallbackFunc) = ## Sets the callback procedure to be called when the future is cancelled. diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index 88db7ee..cd335df 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -13,6 +13,7 @@ import std/deques when not(defined(windows)): import ".."/selectors2 import ".."/[asyncloop, config, osdefs, oserrno, osutils, handles] import "."/common +import stew/ptrops type VectorKind = enum @@ -119,7 +120,7 @@ when defined(windows): ## Initiation transp.state.incl(WritePending) let fd = SocketHandle(transp.fd) - var vector = transp.queue.popFirst() + let vector = transp.queue.popFirst() transp.setWriterWSABuffer(vector) let ret = if vector.kind == WithAddress: @@ -365,7 +366,7 @@ when defined(windows): udata: cast[pointer](res)) res.wovl.data = CompletionData(cb: writeDatagramLoop, udata: cast[pointer](res)) - res.rwsabuf = WSABUF(buf: cast[cstring](addr res.buffer[0]), + res.rwsabuf = WSABUF(buf: cast[cstring](baseAddr res.buffer), len: ULONG(len(res.buffer))) GC_ref(res) # Start tracking transport @@ -392,7 +393,7 @@ else: else: while true: transp.ralen = SockLen(sizeof(Sockaddr_storage)) - var res = osdefs.recvfrom(fd, addr transp.buffer[0], + var res = osdefs.recvfrom(fd, baseAddr transp.buffer, cint(len(transp.buffer)), cint(0), cast[ptr SockAddr](addr transp.raddr), addr transp.ralen) @@ -424,7 +425,7 @@ else: transp.state.incl({WritePaused}) else: if len(transp.queue) > 0: - var vector = transp.queue.popFirst() + let vector = transp.queue.popFirst() while true: if vector.kind == WithAddress: toSAddr(vector.address, transp.waddr, transp.walen) @@ -826,7 +827,7 @@ proc newDatagramTransport6*[T](cbproc: UnsafeDatagramCallback, proc join*(transp: DatagramTransport): Future[void] {. async: (raw: true, raises: [CancelledError]).} = ## Wait until the transport ``transp`` will be closed. - var retFuture = newFuture[void]("datagram.transport.join") + let retFuture = newFuture[void]("datagram.transport.join") proc continuation(udata: pointer) {.gcsafe.} = retFuture.complete() @@ -858,12 +859,12 @@ proc send*(transp: DatagramTransport, pbytes: pointer, async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send buffer with pointer ``pbytes`` and size ``nbytes`` using transport ## ``transp`` to remote destination address which was bounded on transport. - var retFuture = newFuture[void]("datagram.transport.send(pointer)") + let retFuture = newFuture[void]("datagram.transport.send(pointer)") transp.checkClosed(retFuture) if transp.remote.port == Port(0): retFuture.fail(newException(TransportError, "Remote peer not set!")) return retFuture - var vector = GramVector(kind: WithoutAddress, buf: pbytes, buflen: nbytes, + let vector = GramVector(kind: WithoutAddress, buf: pbytes, buflen: nbytes, writer: retFuture) transp.queue.addLast(vector) if WritePaused in transp.state: @@ -877,14 +878,14 @@ proc send*(transp: DatagramTransport, msg: sink string, async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send string ``msg`` using transport ``transp`` to remote destination ## address which was bounded on transport. - var retFuture = newFuture[void]("datagram.transport.send(string)") + let retFuture = newFuture[void]("datagram.transport.send(string)") transp.checkClosed(retFuture) let length = if msglen <= 0: len(msg) else: msglen var localCopy = chronosMoveSink(msg) retFuture.addCallback(proc(_: pointer) = reset(localCopy)) - let vector = GramVector(kind: WithoutAddress, buf: addr localCopy[0], + let vector = GramVector(kind: WithoutAddress, buf: baseAddr localCopy, buflen: length, writer: retFuture) @@ -900,14 +901,14 @@ proc send*[T](transp: DatagramTransport, msg: sink seq[T], async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send string ``msg`` using transport ``transp`` to remote destination ## address which was bounded on transport. - var retFuture = newFuture[void]("datagram.transport.send(seq)") + let retFuture = newFuture[void]("datagram.transport.send(seq)") transp.checkClosed(retFuture) let length = if msglen <= 0: (len(msg) * sizeof(T)) else: (msglen * sizeof(T)) var localCopy = chronosMoveSink(msg) retFuture.addCallback(proc(_: pointer) = reset(localCopy)) - let vector = GramVector(kind: WithoutAddress, buf: addr localCopy[0], + let vector = GramVector(kind: WithoutAddress, buf: baseAddr localCopy, buflen: length, writer: retFuture) transp.queue.addLast(vector) @@ -922,7 +923,7 @@ proc sendTo*(transp: DatagramTransport, remote: TransportAddress, async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send buffer with pointer ``pbytes`` and size ``nbytes`` using transport ## ``transp`` to remote destination address ``remote``. - var retFuture = newFuture[void]("datagram.transport.sendTo(pointer)") + let retFuture = newFuture[void]("datagram.transport.sendTo(pointer)") transp.checkClosed(retFuture) let vector = GramVector(kind: WithAddress, buf: pbytes, buflen: nbytes, writer: retFuture, address: remote) @@ -938,14 +939,14 @@ proc sendTo*(transp: DatagramTransport, remote: TransportAddress, async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send string ``msg`` using transport ``transp`` to remote destination ## address ``remote``. - var retFuture = newFuture[void]("datagram.transport.sendTo(string)") + let retFuture = newFuture[void]("datagram.transport.sendTo(string)") transp.checkClosed(retFuture) let length = if msglen <= 0: len(msg) else: msglen var localCopy = chronosMoveSink(msg) retFuture.addCallback(proc(_: pointer) = reset(localCopy)) - let vector = GramVector(kind: WithAddress, buf: addr localCopy[0], + let vector = GramVector(kind: WithAddress, buf: baseAddr localCopy, buflen: length, writer: retFuture, address: remote) @@ -961,15 +962,15 @@ proc sendTo*[T](transp: DatagramTransport, remote: TransportAddress, async: (raw: true, raises: [TransportError, CancelledError]).} = ## Send sequence ``msg`` using transport ``transp`` to remote destination ## address ``remote``. - var retFuture = newFuture[void]("datagram.transport.sendTo(seq)") + let retFuture = newFuture[void]("datagram.transport.sendTo(seq)") transp.checkClosed(retFuture) let length = if msglen <= 0: (len(msg) * sizeof(T)) else: (msglen * sizeof(T)) var localCopy = chronosMoveSink(msg) retFuture.addCallback(proc(_: pointer) = reset(localCopy)) - let vector = GramVector(kind: WithAddress, buf: addr localCopy[0], + let vector = GramVector(kind: WithAddress, buf: baseAddr localCopy, buflen: length, - writer: cast[Future[void]](retFuture), + writer: retFuture, address: remote) transp.queue.addLast(vector) if WritePaused in transp.state: @@ -993,7 +994,6 @@ proc peekMessage*(transp: DatagramTransport, msg: var seq[byte], proc getMessage*(transp: DatagramTransport): seq[byte] {. raises: [TransportError].} = ## Copy data from internal message buffer and return result. - var default: seq[byte] if ReadError in transp.state: transp.state.excl(ReadError) raise transp.getError() @@ -1002,7 +1002,7 @@ proc getMessage*(transp: DatagramTransport): seq[byte] {. copyMem(addr res[0], addr transp.buffer[0], transp.buflen) res else: - default + default(seq[byte]) proc getUserData*[T](transp: DatagramTransport): T {.inline.} = ## Obtain user data stored in ``transp`` object. From e296ae30c84bdd1f0b12c50ab551ed080f8a815c Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Sat, 20 Jan 2024 16:56:57 +0100 Subject: [PATCH 09/11] asyncraises for threadsync (#495) * asyncraises for threadsync * missing bracket * missing exception --- chronos/internal/asyncfutures.nim | 2 +- chronos/threadsync.nim | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/chronos/internal/asyncfutures.nim b/chronos/internal/asyncfutures.nim index 496a776..1a2be75 100644 --- a/chronos/internal/asyncfutures.nim +++ b/chronos/internal/asyncfutures.nim @@ -1553,7 +1553,7 @@ when defined(windows): proc waitForSingleObject*(handle: HANDLE, timeout: Duration): Future[WaitableResult] {. - raises: [].} = + async: (raises: [AsyncError, CancelledError], raw: true).} = ## Waits until the specified object is in the signaled state or the ## time-out interval elapses. WaitForSingleObject() for asynchronous world. let flags = WT_EXECUTEONLYONCE diff --git a/chronos/threadsync.nim b/chronos/threadsync.nim index bbff18b..f922c12 100644 --- a/chronos/threadsync.nim +++ b/chronos/threadsync.nim @@ -272,7 +272,8 @@ proc waitSync*(signal: ThreadSignalPtr, else: return ok(true) -proc fire*(signal: ThreadSignalPtr): Future[void] = +proc fire*(signal: ThreadSignalPtr): Future[void] {. + async: (raises: [AsyncError, CancelledError], raw: true).} = ## Set state of ``signal`` to signaled in asynchronous way. var retFuture = newFuture[void]("asyncthreadsignal.fire") when defined(windows): @@ -356,14 +357,17 @@ proc fire*(signal: ThreadSignalPtr): Future[void] = retFuture when defined(windows): - proc wait*(signal: ThreadSignalPtr) {.async.} = + proc wait*(signal: ThreadSignalPtr) {. + async: (raises: [AsyncError, CancelledError]).} = let handle = signal[].event let res = await waitForSingleObject(handle, InfiniteDuration) # There should be no other response, because we use `InfiniteDuration`. doAssert(res == WaitableResult.Ok) else: - proc wait*(signal: ThreadSignalPtr): Future[void] = - var retFuture = newFuture[void]("asyncthreadsignal.wait") + proc wait*(signal: ThreadSignalPtr): Future[void] {. + async: (raises: [AsyncError, CancelledError], raw: true).} = + let retFuture = Future[void].Raising([AsyncError, CancelledError]).init( + "asyncthreadsignal.wait") var data = 1'u64 let eventFd = when defined(linux): From 09a0b117194ed41ee6cebf628404698006d238b4 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Tue, 23 Jan 2024 09:34:10 +0200 Subject: [PATCH 10/11] Make asyncproc use asyncraises. (#497) * Make asyncproc use asyncraises. * Fix missing asyncraises for waitForExit(). --- chronos/asyncproc.nim | 75 ++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/chronos/asyncproc.nim b/chronos/asyncproc.nim index 8615c57..f008776 100644 --- a/chronos/asyncproc.nim +++ b/chronos/asyncproc.nim @@ -231,8 +231,9 @@ proc closeProcessHandles(pipes: var AsyncProcessPipes, lastError: OSErrorCode): OSErrorCode {.apforward.} proc closeProcessStreams(pipes: AsyncProcessPipes, options: set[AsyncProcessOption]): Future[void] {. - apforward.} -proc closeWait(holder: AsyncStreamHolder): Future[void] {.apforward.} + async: (raises: []).} +proc closeWait(holder: AsyncStreamHolder): Future[void] {. + async: (raises: []).} template isOk(code: OSErrorCode): bool = when defined(windows): @@ -391,7 +392,8 @@ when defined(windows): stdinHandle = ProcessStreamHandle(), stdoutHandle = ProcessStreamHandle(), stderrHandle = ProcessStreamHandle(), - ): Future[AsyncProcessRef] {.async.} = + ): Future[AsyncProcessRef] {. + async: (raises: [AsyncProcessError, CancelledError]).} = var pipes = preparePipes(options, stdinHandle, stdoutHandle, stderrHandle).valueOr: @@ -517,14 +519,16 @@ when defined(windows): ok(false) proc waitForExit*(p: AsyncProcessRef, - timeout = InfiniteDuration): Future[int] {.async.} = + timeout = InfiniteDuration): Future[int] {. + async: (raises: [AsyncProcessError, AsyncProcessTimeoutError, + CancelledError]).} = if p.exitStatus.isSome(): return p.exitStatus.get() let wres = try: await waitForSingleObject(p.processHandle, timeout) - except ValueError as exc: + except AsyncError as exc: raiseAsyncProcessError("Unable to wait for process handle", exc) if wres == WaitableResult.Timeout: @@ -537,7 +541,8 @@ when defined(windows): if exitCode >= 0: p.exitStatus = Opt.some(exitCode) - return exitCode + + exitCode proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = if p.exitStatus.isSome(): @@ -787,7 +792,8 @@ else: stdinHandle = ProcessStreamHandle(), stdoutHandle = ProcessStreamHandle(), stderrHandle = ProcessStreamHandle(), - ): Future[AsyncProcessRef] {.async.} = + ): Future[AsyncProcessRef] {. + async: (raises: [AsyncProcessError, CancelledError]).} = var pid: Pid pipes = preparePipes(options, stdinHandle, stdoutHandle, @@ -887,7 +893,7 @@ else: ) trackCounter(AsyncProcessTrackerName) - return process + process proc peekProcessExitCode(p: AsyncProcessRef, reap = false): AsyncProcessResult[int] = @@ -948,7 +954,9 @@ else: ok(false) proc waitForExit*(p: AsyncProcessRef, - timeout = InfiniteDuration): Future[int] = + timeout = InfiniteDuration): Future[int] {. + async: (raw: true, raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = var retFuture = newFuture[int]("chronos.waitForExit()") processHandle: ProcessHandle @@ -1050,7 +1058,7 @@ else: # Process is still running, so we going to wait for SIGCHLD. retFuture.cancelCallback = cancellation - return retFuture + retFuture proc peekExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = let res = ? p.peekProcessExitCode() @@ -1155,7 +1163,7 @@ proc preparePipes(options: set[AsyncProcessOption], stderrHandle: remoteStderr )) -proc closeWait(holder: AsyncStreamHolder) {.async.} = +proc closeWait(holder: AsyncStreamHolder) {.async: (raises: []).} = let (future, transp) = case holder.kind of StreamKind.None: @@ -1182,10 +1190,11 @@ proc closeWait(holder: AsyncStreamHolder) {.async.} = res if len(pending) > 0: - await allFutures(pending) + await noCancel allFutures(pending) proc closeProcessStreams(pipes: AsyncProcessPipes, - options: set[AsyncProcessOption]): Future[void] = + options: set[AsyncProcessOption]): Future[void] {. + async: (raw: true, raises: []).} = let pending = block: var res: seq[Future[void]] @@ -1196,10 +1205,12 @@ proc closeProcessStreams(pipes: AsyncProcessPipes, if ProcessFlag.AutoStderr in pipes.flags: res.add(pipes.stderrHolder.closeWait()) res - allFutures(pending) + noCancel allFutures(pending) proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation, - timeout = InfiniteDuration): Future[int] {.async.} = + timeout = InfiniteDuration): Future[int] {. + async: (raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = let timerFut = if timeout == InfiniteDuration: newFuture[void]("chronos.killAndwaitForExit") @@ -1223,7 +1234,10 @@ proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation, return exitCode let waitFut = p.waitForExit().wait(100.milliseconds) - discard await race(FutureBase(waitFut), FutureBase(timerFut)) + try: + discard await race(FutureBase(waitFut), FutureBase(timerFut)) + except ValueError: + raiseAssert "This should not be happened!" if waitFut.finished() and not(waitFut.failed()): let res = p.peekExitCode() @@ -1237,25 +1251,28 @@ proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation, await waitFut.cancelAndWait() raiseAsyncProcessTimeoutError() -proc closeWait*(p: AsyncProcessRef) {.async.} = +proc closeWait*(p: AsyncProcessRef) {.async: (raises: []).} = # Here we ignore all possible errrors, because we do not want to raise # exceptions. discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0)) - await noCancel(p.pipes.closeProcessStreams(p.options)) + await p.pipes.closeProcessStreams(p.options) discard p.closeThreadAndProcessHandle() untrackCounter(AsyncProcessTrackerName) proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter = + ## Returns STDIN async stream associated with process `p`. doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer, "StdinStreamWriter is not available") p.pipes.stdinHolder.writer proc stdoutStream*(p: AsyncProcessRef): AsyncStreamReader = + ## Returns STDOUT async stream associated with process `p`. doAssert(p.pipes.stdoutHolder.kind == StreamKind.Reader, "StdoutStreamReader is not available") p.pipes.stdoutHolder.reader proc stderrStream*(p: AsyncProcessRef): AsyncStreamReader = + ## Returns STDERR async stream associated with process `p`. doAssert(p.pipes.stderrHolder.kind == StreamKind.Reader, "StderrStreamReader is not available") p.pipes.stderrHolder.reader @@ -1263,7 +1280,9 @@ proc stderrStream*(p: AsyncProcessRef): AsyncStreamReader = proc execCommand*(command: string, options = {AsyncProcessOption.EvalCommand}, timeout = InfiniteDuration - ): Future[int] {.async.} = + ): Future[int] {. + async: (raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = let poptions = options + {AsyncProcessOption.EvalCommand} process = await startProcess(command, options = poptions) @@ -1277,7 +1296,9 @@ proc execCommand*(command: string, proc execCommandEx*(command: string, options = {AsyncProcessOption.EvalCommand}, timeout = InfiniteDuration - ): Future[CommandExResponse] {.async.} = + ): Future[CommandExResponse] {. + async: (raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = let process = await startProcess(command, options = options, stdoutHandle = AsyncProcess.Pipe, @@ -1291,13 +1312,13 @@ proc execCommandEx*(command: string, status = await process.waitForExit(timeout) output = try: - string.fromBytes(outputReader.read()) + string.fromBytes(await outputReader) except AsyncStreamError as exc: raiseAsyncProcessError("Unable to read process' stdout channel", exc) error = try: - string.fromBytes(errorReader.read()) + string.fromBytes(await errorReader) except AsyncStreamError as exc: raiseAsyncProcessError("Unable to read process' stderr channel", exc) @@ -1308,13 +1329,15 @@ proc execCommandEx*(command: string, res proc pid*(p: AsyncProcessRef): int = - ## Returns process ``p`` identifier. + ## Returns process ``p`` unique process identifier. int(p.processId) template processId*(p: AsyncProcessRef): int = pid(p) proc killAndWaitForExit*(p: AsyncProcessRef, - timeout = InfiniteDuration): Future[int] = + timeout = InfiniteDuration): Future[int] {. + async: (raw: true, raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = ## Perform continuous attempts to kill the ``p`` process for specified period ## of time ``timeout``. ## @@ -1330,7 +1353,9 @@ proc killAndWaitForExit*(p: AsyncProcessRef, opAndWaitForExit(p, WaitOperation.Kill, timeout) proc terminateAndWaitForExit*(p: AsyncProcessRef, - timeout = InfiniteDuration): Future[int] = + timeout = InfiniteDuration): Future[int] {. + async: (raw: true, raises: [ + AsyncProcessError, AsyncProcessTimeoutError, CancelledError]).} = ## Perform continuous attempts to terminate the ``p`` process for specified ## period of time ``timeout``. ## From 672db137b7cad9b384b8f4fb551fb6bbeaabfe1b Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Wed, 24 Jan 2024 18:33:13 +0100 Subject: [PATCH 11/11] v4.0.0 (#494) Features: * Exception effects / raises for async procedures helping you write more efficient leak-free code * Cross-thread notification mechanism for suitable building channels, queues and other multithreaded primitives * Async process I/O * IPv6 dual stack support * HTTP middleware support alloing multiple services to share a single http server * A new [documentation web site](https://status-im.github.io/nim-chronos/) covering the basics, with several simple examples for getting started * Implicit returns, support for `results.?` and other conveniences * Rate limiter * Revamped cancellation support with more control over the cancellation process * Efficiency improvements with `lent` and `sink` See the [porting](https://status-im.github.io/nim-chronos/porting.html) guides for porting code from earlier chronos releases (as well as asyncdispatch) --- chronos.nimble | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chronos.nimble b/chronos.nimble index e435883..01117b6 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,13 +1,13 @@ mode = ScriptMode.Verbose packageName = "chronos" -version = "3.2.0" +version = "4.0.0" author = "Status Research & Development GmbH" description = "Networking framework with async/await support" license = "MIT or Apache License 2.0" skipDirs = @["tests"] -requires "nim >= 1.6.0", +requires "nim >= 1.6.16", "results", "stew", "bearssl",