diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b78f2a12..b7aa0fa1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,7 +96,7 @@ jobs: - name: Restore Nim DLLs dependencies (Windows) from cache if: runner.os == 'Windows' id: windows-dlls-cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: external/dlls-${{ matrix.target.cpu }} key: 'dlls-${{ matrix.target.cpu }}' diff --git a/chronos.nim b/chronos.nim index 6801b289..8295924d 100644 --- a/chronos.nim +++ b/chronos.nim @@ -5,6 +5,5 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import chronos/[asyncloop, asyncsync, handles, transport, timer, - asyncproc, debugutils] -export asyncloop, asyncsync, handles, transport, timer, asyncproc, debugutils +import chronos/[asyncloop, asyncsync, handles, transport, timer, debugutils] +export asyncloop, asyncsync, handles, transport, timer, debugutils diff --git a/chronos.nimble b/chronos.nimble index 6b4ac58a..e9c1b11d 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -17,6 +17,22 @@ let nimc = getEnv("NIMC", "nim") # Which nim compiler to use let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js) let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler let verbose = getEnv("V", "") notin ["", "0"] +let testArguments = + when defined(windows): + [ + "-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert", + "-d:debug -d:chronosPreviewV4", + "-d:release", + "-d:release -d:chronosPreviewV4" + ] + else: + [ + "-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert", + "-d:debug -d:chronosPreviewV4", + "-d:debug -d:chronosDebug -d:chronosEventEngine=poll -d:useSysAssert -d:useGcAssert", + "-d:release", + "-d:release -d:chronosPreviewV4" + ] let styleCheckStyle = if (NimMajor, NimMinor) < (1, 6): "hint" else: "error" let cfg = @@ -31,12 +47,7 @@ proc run(args, path: string) = build args & " -r", path task test, "Run all tests": - for args in [ - "-d:debug -d:chronosDebug", - "-d:debug -d:chronosPreviewV4", - "-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert", - "-d:release", - "-d:release -d:chronosPreviewV4"]: + for args in testArguments: run args, "tests/testall" if (NimMajor, NimMinor) > (1, 6): run args & " --mm:refc", "tests/testall" diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim index ba2b1d4f..b948fbd3 100644 --- a/chronos/apps/http/httpbodyrw.nim +++ b/chronos/apps/http/httpbodyrw.nim @@ -25,71 +25,6 @@ type bstate*: HttpState streams*: seq[AsyncStreamWriter] - HttpBodyTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - -proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe, raises: [].} -proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe, raises: [].} - -proc getHttpBodyWriterTracker(): HttpBodyTracker {.inline.} = - var res = cast[HttpBodyTracker](getTracker(HttpBodyWriterTrackerName)) - if isNil(res): - res = setupHttpBodyWriterTracker() - res - -proc getHttpBodyReaderTracker(): HttpBodyTracker {.inline.} = - var res = cast[HttpBodyTracker](getTracker(HttpBodyReaderTrackerName)) - if isNil(res): - res = setupHttpBodyReaderTracker() - res - -proc dumpHttpBodyWriterTracking(): string {.gcsafe.} = - let tracker = getHttpBodyWriterTracker() - "Opened HTTP body writers: " & $tracker.opened & "\n" & - "Closed HTTP body writers: " & $tracker.closed - -proc dumpHttpBodyReaderTracking(): string {.gcsafe.} = - let tracker = getHttpBodyReaderTracker() - "Opened HTTP body readers: " & $tracker.opened & "\n" & - "Closed HTTP body readers: " & $tracker.closed - -proc leakHttpBodyWriter(): bool {.gcsafe.} = - var tracker = getHttpBodyWriterTracker() - tracker.opened != tracker.closed - -proc leakHttpBodyReader(): bool {.gcsafe.} = - var tracker = getHttpBodyReaderTracker() - tracker.opened != tracker.closed - -proc trackHttpBodyWriter(t: HttpBodyWriter) {.inline.} = - inc(getHttpBodyWriterTracker().opened) - -proc untrackHttpBodyWriter*(t: HttpBodyWriter) {.inline.} = - inc(getHttpBodyWriterTracker().closed) - -proc trackHttpBodyReader(t: HttpBodyReader) {.inline.} = - inc(getHttpBodyReaderTracker().opened) - -proc untrackHttpBodyReader*(t: HttpBodyReader) {.inline.} = - inc(getHttpBodyReaderTracker().closed) - -proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe.} = - var res = HttpBodyTracker(opened: 0, closed: 0, - dump: dumpHttpBodyWriterTracking, - isLeaked: leakHttpBodyWriter - ) - addTracker(HttpBodyWriterTrackerName, res) - res - -proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe.} = - var res = HttpBodyTracker(opened: 0, closed: 0, - dump: dumpHttpBodyReaderTracking, - isLeaked: leakHttpBodyReader - ) - addTracker(HttpBodyReaderTrackerName, res) - res - proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = ## HttpBodyReader is AsyncStreamReader which holds references to all the ## ``streams``. Also on close it will close all the ``streams``. @@ -98,7 +33,7 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = doAssert(len(streams) > 0, "At least one stream must be added") var res = HttpBodyReader(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) - trackHttpBodyReader(res) + trackCounter(HttpBodyReaderTrackerName) res proc closeWait*(bstream: HttpBodyReader) {.async.} = @@ -113,7 +48,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async.} = await allFutures(res) await procCall(closeWait(AsyncStreamReader(bstream))) bstream.bstate = HttpState.Closed - untrackHttpBodyReader(bstream) + untrackCounter(HttpBodyReaderTrackerName) proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = ## HttpBodyWriter is AsyncStreamWriter which holds references to all the @@ -123,7 +58,7 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = doAssert(len(streams) > 0, "At least one stream must be added") var res = HttpBodyWriter(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) - trackHttpBodyWriter(res) + trackCounter(HttpBodyWriterTrackerName) res proc closeWait*(bstream: HttpBodyWriter) {.async.} = @@ -136,7 +71,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async.} = await allFutures(res) await procCall(closeWait(AsyncStreamWriter(bstream))) bstream.bstate = HttpState.Closed - untrackHttpBodyWriter(bstream) + untrackCounter(HttpBodyWriterTrackerName) proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} = if len(bstream.streams) == 1: diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index 14f34a30..f173fecc 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -108,6 +108,7 @@ type remoteHostname*: string flags*: set[HttpClientConnectionFlag] timestamp*: Moment + duration*: Duration HttpClientConnectionRef* = ref HttpClientConnection @@ -190,10 +191,6 @@ type HttpClientFlags* = set[HttpClientFlag] - HttpClientTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - ServerSentEvent* = object name*: string data*: string @@ -204,100 +201,6 @@ type # HttpClientResponseRef valid states are # Open -> (Finished, Error) -> (Closing, Closed) -proc setupHttpClientConnectionTracker(): HttpClientTracker {. - gcsafe, raises: [].} -proc setupHttpClientRequestTracker(): HttpClientTracker {. - gcsafe, raises: [].} -proc setupHttpClientResponseTracker(): HttpClientTracker {. - gcsafe, raises: [].} - -proc getHttpClientConnectionTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientConnectionTrackerName)) - if isNil(res): - res = setupHttpClientConnectionTracker() - res - -proc getHttpClientRequestTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientRequestTrackerName)) - if isNil(res): - res = setupHttpClientRequestTracker() - res - -proc getHttpClientResponseTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientResponseTrackerName)) - if isNil(res): - res = setupHttpClientResponseTracker() - res - -proc dumpHttpClientConnectionTracking(): string {.gcsafe.} = - let tracker = getHttpClientConnectionTracker() - "Opened HTTP client connections: " & $tracker.opened & "\n" & - "Closed HTTP client connections: " & $tracker.closed - -proc dumpHttpClientRequestTracking(): string {.gcsafe.} = - let tracker = getHttpClientRequestTracker() - "Opened HTTP client requests: " & $tracker.opened & "\n" & - "Closed HTTP client requests: " & $tracker.closed - -proc dumpHttpClientResponseTracking(): string {.gcsafe.} = - let tracker = getHttpClientResponseTracker() - "Opened HTTP client responses: " & $tracker.opened & "\n" & - "Closed HTTP client responses: " & $tracker.closed - -proc leakHttpClientConnection(): bool {.gcsafe.} = - var tracker = getHttpClientConnectionTracker() - tracker.opened != tracker.closed - -proc leakHttpClientRequest(): bool {.gcsafe.} = - var tracker = getHttpClientRequestTracker() - tracker.opened != tracker.closed - -proc leakHttpClientResponse(): bool {.gcsafe.} = - var tracker = getHttpClientResponseTracker() - tracker.opened != tracker.closed - -proc trackHttpClientConnection(t: HttpClientConnectionRef) {.inline.} = - inc(getHttpClientConnectionTracker().opened) - -proc untrackHttpClientConnection*(t: HttpClientConnectionRef) {.inline.} = - inc(getHttpClientConnectionTracker().closed) - -proc trackHttpClientRequest(t: HttpClientRequestRef) {.inline.} = - inc(getHttpClientRequestTracker().opened) - -proc untrackHttpClientRequest*(t: HttpClientRequestRef) {.inline.} = - inc(getHttpClientRequestTracker().closed) - -proc trackHttpClientResponse(t: HttpClientResponseRef) {.inline.} = - inc(getHttpClientResponseTracker().opened) - -proc untrackHttpClientResponse*(t: HttpClientResponseRef) {.inline.} = - inc(getHttpClientResponseTracker().closed) - -proc setupHttpClientConnectionTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientConnectionTracking, - isLeaked: leakHttpClientConnection - ) - addTracker(HttpClientConnectionTrackerName, res) - res - -proc setupHttpClientRequestTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientRequestTracking, - isLeaked: leakHttpClientRequest - ) - addTracker(HttpClientRequestTrackerName, res) - res - -proc setupHttpClientResponseTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientResponseTracking, - isLeaked: leakHttpClientResponse - ) - addTracker(HttpClientResponseTrackerName, res) - res - template checkClosed(reqresp: untyped): untyped = if reqresp.connection.state in {HttpClientConnectionState.Closing, HttpClientConnectionState.Closed}: @@ -331,6 +234,12 @@ template setDuration( reqresp.duration = timestamp - reqresp.timestamp reqresp.connection.setTimestamp(timestamp) +template setDuration(conn: HttpClientConnectionRef): untyped = + if not(isNil(conn)): + let timestamp = Moment.now() + conn.duration = timestamp - conn.timestamp + conn.setTimestamp(timestamp) + template isReady(conn: HttpClientConnectionRef): bool = (conn.state == HttpClientConnectionState.Ready) and (HttpClientConnectionFlag.KeepAlive in conn.flags) and @@ -556,7 +465,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, state: HttpClientConnectionState.Connecting, remoteHostname: ha.id ) - trackHttpClientConnection(res) + trackCounter(HttpClientConnectionTrackerName) res of HttpClientScheme.Secure: let treader = newAsyncStreamReader(transp) @@ -575,7 +484,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, state: HttpClientConnectionState.Connecting, remoteHostname: ha.id ) - trackHttpClientConnection(res) + trackCounter(HttpClientConnectionTrackerName) res proc setError(request: HttpClientRequestRef, error: ref HttpError) {. @@ -615,13 +524,13 @@ proc closeWait(conn: HttpClientConnectionRef) {.async.} = discard await conn.transp.closeWait() conn.state = HttpClientConnectionState.Closed - untrackHttpClientConnection(conn) + untrackCounter(HttpClientConnectionTrackerName) proc connect(session: HttpSessionRef, ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} = ## Establish new connection with remote server using ``url`` and ``flags``. ## On success returns ``HttpClientConnectionRef`` object. - + var lastError = "" # Here we trying to connect to every possible remote host address we got after # DNS resolution. for address in ha.addresses: @@ -645,9 +554,14 @@ proc connect(session: HttpSessionRef, except CancelledError as exc: await res.closeWait() raise exc - except AsyncStreamError: + except TLSStreamProtocolError as exc: await res.closeWait() res.state = HttpClientConnectionState.Error + lastError = $exc.msg + except AsyncStreamError as exc: + await res.closeWait() + res.state = HttpClientConnectionState.Error + lastError = $exc.msg of HttpClientScheme.Nonsecure: res.state = HttpClientConnectionState.Ready res @@ -655,7 +569,11 @@ proc connect(session: HttpSessionRef, return conn # If all attempts to connect to the remote host have failed. - raiseHttpConnectionError("Could not connect to remote host") + if len(lastError) > 0: + raiseHttpConnectionError("Could not connect to remote host, reason: " & + lastError) + else: + raiseHttpConnectionError("Could not connect to remote host") proc removeConnection(session: HttpSessionRef, conn: HttpClientConnectionRef) {.async.} = @@ -685,9 +603,9 @@ proc acquireConnection( ): Future[HttpClientConnectionRef] {.async.} = ## Obtain connection from ``session`` or establish a new one. var default: seq[HttpClientConnectionRef] + let timestamp = Moment.now() if session.connectionPoolEnabled(flags): # Trying to reuse existing connection from our connection's pool. - let timestamp = Moment.now() # We looking for non-idle connection at `Ready` state, all idle connections # will be freed by sessionWatcher(). for connection in session.connections.getOrDefault(ha.id): @@ -704,6 +622,8 @@ proc acquireConnection( connection.state = HttpClientConnectionState.Acquired session.connections.mgetOrPut(ha.id, default).add(connection) inc(session.connectionsCount) + connection.setTimestamp(timestamp) + connection.setDuration() return connection proc releaseConnection(session: HttpSessionRef, @@ -835,7 +755,7 @@ proc closeWait*(request: HttpClientRequestRef) {.async.} = request.session = nil request.error = nil request.state = HttpReqRespState.Closed - untrackHttpClientRequest(request) + untrackCounter(HttpClientRequestTrackerName) proc closeWait*(response: HttpClientResponseRef) {.async.} = if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: @@ -848,7 +768,7 @@ proc closeWait*(response: HttpClientResponseRef) {.async.} = response.session = nil response.error = nil response.state = HttpReqRespState.Closed - untrackHttpClientResponse(response) + untrackCounter(HttpClientResponseTrackerName) proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] ): HttpResult[HttpClientResponseRef] {.raises: [] .} = @@ -958,7 +878,7 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] httpPipeline: res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive) res.connection.flags.incl(HttpClientConnectionFlag.Response) - trackHttpClientResponse(res) + trackCounter(HttpClientResponseTrackerName) ok(res) proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {. @@ -997,7 +917,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, version: version, flags: flags, headers: HttpTable.init(headers), address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body ) - trackHttpClientRequest(res) + trackCounter(HttpClientRequestTrackerName) res proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, @@ -1013,7 +933,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, version: version, flags: flags, headers: HttpTable.init(headers), address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body ) - trackHttpClientRequest(res) + trackCounter(HttpClientRequestTrackerName) ok(res) proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index cc2478d4..5a4a628c 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -13,6 +13,15 @@ import ../../streams/[asyncstream, boundstream] export asyncloop, asyncsync, results, httputils, strutils const + HttpServerUnsecureConnectionTrackerName* = + "httpserver.unsecure.connection" + HttpServerSecureConnectionTrackerName* = + "httpserver.secure.connection" + HttpServerRequestTrackerName* = + "httpserver.request" + HttpServerResponseTrackerName* = + "httpserver.response" + HeadersMark* = @[0x0d'u8, 0x0a'u8, 0x0d'u8, 0x0a'u8] PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete} diff --git a/chronos/apps/http/httpdebug.nim b/chronos/apps/http/httpdebug.nim new file mode 100644 index 00000000..a1dc0228 --- /dev/null +++ b/chronos/apps/http/httpdebug.nim @@ -0,0 +1,129 @@ +# +# Chronos HTTP/S server implementation +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/tables +import stew/results +import ../../timer +import httpserver, shttpserver +from httpclient import HttpClientScheme +from httpcommon import HttpState +from ../../osdefs import SocketHandle +from ../../transports/common import TransportAddress, ServerFlags +export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState + +{.push raises: [].} + +type + ConnectionType* {.pure.} = enum + NonSecure, Secure + + ConnectionState* {.pure.} = enum + Accepted, Alive, Closing, Closed + + ServerConnectionInfo* = object + handle*: SocketHandle + connectionType*: ConnectionType + connectionState*: ConnectionState + query*: Opt[string] + remoteAddress*: Opt[TransportAddress] + localAddress*: Opt[TransportAddress] + acceptMoment*: Moment + createMoment*: Opt[Moment] + + ServerInfo* = object + connectionType*: ConnectionType + address*: TransportAddress + state*: HttpServerState + maxConnections*: int + backlogSize*: int + baseUri*: Uri + serverIdent*: string + flags*: set[HttpServerFlags] + socketFlags*: set[ServerFlags] + headersTimeout*: Duration + bufferSize*: int + maxHeadersSize*: int + maxRequestBodySize*: int + +proc getConnectionType*( + server: HttpServerRef | SecureHttpServerRef): ConnectionType = + when server is SecureHttpServerRef: + ConnectionType.Secure + else: + if HttpServerFlags.Secure in server.flags: + ConnectionType.Secure + else: + ConnectionType.NonSecure + +proc getServerInfo*(server: HttpServerRef|SecureHttpServerRef): ServerInfo = + ServerInfo( + connectionType: server.getConnectionType(), + address: server.address, + state: server.state(), + maxConnections: server.maxConnections, + backlogSize: server.backlogSize, + baseUri: server.baseUri, + serverIdent: server.serverIdent, + flags: server.flags, + socketFlags: server.socketFlags, + headersTimeout: server.headersTimeout, + bufferSize: server.bufferSize, + maxHeadersSize: server.maxHeadersSize, + maxRequestBodySize: server.maxRequestBodySize + ) + +proc getConnectionState*(holder: HttpConnectionHolderRef): ConnectionState = + if not(isNil(holder.connection)): + case holder.connection.state + of HttpState.Alive: ConnectionState.Alive + of HttpState.Closing: ConnectionState.Closing + of HttpState.Closed: ConnectionState.Closed + else: + ConnectionState.Accepted + +proc getQueryString*(holder: HttpConnectionHolderRef): Opt[string] = + if not(isNil(holder.connection)): + holder.connection.currentRawQuery + else: + Opt.none(string) + +proc init*(t: typedesc[ServerConnectionInfo], + holder: HttpConnectionHolderRef): ServerConnectionInfo = + let + localAddress = + try: + Opt.some(holder.transp.localAddress()) + except CatchableError: + Opt.none(TransportAddress) + remoteAddress = + try: + Opt.some(holder.transp.remoteAddress()) + except CatchableError: + Opt.none(TransportAddress) + queryString = holder.getQueryString() + + ServerConnectionInfo( + handle: SocketHandle(holder.transp.fd), + connectionType: holder.server.getConnectionType(), + connectionState: holder.getConnectionState(), + remoteAddress: remoteAddress, + localAddress: localAddress, + acceptMoment: holder.acceptMoment, + query: queryString, + createMoment: + if not(isNil(holder.connection)): + Opt.some(holder.connection.createMoment) + else: + Opt.none(Moment) + ) + +proc getConnections*(server: HttpServerRef): seq[ServerConnectionInfo] = + var res: seq[ServerConnectionInfo] + for holder in server.connections.values(): + res.add(ServerConnectionInfo.init(holder)) + res diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index 1da4b44c..eafa27c6 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -25,20 +25,24 @@ type QueryCommaSeparatedArray ## Enable usage of comma as an array item delimiter in url-encoded ## entities (e.g. query string or POST body). + Http11Pipeline + ## Enable HTTP/1.1 pipelining. HttpServerError* {.pure.} = enum - TimeoutError, CatchableError, RecoverableError, CriticalError, - DisconnectError + InterruptError, TimeoutError, CatchableError, RecoverableError, + CriticalError, DisconnectError HttpServerState* {.pure.} = enum ServerRunning, ServerStopped, ServerClosed HttpProcessError* = object - error*: HttpServerError + kind*: HttpServerError code*: HttpCode exc*: ref CatchableError - remote*: TransportAddress + remote*: Opt[TransportAddress] + ConnectionFence* = Result[HttpConnectionRef, HttpProcessError] + ResponseFence* = Result[HttpResponseRef, HttpProcessError] RequestFence* = Result[HttpRequestRef, HttpProcessError] HttpRequestFlags* {.pure.} = enum @@ -50,8 +54,11 @@ type HttpResponseStreamType* {.pure.} = enum Plain, SSE, Chunked + HttpProcessExitType* {.pure.} = enum + KeepAlive, Graceful, Immediate + HttpResponseState* {.pure.} = enum - Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb + Empty, Prepared, Sending, Finished, Failed, Cancelled, Default HttpProcessCallback* = proc(req: RequestFence): Future[HttpResponseRef] {. @@ -62,6 +69,20 @@ type transp: StreamTransport): Future[HttpConnectionRef] {. gcsafe, raises: [].} + HttpCloseConnectionCallback* = + proc(connection: HttpConnectionRef): Future[void] {. + gcsafe, raises: [].} + + HttpConnectionHolder* = object of RootObj + connection*: HttpConnectionRef + server*: HttpServerRef + future*: Future[void] + transp*: StreamTransport + acceptMoment*: Moment + connectionId*: string + + HttpConnectionHolderRef* = ref HttpConnectionHolder + HttpServer* = object of RootObj instance*: StreamServer address*: TransportAddress @@ -72,7 +93,7 @@ type serverIdent*: string flags*: set[HttpServerFlags] socketFlags*: set[ServerFlags] - connections*: Table[string, Future[void]] + connections*: OrderedTable[string, HttpConnectionHolderRef] acceptLoop*: Future[void] lifetime*: Future[void] headersTimeout*: Duration @@ -120,11 +141,14 @@ type HttpConnection* = object of RootObj state*: HttpState server*: HttpServerRef - transp: StreamTransport + transp*: StreamTransport mainReader*: AsyncStreamReader mainWriter*: AsyncStreamWriter reader*: AsyncStreamReader writer*: AsyncStreamWriter + closeCb*: HttpCloseConnectionCallback + createMoment*: Moment + currentRawQuery*: Opt[string] buffer: seq[byte] HttpConnectionRef* = ref HttpConnection @@ -132,9 +156,24 @@ type ByteChar* = string | seq[byte] proc init(htype: typedesc[HttpProcessError], error: HttpServerError, - exc: ref CatchableError, remote: TransportAddress, - code: HttpCode): HttpProcessError {.raises: [].} = - HttpProcessError(error: error, exc: exc, remote: remote, code: code) + exc: ref CatchableError, remote: Opt[TransportAddress], + code: HttpCode): HttpProcessError {. + raises: [].} = + HttpProcessError(kind: error, exc: exc, remote: remote, code: code) + +proc init(htype: typedesc[HttpProcessError], + error: HttpServerError): HttpProcessError {. + raises: [].} = + HttpProcessError(kind: error) + +proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef, + transp: StreamTransport, + connectionId: string): HttpConnectionHolderRef = + HttpConnectionHolderRef( + server: server, transp: transp, acceptMoment: Moment.now(), + connectionId: connectionId) + +proc error*(e: HttpProcessError): HttpServerError = e.kind proc createConnection(server: HttpServerRef, transp: StreamTransport): Future[HttpConnectionRef] {. @@ -149,7 +188,7 @@ proc new*(htype: typedesc[HttpServerRef], serverIdent = "", maxConnections: int = -1, bufferSize: int = 4096, - backlogSize: int = 100, + backlogSize: int = DefaultBacklogSize, httpHeadersTimeout = 10.seconds, maxHeadersSize: int = 8192, maxRequestBodySize: int = 1_048_576): HttpResult[HttpServerRef] {. @@ -174,7 +213,7 @@ proc new*(htype: typedesc[HttpServerRef], return err(exc.msg) var res = HttpServerRef( - address: address, + address: serverInstance.localAddress(), instance: serverInstance, processCallback: processCallback, createConnCallback: createConnection, @@ -194,10 +233,37 @@ proc new*(htype: typedesc[HttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initTable[string, Future[void]]() + connections: initOrderedTable[string, HttpConnectionHolderRef]() ) ok(res) +proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] = + var defaultFlags: set[HttpServerFlags] = {} + if isNil(req): return defaultFlags + if isNil(req.connection): return defaultFlags + if isNil(req.connection.server): return defaultFlags + req.connection.server.flags + +proc getResponseFlags(req: HttpRequestRef): set[HttpResponseFlags] = + var defaultFlags: set[HttpResponseFlags] = {} + case req.version + of HttpVersion11: + if HttpServerFlags.Http11Pipeline notin req.getServerFlags(): + return defaultFlags + let header = req.headers.getString(ConnectionHeader, "keep-alive") + if header == "keep-alive": + {HttpResponseFlags.KeepAlive} + else: + defaultFlags + else: + defaultFlags + +proc getResponseVersion(reqFence: RequestFence): HttpVersion {.raises: [].} = + if reqFence.isErr(): + HttpVersion11 + else: + reqFence.get().version + proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} = if req.response.isNone(): var resp = HttpResponseRef( @@ -206,10 +272,7 @@ proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} = version: req.version, headersTable: HttpTable.init(), connection: req.connection, - flags: if req.version == HttpVersion11: - {HttpResponseFlags.KeepAlive} - else: - {} + flags: req.getResponseFlags() ) req.response = Opt.some(resp) resp @@ -222,9 +285,14 @@ proc getHostname*(server: HttpServerRef): string = else: server.baseUri.hostname -proc dumbResponse*(): HttpResponseRef {.raises: [].} = +proc defaultResponse*(): HttpResponseRef {.raises: [].} = ## Create an empty response to return when request processor got no request. - HttpResponseRef(state: HttpResponseState.Dumb, version: HttpVersion11) + HttpResponseRef(state: HttpResponseState.Default, version: HttpVersion11) + +proc dumbResponse*(): HttpResponseRef {.raises: [], + deprecated: "Please use defaultResponse() instead".} = + ## Create an empty response to return when request processor got no request. + defaultResponse() proc getId(transp: StreamTransport): Result[string, string] {.inline.} = ## Returns string unique transport's identifier as string. @@ -358,6 +426,7 @@ proc prepareRequest(conn: HttpConnectionRef, if strip(expectHeader).toLowerAscii() == "100-continue": request.requestFlags.incl(HttpRequestFlags.ClientExpect) + trackCounter(HttpServerRequestTrackerName) ok(request) proc getBodyReader*(request: HttpRequestRef): HttpResult[HttpBodyReader] = @@ -566,7 +635,7 @@ proc preferredContentType*(request: HttpRequestRef, proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion, code: HttpCode, keepAlive = true, datatype = "text/text", - databody = ""): Future[bool] {.async.} = + databody = "") {.async.} = var answer = $version & " " & $code & "\r\n" answer.add(DateHeader) answer.add(": ") @@ -592,13 +661,115 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion, answer.add(databody) try: await conn.writer.write(answer) - return true + except CancelledError as exc: + raise exc + except CatchableError: + # We ignore errors here, because we indicating error already. + discard + +proc sendErrorResponse( + conn: HttpConnectionRef, + reqFence: RequestFence, + respError: HttpProcessError + ): Future[HttpProcessExitType] {.async.} = + let version = getResponseVersion(reqFence) + try: + if reqFence.isOk(): + case respError.kind + of HttpServerError.CriticalError: + await conn.sendErrorResponse(version, respError.code, false) + HttpProcessExitType.Graceful + of HttpServerError.RecoverableError: + await conn.sendErrorResponse(version, respError.code, true) + HttpProcessExitType.Graceful + of HttpServerError.CatchableError: + await conn.sendErrorResponse(version, respError.code, false) + HttpProcessExitType.Graceful + of HttpServerError.DisconnectError, + HttpServerError.InterruptError, + HttpServerError.TimeoutError: + raiseAssert("Unexpected response error: " & $respError.kind) + else: + HttpProcessExitType.Graceful except CancelledError: - return false - except AsyncStreamWriteError: - return false - except AsyncStreamIncompleteError: - return false + HttpProcessExitType.Immediate + except CatchableError: + HttpProcessExitType.Immediate + +proc sendDefaultResponse( + conn: HttpConnectionRef, + reqFence: RequestFence, + response: HttpResponseRef + ): Future[HttpProcessExitType] {.async.} = + let + version = getResponseVersion(reqFence) + keepConnection = + if isNil(response) or (HttpResponseFlags.KeepAlive notin response.flags): + HttpProcessExitType.Graceful + else: + HttpProcessExitType.KeepAlive + + template toBool(hpet: HttpProcessExitType): bool = + case hpet + of HttpProcessExitType.KeepAlive: + true + of HttpProcessExitType.Immediate: + false + of HttpProcessExitType.Graceful: + false + + try: + if reqFence.isOk(): + if isNil(response): + await conn.sendErrorResponse(version, Http404, keepConnection.toBool()) + keepConnection + else: + case response.state + of HttpResponseState.Empty: + # Response was ignored, so we respond with not found. + await conn.sendErrorResponse(version, Http404, + keepConnection.toBool()) + keepConnection + of HttpResponseState.Prepared: + # Response was prepared but not sent, so we can respond with some + # error code + await conn.sendErrorResponse(HttpVersion11, Http409, + keepConnection.toBool()) + keepConnection + of HttpResponseState.Sending, HttpResponseState.Failed, + HttpResponseState.Cancelled: + # Just drop connection, because we dont know at what stage we are + HttpProcessExitType.Immediate + of HttpResponseState.Default: + # Response was ignored, so we respond with not found. + await conn.sendErrorResponse(version, Http404, + keepConnection.toBool()) + keepConnection + of HttpResponseState.Finished: + keepConnection + else: + case reqFence.error.kind + of HttpServerError.TimeoutError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + HttpProcessExitType.Graceful + of HttpServerError.CriticalError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + HttpProcessExitType.Graceful + of HttpServerError.RecoverableError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + HttpProcessExitType.Graceful + of HttpServerError.CatchableError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + HttpProcessExitType.Graceful + of HttpServerError.DisconnectError: + # When `HttpServerFlags.NotifyDisconnect` is set. + HttpProcessExitType.Immediate + of HttpServerError.InterruptError: + raiseAssert("Unexpected request error: " & $reqFence.error.kind) + except CancelledError: + HttpProcessExitType.Immediate + except CatchableError: + HttpProcessExitType.Immediate proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} = try: @@ -631,31 +802,38 @@ proc init*(value: var HttpConnection, server: HttpServerRef, mainWriter: newAsyncStreamWriter(transp) ) +proc closeUnsecureConnection(conn: HttpConnectionRef) {.async.} = + if conn.state == HttpState.Alive: + conn.state = HttpState.Closing + var pending: seq[Future[void]] + pending.add(conn.mainReader.closeWait()) + pending.add(conn.mainWriter.closeWait()) + pending.add(conn.transp.closeWait()) + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + untrackCounter(HttpServerUnsecureConnectionTrackerName) + reset(conn[]) + conn.state = HttpState.Closed + proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, transp: StreamTransport): HttpConnectionRef = var res = HttpConnectionRef() res[].init(server, transp) res.reader = res.mainReader res.writer = res.mainWriter + res.closeCb = closeUnsecureConnection + res.createMoment = Moment.now() + trackCounter(HttpServerUnsecureConnectionTrackerName) res -proc closeWait*(conn: HttpConnectionRef) {.async.} = - if conn.state == HttpState.Alive: - conn.state = HttpState.Closing - var pending: seq[Future[void]] - if conn.reader != conn.mainReader: - pending.add(conn.reader.closeWait()) - if conn.writer != conn.mainWriter: - pending.add(conn.writer.closeWait()) - if len(pending) > 0: - await allFutures(pending) - # After we going to close everything else. - pending.setLen(3) - pending[0] = conn.mainReader.closeWait() - pending[1] = conn.mainWriter.closeWait() - pending[2] = conn.transp.closeWait() - await allFutures(pending) - conn.state = HttpState.Closed +proc gracefulCloseWait*(conn: HttpConnectionRef) {.async.} = + await conn.transp.shutdownWait() + await conn.closeCb(conn) + +proc closeWait*(conn: HttpConnectionRef): Future[void] = + conn.closeCb(conn) proc closeWait*(req: HttpRequestRef) {.async.} = if req.state == HttpState.Alive: @@ -663,7 +841,14 @@ proc closeWait*(req: HttpRequestRef) {.async.} = req.state = HttpState.Closing let resp = req.response.get() if (HttpResponseFlags.Stream in resp.flags) and not(isNil(resp.writer)): - await resp.writer.closeWait() + var writer = resp.writer.closeWait() + try: + await writer + except CancelledError: + await writer + reset(resp[]) + untrackCounter(HttpServerRequestTrackerName) + reset(req[]) req.state = HttpState.Closed proc createConnection(server: HttpServerRef, @@ -681,175 +866,190 @@ proc `keepalive=`*(resp: HttpResponseRef, value: bool) = proc keepalive*(resp: HttpResponseRef): bool {.raises: [].} = HttpResponseFlags.KeepAlive in resp.flags -proc processLoop(server: HttpServerRef, transp: StreamTransport, - connId: string) {.async.} = - var - conn: HttpConnectionRef - connArg: RequestFence - runLoop = false - +proc getRemoteAddress(transp: StreamTransport): Opt[TransportAddress] {. + raises: [].} = + if isNil(transp): return Opt.none(TransportAddress) try: - conn = await server.createConnCallback(server, transp) - runLoop = true + Opt.some(transp.remoteAddress()) + except CatchableError: + Opt.none(TransportAddress) + +proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] {. + raises: [].} = + if isNil(connection): return Opt.none(TransportAddress) + getRemoteAddress(connection.transp) + +proc getResponseFence*(connection: HttpConnectionRef, + reqFence: RequestFence): Future[ResponseFence] {. + async.} = + try: + let res = await connection.server.processCallback(reqFence) + ResponseFence.ok(res) except CancelledError: - server.connections.del(connId) - await transp.closeWait() - return + ResponseFence.err(HttpProcessError.init( + HttpServerError.InterruptError)) except HttpCriticalError as exc: - let error = HttpProcessError.init(HttpServerError.CriticalError, exc, - transp.remoteAddress(), exc.code) - connArg = RequestFence.err(error) - runLoop = false + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpRecoverableError as exc: + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except CatchableError as exc: + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http503)) - if not(runLoop): - try: - # We still want to notify process callback about failure, but we ignore - # result. - discard await server.processCallback(connArg) - except CancelledError: - runLoop = false - except CatchableError as exc: - # There should be no exceptions, so we will raise `Defect`. - raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]") +proc getResponseFence*(server: HttpServerRef, + connFence: ConnectionFence): Future[ResponseFence] {. + async.} = + doAssert(connFence.isErr()) + try: + let + reqFence = RequestFence.err(connFence.error) + res = await server.processCallback(reqFence) + ResponseFence.ok(res) + except CancelledError: + ResponseFence.err(HttpProcessError.init( + HttpServerError.InterruptError)) + except HttpCriticalError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpRecoverableError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except CatchableError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http503)) - var breakLoop = false - while runLoop: - var - arg: RequestFence - resp: HttpResponseRef - - try: - let request = - if server.headersTimeout.isInfinite(): - await conn.getRequest() - else: - await conn.getRequest().wait(server.headersTimeout) - arg = RequestFence.ok(request) - except CancelledError: - breakLoop = true - except AsyncTimeoutError as exc: - let error = HttpProcessError.init(HttpServerError.TimeoutError, exc, - transp.remoteAddress(), Http408) - arg = RequestFence.err(error) - except HttpRecoverableError as exc: - let error = HttpProcessError.init(HttpServerError.RecoverableError, exc, - transp.remoteAddress(), exc.code) - arg = RequestFence.err(error) - except HttpCriticalError as exc: - let error = HttpProcessError.init(HttpServerError.CriticalError, exc, - transp.remoteAddress(), exc.code) - arg = RequestFence.err(error) - except HttpDisconnectError as exc: - if HttpServerFlags.NotifyDisconnect in server.flags: - let error = HttpProcessError.init(HttpServerError.DisconnectError, exc, - transp.remoteAddress(), Http400) - arg = RequestFence.err(error) +proc getRequestFence*(server: HttpServerRef, + connection: HttpConnectionRef): Future[RequestFence] {. + async.} = + try: + let res = + if server.headersTimeout.isInfinite(): + await connection.getRequest() else: - breakLoop = true - except CatchableError as exc: - let error = HttpProcessError.init(HttpServerError.CatchableError, exc, - transp.remoteAddress(), Http500) - arg = RequestFence.err(error) + await connection.getRequest().wait(server.headersTimeout) + connection.currentRawQuery = Opt.some(res.rawPath) + RequestFence.ok(res) + except CancelledError: + RequestFence.err(HttpProcessError.init(HttpServerError.InterruptError)) + except AsyncTimeoutError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.TimeoutError, exc, address, Http408)) + except HttpRecoverableError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except HttpCriticalError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpDisconnectError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.DisconnectError, exc, address, Http400)) + except CatchableError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http500)) - if breakLoop: - break +proc getConnectionFence*(server: HttpServerRef, + transp: StreamTransport): Future[ConnectionFence] {. + async.} = + try: + let res = await server.createConnCallback(server, transp) + ConnectionFence.ok(res) + except CancelledError: + ConnectionFence.err(HttpProcessError.init(HttpServerError.InterruptError)) + except HttpCriticalError as exc: + # On error `transp` will be closed by `createConnCallback()` call. + let address = Opt.none(TransportAddress) + ConnectionFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except CatchableError as exc: + # On error `transp` will be closed by `createConnCallback()` call. + let address = Opt.none(TransportAddress) + ConnectionFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, Http503)) - breakLoop = false - var lastErrorCode: Opt[HttpCode] - - try: - resp = await conn.server.processCallback(arg) - except CancelledError: - breakLoop = true - except HttpCriticalError as exc: - lastErrorCode = Opt.some(exc.code) - except HttpRecoverableError as exc: - lastErrorCode = Opt.some(exc.code) - except CatchableError: - lastErrorCode = Opt.some(Http503) - - if breakLoop: - break - - if arg.isErr(): - let code = arg.error().code - try: - case arg.error().error - of HttpServerError.TimeoutError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.RecoverableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.CriticalError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.CatchableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.DisconnectError: - discard - except CancelledError: - # We swallowing `CancelledError` in a loop, but we going to exit - # loop ASAP. - discard - break +proc processRequest(server: HttpServerRef, + connection: HttpConnectionRef, + connId: string): Future[HttpProcessExitType] {.async.} = + let requestFence = await getRequestFence(server, connection) + if requestFence.isErr(): + case requestFence.error.kind + of HttpServerError.InterruptError: + return HttpProcessExitType.Immediate + of HttpServerError.DisconnectError: + if HttpServerFlags.NotifyDisconnect notin server.flags: + return HttpProcessExitType.Immediate else: - let request = arg.get() - var keepConn = if request.version == HttpVersion11: true else: false - if lastErrorCode.isNone(): - if isNil(resp): - # Response was `nil`. - try: - discard await conn.sendErrorResponse(HttpVersion11, Http404, false) - except CancelledError: - keepConn = false - else: - try: - case resp.state - of HttpResponseState.Empty: - # Response was ignored - discard await conn.sendErrorResponse(HttpVersion11, Http404, - keepConn) - of HttpResponseState.Prepared: - # Response was prepared but not sent. - discard await conn.sendErrorResponse(HttpVersion11, Http409, - keepConn) - else: - # some data was already sent to the client. - keepConn = resp.keepalive() - except CancelledError: - keepConn = false - else: - try: - discard await conn.sendErrorResponse(HttpVersion11, - lastErrorCode.get(), false) - except CancelledError: - keepConn = false + discard - # Closing and releasing all the request resources. + let responseFence = await getResponseFence(connection, requestFence) + if responseFence.isErr() and + (responseFence.error.kind == HttpServerError.InterruptError): + if requestFence.isOk(): + await requestFence.get().closeWait() + return HttpProcessExitType.Immediate + + let res = + if responseFence.isErr(): + await connection.sendErrorResponse(requestFence, responseFence.error) + else: + await connection.sendDefaultResponse(requestFence, responseFence.get()) + + if requestFence.isOk(): + await requestFence.get().closeWait() + + res + +proc processLoop(holder: HttpConnectionHolderRef) {.async.} = + let + server = holder.server + transp = holder.transp + connectionId = holder.connectionId + connection = + block: + let res = await server.getConnectionFence(transp) + if res.isErr(): + if res.error.kind != HttpServerError.InterruptError: + discard await server.getResponseFence(res) + server.connections.del(connectionId) + return + res.get() + + holder.connection = connection + + var runLoop = HttpProcessExitType.KeepAlive + while runLoop == HttpProcessExitType.KeepAlive: + runLoop = try: - await request.closeWait() + await server.processRequest(connection, connectionId) except CancelledError: - # We swallowing `CancelledError` in a loop, but we still need to close - # `request` before exiting. - await request.closeWait() + HttpProcessExitType.Immediate + except CatchableError as exc: + raiseAssert "Unexpected error [" & $exc.name & "] happens: " & $exc.msg - if not(keepConn): - break - - # Connection could be `nil` only when secure handshake is failed. - if not(isNil(conn)): - try: - await conn.closeWait() - except CancelledError: - # Cancellation could be happened while we closing `conn`. But we still - # need to close it. - await conn.closeWait() - - server.connections.del(connId) - # if server.maxConnections > 0: - # server.semaphore.release() + server.connections.del(connectionId) + case runLoop + of HttpProcessExitType.KeepAlive: + await connection.closeWait() + of HttpProcessExitType.Immediate: + await connection.closeWait() + of HttpProcessExitType.Graceful: + await connection.gracefulCloseWait() proc acceptClientLoop(server: HttpServerRef) {.async.} = - var breakLoop = false - while true: + var runLoop = true + while runLoop: try: # if server.maxConnections > 0: # await server.semaphore.acquire() @@ -859,28 +1059,18 @@ proc acceptClientLoop(server: HttpServerRef) {.async.} = # We are unable to identify remote peer, it means that remote peer # disconnected before identification. await transp.closeWait() - breakLoop = false + runLoop = false else: let connId = resId.get() - server.connections[connId] = processLoop(server, transp, connId) - except CancelledError: - # Server was stopped - breakLoop = true - except TransportOsError: - # This is some critical unrecoverable error. - breakLoop = true - except TransportTooManyError: - # Non critical error - breakLoop = false - except TransportAbortedError: - # Non critical error - breakLoop = false - except CatchableError: - # Unexpected error - breakLoop = true - - if breakLoop: - break + let holder = HttpConnectionHolderRef.new(server, transp, resId.get()) + server.connections[connId] = holder + holder.future = processLoop(holder) + except TransportTooManyError, TransportAbortedError: + # Non-critical error + discard + except CancelledError, TransportOsError, CatchableError: + # Critical, cancellation or unexpected error + runLoop = false proc state*(server: HttpServerRef): HttpServerState {.raises: [].} = ## Returns current HTTP server's state. @@ -909,11 +1099,11 @@ proc drop*(server: HttpServerRef) {.async.} = ## Drop all pending HTTP connections. var pending: seq[Future[void]] if server.state in {ServerStopped, ServerRunning}: - for fut in server.connections.values(): - if not(fut.finished()): - fut.cancel() - pending.add(fut) + for holder in server.connections.values(): + if not(isNil(holder.future)) and not(holder.future.finished()): + pending.add(holder.future.cancelAndWait()) await allFutures(pending) + server.connections.clear() proc closeWait*(server: HttpServerRef) {.async.} = ## Stop HTTP server and drop all the pending connections. diff --git a/chronos/apps/http/httptable.nim b/chronos/apps/http/httptable.nim index 86060de3..f44765ae 100644 --- a/chronos/apps/http/httptable.nim +++ b/chronos/apps/http/httptable.nim @@ -197,3 +197,7 @@ proc toList*(ht: HttpTables, normKey = false): auto = for key, value in ht.stringItems(normKey): res.add((key, value)) res + +proc clear*(ht: var HttpTables) = + ## Resets the HtppTable so that it is empty. + ht.table.clear() diff --git a/chronos/apps/http/shttpserver.nim b/chronos/apps/http/shttpserver.nim index 93f253b8..bc5c3fbe 100644 --- a/chronos/apps/http/shttpserver.nim +++ b/chronos/apps/http/shttpserver.nim @@ -24,6 +24,29 @@ type SecureHttpConnectionRef* = ref SecureHttpConnection +proc closeSecConnection(conn: HttpConnectionRef) {.async.} = + if conn.state == HttpState.Alive: + conn.state = HttpState.Closing + var pending: seq[Future[void]] + pending.add(conn.writer.closeWait()) + pending.add(conn.reader.closeWait()) + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + # After we going to close everything else. + pending.setLen(3) + pending[0] = conn.mainReader.closeWait() + pending[1] = conn.mainWriter.closeWait() + pending[2] = conn.transp.closeWait() + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + reset(cast[SecureHttpConnectionRef](conn)[]) + untrackCounter(HttpServerSecureConnectionTrackerName) + conn.state = HttpState.Closed + proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef, transp: StreamTransport): SecureHttpConnectionRef = var res = SecureHttpConnectionRef() @@ -37,6 +60,8 @@ proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef, res.tlsStream = tlsStream res.reader = AsyncStreamReader(tlsStream.reader) res.writer = AsyncStreamWriter(tlsStream.writer) + res.closeCb = closeSecConnection + trackCounter(HttpServerSecureConnectionTrackerName) res proc createSecConnection(server: HttpServerRef, @@ -50,9 +75,16 @@ proc createSecConnection(server: HttpServerRef, except CancelledError as exc: await HttpConnectionRef(sconn).closeWait() raise exc - except TLSStreamError: + except TLSStreamError as exc: await HttpConnectionRef(sconn).closeWait() - raiseHttpCriticalError("Unable to establish secure connection") + let msg = "Unable to establish secure connection, reason [" & + $exc.msg & "]" + raiseHttpCriticalError(msg) + except CatchableError as exc: + await HttpConnectionRef(sconn).closeWait() + let msg = "Unexpected error while trying to establish secure connection, " & + "reason [" & $exc.msg & "]" + raiseHttpCriticalError(msg) proc new*(htype: typedesc[SecureHttpServerRef], address: TransportAddress, @@ -66,7 +98,7 @@ proc new*(htype: typedesc[SecureHttpServerRef], secureFlags: set[TLSFlags] = {}, maxConnections: int = -1, bufferSize: int = 4096, - backlogSize: int = 100, + backlogSize: int = DefaultBacklogSize, httpHeadersTimeout = 10.seconds, maxHeadersSize: int = 8192, maxRequestBodySize: int = 1_048_576 @@ -100,7 +132,7 @@ proc new*(htype: typedesc[SecureHttpServerRef], createConnCallback: createSecConnection, baseUri: serverUri, serverIdent: serverIdent, - flags: serverFlags, + flags: serverFlags + {HttpServerFlags.Secure}, socketFlags: socketFlags, maxConnections: maxConnections, bufferSize: bufferSize, @@ -114,7 +146,7 @@ proc new*(htype: typedesc[SecureHttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initTable[string, Future[void]](), + connections: initOrderedTable[string, HttpConnectionHolderRef](), tlsCertificate: tlsCertificate, tlsPrivateKey: tlsPrivateKey, secureFlags: secureFlags diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index 77439162..a644b778 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -171,11 +171,16 @@ type dump*: proc(): string {.gcsafe, raises: [].} isLeaked*: proc(): bool {.gcsafe, raises: [].} + TrackerCounter* = object + opened*: uint64 + closed*: uint64 + PDispatcherBase = ref object of RootRef timers*: HeapQueue[TimerCallback] callbacks*: Deque[AsyncCallback] idlers*: Deque[AsyncCallback] trackers*: Table[string, TrackerBase] + counters*: Table[string, TrackerCounter] proc sentinelCallbackImpl(arg: pointer) {.gcsafe.} = raiseAssert "Sentinel callback MUST not be scheduled" @@ -308,6 +313,7 @@ when defined(windows): getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS transmitFile*: WSAPROC_TRANSMITFILE getQueuedCompletionStatusEx*: LPFN_GETQUEUEDCOMPLETIONSTATUSEX + disconnectEx*: WSAPROC_DISCONNECTEX flags: set[DispatcherFlag] PtrCustomOverlapped* = ptr CustomOverlapped @@ -388,6 +394,13 @@ when defined(windows): "dispatcher's TransmitFile()") loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer) + block: + let res = getFunc(sock, funcPointer, WSAID_DISCONNECTEX) + if not(res): + raiseOsDefect(osLastError(), "initAPI(): Unable to initialize " & + "dispatcher's DisconnectEx()") + loop.disconnectEx = cast[WSAPROC_DISCONNECTEX](funcPointer) + if closeFd(sock) != 0: raiseOsDefect(osLastError(), "initAPI(): Unable to close control socket") @@ -404,7 +417,8 @@ when defined(windows): timers: initHeapQueue[TimerCallback](), callbacks: initDeque[AsyncCallback](64), idlers: initDeque[AsyncCallback](), - trackers: initTable[string, TrackerBase]() + trackers: initTable[string, TrackerBase](), + counters: initTable[string, TrackerCounter]() ) res.callbacks.addLast(SentinelCallback) initAPI(res) @@ -811,10 +825,11 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or var res = PDispatcher( selector: selector, timers: initHeapQueue[TimerCallback](), - callbacks: initDeque[AsyncCallback](asyncEventsCount), + callbacks: initDeque[AsyncCallback](chronosEventsCount), idlers: initDeque[AsyncCallback](), - keys: newSeq[ReadyKey](asyncEventsCount), - trackers: initTable[string, TrackerBase]() + keys: newSeq[ReadyKey](chronosEventsCount), + trackers: initTable[string, TrackerBase](), + counters: initTable[string, TrackerCounter]() ) res.callbacks.addLast(SentinelCallback) initAPI(res) @@ -994,7 +1009,7 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or ## You can execute ``aftercb`` before actual socket close operation. closeSocket(fd, aftercb) - when asyncEventEngine in ["epoll", "kqueue"]: + when chronosEventEngine in ["epoll", "kqueue"]: type ProcessHandle* = distinct int SignalHandle* = distinct int @@ -1108,7 +1123,7 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or if not isNil(adata.reader.function): loop.callbacks.addLast(adata.reader) - when asyncEventEngine in ["epoll", "kqueue"]: + when chronosEventEngine in ["epoll", "kqueue"]: let customSet = {Event.Timer, Event.Signal, Event.Process, Event.Vnode} if customSet * events != {}: @@ -1242,10 +1257,7 @@ proc callIdle*(cbproc: CallbackFunc) = include asyncfutures2 - -when defined(macosx) or defined(macos) or defined(freebsd) or - defined(netbsd) or defined(openbsd) or defined(dragonfly) or - defined(linux) or defined(windows): +when (chronosEventEngine in ["epoll", "kqueue"]) or defined(windows): proc waitSignal*(signal: int): Future[void] {.raises: [].} = var retFuture = newFuture[void]("chronos.waitSignal()") @@ -1505,16 +1517,54 @@ proc waitFor*[T](fut: Future[T]): T {.raises: [CatchableError].} = fut.read() -proc addTracker*[T](id: string, tracker: T) = +proc addTracker*[T](id: string, tracker: T) {. + deprecated: "Please use trackCounter facility instead".} = ## Add new ``tracker`` object to current thread dispatcher with identifier ## ``id``. - let loop = getThreadDispatcher() - loop.trackers[id] = tracker + getThreadDispatcher().trackers[id] = tracker -proc getTracker*(id: string): TrackerBase = +proc getTracker*(id: string): TrackerBase {. + deprecated: "Please use getTrackerCounter() instead".} = ## Get ``tracker`` from current thread dispatcher using identifier ``id``. - let loop = getThreadDispatcher() - result = loop.trackers.getOrDefault(id, nil) + getThreadDispatcher().trackers.getOrDefault(id, nil) + +proc trackCounter*(name: string) {.noinit.} = + ## Increase tracker counter with name ``name`` by 1. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).opened) + +proc untrackCounter*(name: string) {.noinit.} = + ## Decrease tracker counter with name ``name`` by 1. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).closed) + +proc getTrackerCounter*(name: string): TrackerCounter {.noinit.} = + ## Return value of counter with name ``name``. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + getThreadDispatcher().counters.getOrDefault(name, tracker) + +proc isCounterLeaked*(name: string): bool {.noinit.} = + ## Returns ``true`` if leak is detected, number of `opened` not equal to + ## number of `closed` requests. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + let res = getThreadDispatcher().counters.getOrDefault(name, tracker) + res.opened == res.closed + +iterator trackerCounters*( + loop: PDispatcher + ): tuple[name: string, value: TrackerCounter] = + ## Iterates over `loop` thread dispatcher tracker counter table, returns all + ## the tracker counter's names and values. + doAssert(not(isNil(loop))) + for key, value in loop.counters.pairs(): + yield (key, value) + +iterator trackerCounterKeys*(loop: PDispatcher): string = + doAssert(not(isNil(loop))) + ## Iterates over `loop` thread dispatcher tracker counter table, returns all + ## tracker names. + for key in loop.counters.keys(): + yield key when chronosFutureTracking: iterator pendingFutures*(): FutureBase = diff --git a/chronos/asyncmacro2.nim b/chronos/asyncmacro2.nim index 429e287c..45146a30 100644 --- a/chronos/asyncmacro2.nim +++ b/chronos/asyncmacro2.nim @@ -175,9 +175,25 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} = nnkElseExpr.newTree( newStmtList( quote do: {.push warning[resultshadowed]: off.}, - # var result: `baseType` - nnkVarSection.newTree( - nnkIdentDefs.newTree(ident "result", baseType, newEmptyNode())), + # var result {.used.}: `baseType` + # In the proc body, result may or may not end up being used + # depending on how the body is written - with implicit returns / + # expressions in particular, it is likely but not guaranteed that + # it is not used. Ideally, we would avoid emitting it in this + # case to avoid the default initializaiton. {.used.} typically + # works better than {.push.} which has a tendency to leak out of + # scope. + # TODO figure out if there's a way to detect `result` usage in + # the proc body _after_ template exapnsion, and therefore + # avoid creating this variable - one option is to create an + # addtional when branch witha fake `result` and check + # `compiles(procBody)` - this is not without cost though + nnkVarSection.newTree(nnkIdentDefs.newTree( + nnkPragmaExpr.newTree( + ident "result", + nnkPragma.newTree(ident "used")), + baseType, newEmptyNode()) + ), quote do: {.pop.}, ) ) diff --git a/chronos/asyncproc.nim b/chronos/asyncproc.nim index 8d15b72e..8df8e33e 100644 --- a/chronos/asyncproc.nim +++ b/chronos/asyncproc.nim @@ -23,10 +23,9 @@ const AsyncProcessTrackerName* = "async.process" ## AsyncProcess leaks tracker name - - type - AsyncProcessError* = object of CatchableError + AsyncProcessError* = object of AsyncError + AsyncProcessTimeoutError* = object of AsyncProcessError AsyncProcessResult*[T] = Result[T, OSErrorCode] @@ -109,49 +108,12 @@ type stdError*: string status*: int - AsyncProcessTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 + WaitOperation {.pure.} = enum + Kill, Terminate template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle = ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto) -proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} - -proc getAsyncProcessTracker(): AsyncProcessTracker {.inline.} = - var res = cast[AsyncProcessTracker](getTracker(AsyncProcessTrackerName)) - if isNil(res): - res = setupAsyncProcessTracker() - res - -proc dumpAsyncProcessTracking(): string {.gcsafe.} = - var tracker = getAsyncProcessTracker() - let res = "Started async processes: " & $tracker.opened & "\n" & - "Closed async processes: " & $tracker.closed - res - -proc leakAsyncProccessTracker(): bool {.gcsafe.} = - var tracker = getAsyncProcessTracker() - tracker.opened != tracker.closed - -proc trackAsyncProccess(t: AsyncProcessRef) {.inline.} = - var tracker = getAsyncProcessTracker() - inc(tracker.opened) - -proc untrackAsyncProcess(t: AsyncProcessRef) {.inline.} = - var tracker = getAsyncProcessTracker() - inc(tracker.closed) - -proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} = - var res = AsyncProcessTracker( - opened: 0, - closed: 0, - dump: dumpAsyncProcessTracking, - isLeaked: leakAsyncProccessTracker - ) - addTracker(AsyncProcessTrackerName, res) - res - proc init*(t: typedesc[AsyncFD], handle: ProcessStreamHandle): AsyncFD = case handle.kind of ProcessStreamHandleKind.ProcHandle: @@ -336,6 +298,11 @@ proc raiseAsyncProcessError(msg: string, exc: ref CatchableError = nil) {. msg & " ([" & $exc.name & "]: " & $exc.msg & ")" raise newException(AsyncProcessError, message) +proc raiseAsyncProcessTimeoutError() {. + noreturn, noinit, noinline, raises: [AsyncProcessTimeoutError].} = + let message = "Operation timed out" + raise newException(AsyncProcessTimeoutError, message) + proc raiseAsyncProcessError(msg: string, error: OSErrorCode|cint) {. noreturn, noinit, noinline, raises: [AsyncProcessError].} = when error is OSErrorCode: @@ -502,7 +469,7 @@ when defined(windows): flags: pipes.flags ) - trackAsyncProccess(process) + trackCounter(AsyncProcessTrackerName) return process proc peekProcessExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = @@ -919,7 +886,7 @@ else: flags: pipes.flags ) - trackAsyncProccess(process) + trackCounter(AsyncProcessTrackerName) return process proc peekProcessExitCode(p: AsyncProcessRef, @@ -1231,13 +1198,52 @@ proc closeProcessStreams(pipes: AsyncProcessPipes, res allFutures(pending) +proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation, + timeout = InfiniteDuration): Future[int] {.async.} = + let timerFut = + if timeout == InfiniteDuration: + newFuture[void]("chronos.killAndwaitForExit") + else: + sleepAsync(timeout) + + while true: + if p.running().get(true): + # We ignore operation errors because we going to repeat calling + # operation until process will not exit. + case op + of WaitOperation.Kill: + discard p.kill() + of WaitOperation.Terminate: + discard p.terminate() + else: + let exitCode = p.peekExitCode().valueOr: + raiseAsyncProcessError("Unable to peek process exit code", error) + if not(timerFut.finished()): + await cancelAndWait(timerFut) + return exitCode + + let waitFut = p.waitForExit().wait(100.milliseconds) + discard await race(FutureBase(waitFut), FutureBase(timerFut)) + + if waitFut.finished() and not(waitFut.failed()): + let res = p.peekExitCode() + if res.isOk(): + if not(timerFut.finished()): + await cancelAndWait(timerFut) + return res.get() + + if timerFut.finished(): + if not(waitFut.finished()): + await waitFut.cancelAndWait() + raiseAsyncProcessTimeoutError() + proc closeWait*(p: AsyncProcessRef) {.async.} = # Here we ignore all possible errrors, because we do not want to raise # exceptions. discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0)) await p.pipes.closeProcessStreams(p.options) discard p.closeThreadAndProcessHandle() - untrackAsyncProcess(p) + untrackCounter(AsyncProcessTrackerName) proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter = doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer, @@ -1258,14 +1264,15 @@ proc execCommand*(command: string, options = {AsyncProcessOption.EvalCommand}, timeout = InfiniteDuration ): Future[int] {.async.} = - let poptions = options + {AsyncProcessOption.EvalCommand} - let process = await startProcess(command, options = poptions) - let res = - try: - await process.waitForExit(timeout) - finally: - await process.closeWait() - return res + let + poptions = options + {AsyncProcessOption.EvalCommand} + process = await startProcess(command, options = poptions) + res = + try: + await process.waitForExit(timeout) + finally: + await process.closeWait() + res proc execCommandEx*(command: string, options = {AsyncProcessOption.EvalCommand}, @@ -1298,10 +1305,43 @@ proc execCommandEx*(command: string, finally: await process.closeWait() - return res + res proc pid*(p: AsyncProcessRef): int = ## Returns process ``p`` identifier. int(p.processId) template processId*(p: AsyncProcessRef): int = pid(p) + +proc killAndWaitForExit*(p: AsyncProcessRef, + timeout = InfiniteDuration): Future[int] = + ## Perform continuous attempts to kill the ``p`` process for specified period + ## of time ``timeout``. + ## + ## On Posix systems, killing means sending ``SIGKILL`` to the process ``p``, + ## On Windows, it uses ``TerminateProcess`` to kill the process ``p``. + ## + ## If the process ``p`` fails to be killed within the ``timeout`` time, it + ## will raise ``AsyncProcessTimeoutError``. + ## + ## In case of error this it will raise ``AsyncProcessError``. + ## + ## Returns process ``p`` exit code. + opAndWaitForExit(p, WaitOperation.Kill, timeout) + +proc terminateAndWaitForExit*(p: AsyncProcessRef, + timeout = InfiniteDuration): Future[int] = + ## Perform continuous attempts to terminate the ``p`` process for specified + ## period of time ``timeout``. + ## + ## On Posix systems, terminating means sending ``SIGTERM`` to the process + ## ``p``, on Windows, it uses ``TerminateProcess`` to terminate the process + ## ``p``. + ## + ## If the process ``p`` fails to be terminated within the ``timeout`` time, it + ## will raise ``AsyncProcessTimeoutError``. + ## + ## In case of error this it will raise ``AsyncProcessError``. + ## + ## Returns process ``p`` exit code. + opAndWaitForExit(p, WaitOperation.Terminate, timeout) diff --git a/chronos/config.nim b/chronos/config.nim index 0a439a12..bd6c2b9d 100644 --- a/chronos/config.nim +++ b/chronos/config.nim @@ -49,6 +49,27 @@ when (NimMajor, NimMinor) >= (1, 4): ## using `AsyncProcessOption.EvalCommand` and API calls such as ## ``execCommand(command)`` and ``execCommandEx(command)``. + chronosEventsCount* {.intdefine.} = 64 + ## Number of OS poll events retrieved by syscall (epoll, kqueue, poll). + + chronosInitialSize* {.intdefine.} = 64 + ## Initial size of Selector[T]'s array of file descriptors. + + chronosEventEngine* {.strdefine.}: string = + when defined(linux) and not(defined(android) or defined(emscripten)): + "epoll" + elif defined(macosx) or defined(macos) or defined(ios) or + defined(freebsd) or defined(netbsd) or defined(openbsd) or + defined(dragonfly): + "kqueue" + elif defined(android) or defined(emscripten): + "poll" + elif defined(posix): + "poll" + else: + "" + ## OS polling engine type which is going to be used by chronos. + else: # 1.2 doesn't support `booldefine` in `when` properly const @@ -69,6 +90,21 @@ else: "/system/bin/sh" else: "/bin/sh" + chronosEventsCount*: int = 64 + chronosInitialSize*: int = 64 + chronosEventEngine* {.strdefine.}: string = + when defined(linux) and not(defined(android) or defined(emscripten)): + "epoll" + elif defined(macosx) or defined(macos) or defined(ios) or + defined(freebsd) or defined(netbsd) or defined(openbsd) or + defined(dragonfly): + "kqueue" + elif defined(android) or defined(emscripten): + "poll" + elif defined(posix): + "poll" + else: + "" when defined(debug) or defined(chronosConfig): import std/macros @@ -83,3 +119,6 @@ when defined(debug) or defined(chronosConfig): printOption("chronosFutureTracking", chronosFutureTracking) printOption("chronosDumpAsync", chronosDumpAsync) printOption("chronosProcShell", chronosProcShell) + printOption("chronosEventEngine", chronosEventEngine) + printOption("chronosEventsCount", chronosEventsCount) + printOption("chronosInitialSize", chronosInitialSize) diff --git a/chronos/ioselects/ioselectors_epoll.nim b/chronos/ioselects/ioselectors_epoll.nim index d438bac0..161a5dfb 100644 --- a/chronos/ioselects/ioselectors_epoll.nim +++ b/chronos/ioselects/ioselectors_epoll.nim @@ -97,12 +97,12 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] = var nmask: Sigset if sigemptyset(nmask) < 0: return err(osLastError()) - let epollFd = epoll_create(asyncEventsCount) + let epollFd = epoll_create(chronosEventsCount) if epollFd < 0: return err(osLastError()) let selector = Selector[T]( epollFd: epollFd, - fds: initTable[int32, SelectorKey[T]](asyncInitialSize), + fds: initTable[int32, SelectorKey[T]](chronosInitialSize), signalMask: nmask, virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1 childrenExited: false, @@ -627,7 +627,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, readyKeys: var openArray[ReadyKey] ): SelectResult[int] = var - queueEvents: array[asyncEventsCount, EpollEvent] + queueEvents: array[chronosEventsCount, EpollEvent] k: int = 0 verifySelectParams(timeout, -1, int(high(cint))) @@ -668,7 +668,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, ok(k) proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] = - var res = newSeq[ReadyKey](asyncEventsCount) + var res = newSeq[ReadyKey](chronosEventsCount) let count = ? selectInto2(s, timeout, res) res.setLen(count) ok(res) diff --git a/chronos/ioselects/ioselectors_kqueue.nim b/chronos/ioselects/ioselectors_kqueue.nim index 9f0627aa..e39f9689 100644 --- a/chronos/ioselects/ioselectors_kqueue.nim +++ b/chronos/ioselects/ioselectors_kqueue.nim @@ -110,7 +110,7 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] = let selector = Selector[T]( kqFd: kqFd, - fds: initTable[int32, SelectorKey[T]](asyncInitialSize), + fds: initTable[int32, SelectorKey[T]](chronosInitialSize), virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1 virtualHoles: initDeque[int32]() ) @@ -559,7 +559,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, ): SelectResult[int] = var tv: Timespec - queueEvents: array[asyncEventsCount, KEvent] + queueEvents: array[chronosEventsCount, KEvent] verifySelectParams(timeout, -1, high(int)) @@ -575,7 +575,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, addr tv else: nil - maxEventsCount = cint(min(asyncEventsCount, len(readyKeys))) + maxEventsCount = cint(min(chronosEventsCount, len(readyKeys))) eventsCount = block: var res = 0 @@ -601,7 +601,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, proc select2*[T](s: Selector[T], timeout: int): Result[seq[ReadyKey], OSErrorCode] = - var res = newSeq[ReadyKey](asyncEventsCount) + var res = newSeq[ReadyKey](chronosEventsCount) let count = ? selectInto2(s, timeout, res) res.setLen(count) ok(res) diff --git a/chronos/ioselects/ioselectors_poll.nim b/chronos/ioselects/ioselectors_poll.nim index d0d533cd..25cc0351 100644 --- a/chronos/ioselects/ioselectors_poll.nim +++ b/chronos/ioselects/ioselectors_poll.nim @@ -16,7 +16,7 @@ import stew/base10 type SelectorImpl[T] = object fds: Table[int32, SelectorKey[T]] - pollfds: seq[TPollFd] + pollfds: seq[TPollfd] Selector*[T] = ref SelectorImpl[T] type @@ -50,7 +50,7 @@ proc freeKey[T](s: Selector[T], key: int32) = proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] = let selector = Selector[T]( - fds: initTable[int32, SelectorKey[T]](asyncInitialSize) + fds: initTable[int32, SelectorKey[T]](chronosInitialSize) ) ok(selector) @@ -72,7 +72,7 @@ proc trigger2*(event: SelectEvent): SelectResult[void] = if res == -1: err(osLastError()) elif res != sizeof(uint64): - err(OSErrorCode(osdefs.EINVAL)) + err(osdefs.EINVAL) else: ok() @@ -98,13 +98,14 @@ template toPollEvents(events: set[Event]): cshort = res template pollAdd[T](s: Selector[T], sock: cint, events: set[Event]) = - s.pollfds.add(TPollFd(fd: sock, events: toPollEvents(events), revents: 0)) + s.pollfds.add(TPollfd(fd: sock, events: toPollEvents(events), revents: 0)) template pollUpdate[T](s: Selector[T], sock: cint, events: set[Event]) = var updated = false for mitem in s.pollfds.mitems(): if mitem.fd == sock: mitem.events = toPollEvents(events) + updated = true break if not(updated): raiseAssert "Descriptor [" & $sock & "] is not registered in the queue!" @@ -177,7 +178,6 @@ proc unregister2*[T](s: Selector[T], event: SelectEvent): SelectResult[void] = proc prepareKey[T](s: Selector[T], event: var TPollfd): Opt[ReadyKey] = let - defaultKey = SelectorKey[T](ident: InvalidIdent) fdi32 = int32(event.fd) revents = event.revents @@ -224,7 +224,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, eventsCount = if maxEventsCount > 0: let res = handleEintr(poll(addr(s.pollfds[0]), Tnfds(maxEventsCount), - timeout)) + cint(timeout))) if res < 0: return err(osLastError()) res @@ -241,7 +241,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int, ok(k) proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] = - var res = newSeq[ReadyKey](asyncEventsCount) + var res = newSeq[ReadyKey](chronosEventsCount) let count = ? selectInto2(s, timeout, res) res.setLen(count) ok(res) diff --git a/chronos/osdefs.nim b/chronos/osdefs.nim index 8106fb68..75ceb676 100644 --- a/chronos/osdefs.nim +++ b/chronos/osdefs.nim @@ -237,6 +237,10 @@ when defined(windows): GUID(D1: 0xb5367df0'u32, D2: 0xcbac'u16, D3: 0x11cf'u16, D4: [0x95'u8, 0xca'u8, 0x00'u8, 0x80'u8, 0x5f'u8, 0x48'u8, 0xa1'u8, 0x92'u8]) + WSAID_DISCONNECTEX* = + GUID(D1: 0x7fda2e11'u32, D2: 0x8630'u16, D3: 0x436f'u16, + D4: [0xa0'u8, 0x31'u8, 0xf5'u8, 0x36'u8, + 0xa6'u8, 0xee'u8, 0xc1'u8, 0x57'u8]) GAA_FLAG_INCLUDE_PREFIX* = 0x0010'u32 @@ -497,6 +501,11 @@ when defined(windows): lpTransmitBuffers: pointer, dwReserved: DWORD): WINBOOL {. stdcall, gcsafe, raises: [].} + WSAPROC_DISCONNECTEX* = proc ( + hSocket: SocketHandle, lpOverlapped: POVERLAPPED, dwFlags: DWORD, + dwReserved: DWORD): WINBOOL {. + stdcall, gcsafe, raises: [].} + LPFN_GETQUEUEDCOMPLETIONSTATUSEX* = proc ( completionPort: HANDLE, lpPortEntries: ptr OVERLAPPED_ENTRY, ulCount: ULONG, ulEntriesRemoved: var ULONG, @@ -699,7 +708,7 @@ when defined(windows): res: var ptr AddrInfo): cint {. stdcall, dynlib: "ws2_32", importc: "getaddrinfo", sideEffect.} - proc freeaddrinfo*(ai: ptr AddrInfo) {. + proc freeAddrInfo*(ai: ptr AddrInfo) {. stdcall, dynlib: "ws2_32", importc: "freeaddrinfo", sideEffect.} proc createIoCompletionPort*(fileHandle: HANDLE, @@ -870,16 +879,20 @@ elif defined(macos) or defined(macosx): setrlimit, getpid, pthread_sigmask, sigprocmask, sigemptyset, sigaddset, sigismember, fcntl, accept, pipe, write, signal, read, setsockopt, getsockopt, - getcwd, chdir, waitpid, kill, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, - Sockaddr_un, SocketHandle, AddrInfo, RLimit, + Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, + Suseconds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, - O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR, + O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, + SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, - SIG_BLOCK, SIG_UNBLOCK, + SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, @@ -891,16 +904,20 @@ elif defined(macos) or defined(macosx): setrlimit, getpid, pthread_sigmask, sigprocmask, sigemptyset, sigaddset, sigismember, fcntl, accept, pipe, write, signal, read, setsockopt, getsockopt, - getcwd, chdir, waitpid, kill, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, - Sockaddr_un, SocketHandle, AddrInfo, RLimit, + Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, + Suseconds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, - O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR, + O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, + SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, - SIG_BLOCK, SIG_UNBLOCK, + SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, @@ -912,6 +929,21 @@ elif defined(macos) or defined(macosx): numer*: uint32 denom*: uint32 + TPollfd* {.importc: "struct pollfd", pure, final, + header: "".} = object + fd*: cint + events*: cshort + revents*: cshort + + Tnfds* {.importc: "nfds_t", header: "".} = cuint + + const + POLLIN* = 0x0001 + POLLOUT* = 0x0004 + POLLERR* = 0x0008 + POLLHUP* = 0x0010 + POLLNVAL* = 0x0020 + proc posix_gettimeofday*(tp: var Timeval, unused: pointer = nil) {. importc: "gettimeofday", header: "".} @@ -921,6 +953,9 @@ elif defined(macos) or defined(macosx): proc mach_absolute_time*(): uint64 {. importc, header: "".} + proc poll*(a1: ptr TPollfd, a2: Tnfds, a3: cint): cint {. + importc, header: "", sideEffect.} + elif defined(linux): from std/posix import close, shutdown, sigemptyset, sigaddset, sigismember, sigdelset, write, read, waitid, getaddrinfo, @@ -929,17 +964,22 @@ elif defined(linux): recvfrom, sendto, send, bindSocket, recv, connect, unlink, listen, sendmsg, recvmsg, getpid, fcntl, pthread_sigmask, sigprocmask, clock_gettime, signal, - getcwd, chdir, waitpid, kill, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode, - SigInfo, Id, Tmsghdr, IOVec, RLimit, + SigInfo, Id, Tmsghdr, IOVec, RLimit, Timeval, TFdSet, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle, + Suseconds, TPollfd, Tnfds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK, SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_REUSEADDR, SO_REUSEPORT, + MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, - SOCK_DGRAM, + SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR, + POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, @@ -952,17 +992,22 @@ elif defined(linux): recvfrom, sendto, send, bindSocket, recv, connect, unlink, listen, sendmsg, recvmsg, getpid, fcntl, pthread_sigmask, sigprocmask, clock_gettime, signal, - getcwd, chdir, waitpid, kill, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode, - SigInfo, Id, Tmsghdr, IOVec, RLimit, + SigInfo, Id, Tmsghdr, IOVec, RLimit, TFdSet, Timeval, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle, + Suseconds, TPollfd, Tnfds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK, SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_REUSEADDR, SO_REUSEPORT, + MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, - SOCK_DGRAM, + SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR, + POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, @@ -1001,13 +1046,22 @@ elif defined(linux): EPOLL_CTL_DEL* = 2 EPOLL_CTL_MOD* = 3 + # https://github.com/torvalds/linux/blob/ff6992735ade75aae3e35d16b17da1008d753d28/include/uapi/linux/eventpoll.h#L77 + when defined(linux) and defined(amd64): + {.pragma: epollPacked, packed.} + else: + {.pragma: epollPacked.} + type - EpollData* {.importc: "union epoll_data", - header: "", pure, final.} = object + EpollData* {.importc: "epoll_data_t", + header: "", pure, final, union.} = object + `ptr`* {.importc: "ptr".}: pointer + fd* {.importc: "fd".}: cint + u32* {.importc: "u32".}: uint32 u64* {.importc: "u64".}: uint64 - EpollEvent* {.importc: "struct epoll_event", header: "", - pure, final.} = object + EpollEvent* {.importc: "struct epoll_event", + header: "", pure, final, epollPacked.} = object events*: uint32 # Epoll events data*: EpollData # User data variable @@ -1062,16 +1116,22 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or setrlimit, getpid, pthread_sigmask, sigemptyset, sigaddset, sigismember, fcntl, accept, pipe, write, signal, read, setsockopt, getsockopt, clock_gettime, - getcwd, chdir, waitpid, kill, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, - Sockaddr_un, SocketHandle, AddrInfo, RLimit, + Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, + Suseconds, TPollfd, Tnfds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, - O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR, + O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, + SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC, + SHUT_RD, SHUT_WR, SHUT_RDWR, + POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, @@ -1083,15 +1143,22 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or setrlimit, getpid, pthread_sigmask, sigemptyset, sigaddset, sigismember, fcntl, accept, pipe, write, signal, read, setsockopt, getsockopt, clock_gettime, + getcwd, chdir, waitpid, kill, select, pselect, + socketpair, poll, freeAddrInfo, Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr, SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6, - Sockaddr_un, SocketHandle, AddrInfo, RLimit, + Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet, + Suseconds, TPollfd, Tnfds, + FD_CLR, FD_ISSET, FD_SET, FD_ZERO, F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC, - O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL, - AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR, + O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM, + SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK, + AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR, SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE, SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC, + SHUT_RD, SHUT_WR, SHUT_RDWR, + POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL, SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT, SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2, SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP, diff --git a/chronos/ratelimit.nim b/chronos/ratelimit.nim index 02d80f51..4147db78 100644 --- a/chronos/ratelimit.nim +++ b/chronos/ratelimit.nim @@ -28,13 +28,15 @@ type pendingRequests: seq[BucketWaiter] manuallyReplenished: AsyncEvent -proc update(bucket: TokenBucket) = +proc update(bucket: TokenBucket, currentTime: Moment) = if bucket.fillDuration == default(Duration): bucket.budget = min(bucket.budgetCap, bucket.budget) return + if currentTime < bucket.lastUpdate: + return + let - currentTime = Moment.now() timeDelta = currentTime - bucket.lastUpdate fillPercent = timeDelta.milliseconds.float / bucket.fillDuration.milliseconds.float replenished = @@ -46,7 +48,7 @@ proc update(bucket: TokenBucket) = bucket.lastUpdate += milliseconds(deltaFromReplenished) bucket.budget = min(bucket.budgetCap, bucket.budget + replenished) -proc tryConsume*(bucket: TokenBucket, tokens: int): bool = +proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = ## If `tokens` are available, consume them, ## Otherwhise, return false. @@ -54,7 +56,7 @@ proc tryConsume*(bucket: TokenBucket, tokens: int): bool = bucket.budget -= tokens return true - bucket.update() + bucket.update(now) if bucket.budget >= tokens: bucket.budget -= tokens @@ -93,12 +95,12 @@ proc worker(bucket: TokenBucket) {.async.} = bucket.workFuture = nil -proc consume*(bucket: TokenBucket, tokens: int): Future[void] = +proc consume*(bucket: TokenBucket, tokens: int, now = Moment.now()): Future[void] = ## Wait for `tokens` to be available, and consume them. let retFuture = newFuture[void]("TokenBucket.consume") if isNil(bucket.workFuture) or bucket.workFuture.finished(): - if bucket.tryConsume(tokens): + if bucket.tryConsume(tokens, now): retFuture.complete() return retFuture @@ -119,10 +121,10 @@ proc consume*(bucket: TokenBucket, tokens: int): Future[void] = return retFuture -proc replenish*(bucket: TokenBucket, tokens: int) = +proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = ## Add `tokens` to the budget (capped to the bucket capacity) bucket.budget += tokens - bucket.update() + bucket.update(now) bucket.manuallyReplenished.fire() proc new*( diff --git a/chronos/selectors2.nim b/chronos/selectors2.nim index 45c45330..c5918fdf 100644 --- a/chronos/selectors2.nim +++ b/chronos/selectors2.nim @@ -32,29 +32,9 @@ # backwards-compatible. import stew/results -import osdefs, osutils, oserrno +import config, osdefs, osutils, oserrno export results, oserrno -const - asyncEventsCount* {.intdefine.} = 64 - ## Number of epoll events retrieved by syscall. - asyncInitialSize* {.intdefine.} = 64 - ## Initial size of Selector[T]'s array of file descriptors. - asyncEventEngine* {.strdefine.} = - when defined(linux): - "epoll" - elif defined(macosx) or defined(macos) or defined(ios) or - defined(freebsd) or defined(netbsd) or defined(openbsd) or - defined(dragonfly): - "kqueue" - elif defined(posix): - "poll" - else: - "" - ## Engine type which is going to be used by module. - - hasThreadSupport = compileOption("threads") - when defined(nimdoc): type @@ -281,7 +261,9 @@ else: var err = newException(IOSelectorsException, msg) raise err - when asyncEventEngine in ["epoll", "kqueue"]: + when chronosEventEngine in ["epoll", "kqueue"]: + const hasThreadSupport = compileOption("threads") + proc blockSignals(newmask: Sigset, oldmask: var Sigset): Result[void, OSErrorCode] = var nmask = newmask @@ -324,11 +306,11 @@ else: doAssert((timeout >= min) and (timeout <= max), "Cannot select with incorrect timeout value, got " & $timeout) -when asyncEventEngine == "epoll": +when chronosEventEngine == "epoll": include ./ioselects/ioselectors_epoll -elif asyncEventEngine == "kqueue": +elif chronosEventEngine == "kqueue": include ./ioselects/ioselectors_kqueue -elif asyncEventEngine == "poll": +elif chronosEventEngine == "poll": include ./ioselects/ioselectors_poll else: - {.fatal: "Event engine `" & asyncEventEngine & "` is not supported!".} + {.fatal: "Event engine `" & chronosEventEngine & "` is not supported!".} diff --git a/chronos/sendfile.nim b/chronos/sendfile.nim index 8cba9e83..7afcb738 100644 --- a/chronos/sendfile.nim +++ b/chronos/sendfile.nim @@ -38,8 +38,12 @@ when defined(nimdoc): ## be prepared to retry the call if there were unsent bytes. ## ## On error, ``-1`` is returned. +elif defined(emscripten): -elif defined(linux) or defined(android): + proc sendfile*(outfd, infd: int, offset: int, count: var int): int = + raiseAssert "sendfile() is not implemented yet" + +elif (defined(linux) or defined(android)) and not(defined(emscripten)): proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int {.importc: "sendfile", header: "".} diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 8af35009..79809ba9 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -96,10 +96,6 @@ type reader*: AsyncStreamReader writer*: AsyncStreamWriter - AsyncStreamTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = @@ -332,79 +328,6 @@ template checkStreamClosed*(t: untyped) = template checkStreamFinished*(t: untyped) = if t.atEof(): raiseAsyncStreamWriteEOFError() -proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {. - gcsafe, raises: [].} -proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {. - gcsafe, raises: [].} - -proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} = - var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) - if isNil(res): - res = setupAsyncStreamReaderTracker() - res - -proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} = - var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) - if isNil(res): - res = setupAsyncStreamWriterTracker() - res - -proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} = - var tracker = getAsyncStreamReaderTracker() - let res = "Opened async stream readers: " & $tracker.opened & "\n" & - "Closed async stream readers: " & $tracker.closed - res - -proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} = - var tracker = getAsyncStreamWriterTracker() - let res = "Opened async stream writers: " & $tracker.opened & "\n" & - "Closed async stream writers: " & $tracker.closed - res - -proc leakAsyncStreamReader(): bool {.gcsafe.} = - var tracker = getAsyncStreamReaderTracker() - tracker.opened != tracker.closed - -proc leakAsyncStreamWriter(): bool {.gcsafe.} = - var tracker = getAsyncStreamWriterTracker() - tracker.opened != tracker.closed - -proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} = - var tracker = getAsyncStreamReaderTracker() - inc(tracker.opened) - -proc untrackAsyncStreamReader*(t: AsyncStreamReader) {.inline.} = - var tracker = getAsyncStreamReaderTracker() - inc(tracker.closed) - -proc trackAsyncStreamWriter(t: AsyncStreamWriter) {.inline.} = - var tracker = getAsyncStreamWriterTracker() - inc(tracker.opened) - -proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} = - var tracker = getAsyncStreamWriterTracker() - inc(tracker.closed) - -proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} = - var res = AsyncStreamTracker( - opened: 0, - closed: 0, - dump: dumpAsyncStreamReaderTracking, - isLeaked: leakAsyncStreamReader - ) - addTracker(AsyncStreamReaderTrackerName, res) - res - -proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} = - var res = AsyncStreamTracker( - opened: 0, - closed: 0, - dump: dumpAsyncStreamWriterTracking, - isLeaked: leakAsyncStreamWriter - ) - addTracker(AsyncStreamWriterTrackerName, res) - res - template readLoop(body: untyped): untyped = while true: if rstream.buffer.dataLen() == 0: @@ -953,10 +876,10 @@ proc join*(rw: AsyncStreamRW): Future[void] = else: var retFuture = newFuture[void]("async.stream.writer.join") - proc continuation(udata: pointer) {.gcsafe.} = + proc continuation(udata: pointer) {.gcsafe, raises:[].} = retFuture.complete() - proc cancellation(udata: pointer) {.gcsafe.} = + proc cancellation(udata: pointer) {.gcsafe, raises:[].} = rw.future.removeCallback(continuation, cast[pointer](retFuture)) if not(rw.future.finished()): @@ -980,9 +903,9 @@ proc close*(rw: AsyncStreamRW) = if not(rw.future.finished()): rw.future.complete() when rw is AsyncStreamReader: - untrackAsyncStreamReader(rw) + untrackCounter(AsyncStreamReaderTrackerName) elif rw is AsyncStreamWriter: - untrackAsyncStreamWriter(rw) + untrackCounter(AsyncStreamWriterTrackerName) rw.state = AsyncStreamState.Closed when rw is AsyncStreamReader: @@ -1031,7 +954,7 @@ proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, child.wsource = wsource child.tsource = wsource.tsource child.queue = newAsyncQueue[WriteItem](queueSize) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1045,7 +968,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -1056,7 +979,7 @@ proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop, child.rsource = rsource child.tsource = rsource.tsource child.buffer = AsyncBuffer.init(bufferSize) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -1071,7 +994,7 @@ proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*(child: AsyncStreamWriter, tsource: StreamTransport) = @@ -1080,7 +1003,7 @@ proc init*(child: AsyncStreamWriter, tsource: StreamTransport) = child.writerLoop = nil child.wsource = nil child.tsource = tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport, @@ -1090,7 +1013,7 @@ proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport, child.writerLoop = nil child.wsource = nil child.tsource = tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child, wsource: AsyncStreamWriter) = @@ -1099,7 +1022,7 @@ proc init*(child, wsource: AsyncStreamWriter) = child.writerLoop = nil child.wsource = wsource child.tsource = wsource.tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) = @@ -1111,7 +1034,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) = if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child: AsyncStreamReader, tsource: StreamTransport) = @@ -1120,7 +1043,7 @@ proc init*(child: AsyncStreamReader, tsource: StreamTransport) = child.readerLoop = nil child.rsource = nil child.tsource = tsource - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child: AsyncStreamReader, tsource: StreamTransport, @@ -1133,7 +1056,7 @@ proc init*[T](child: AsyncStreamReader, tsource: StreamTransport, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*(child, rsource: AsyncStreamReader) = @@ -1142,7 +1065,7 @@ proc init*(child, rsource: AsyncStreamReader) = child.readerLoop = nil child.rsource = rsource child.tsource = rsource.tsource - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) = @@ -1154,7 +1077,7 @@ proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) = if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc newAsyncStreamReader*[T](rsource: AsyncStreamReader, diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index ceacaff7..2999f7af 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -95,6 +95,7 @@ type trustAnchors: TrustAnchorStore SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream + SomeTrustAnchorType* = TrustAnchorStore | openArray[X509TrustAnchor] TLSStreamError* = object of AsyncStreamError TLSStreamHandshakeError* = object of TLSStreamError @@ -139,12 +140,14 @@ proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} = raise newTLSStreamProtocolImpl(message) -proc new*(T: typedesc[TrustAnchorStore], anchors: openArray[X509TrustAnchor]): TrustAnchorStore = +proc new*(T: typedesc[TrustAnchorStore], + anchors: openArray[X509TrustAnchor]): TrustAnchorStore = var res: seq[X509TrustAnchor] for anchor in anchors: res.add(anchor) - doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]), "Anchors should be copied") - return TrustAnchorStore(anchors: res) + doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]), + "Anchors should be copied") + TrustAnchorStore(anchors: res) proc tlsWriteRec(engine: ptr SslEngineContext, writer: TLSStreamWriter): Future[TLSResult] {.async.} = @@ -453,15 +456,16 @@ proc getSignerAlgo(xc: X509Certificate): int = else: int(x509DecoderGetSignerKeyType(dc)) -proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, - wsource: AsyncStreamWriter, - serverName: string, - bufferSize = SSL_BUFSIZE_BIDI, - minVersion = TLSVersion.TLS12, - maxVersion = TLSVersion.TLS12, - flags: set[TLSFlags] = {}, - trustAnchors: TrustAnchorStore | openArray[X509TrustAnchor] = MozillaTrustAnchors - ): TLSAsyncStream = +proc newTLSClientAsyncStream*( + rsource: AsyncStreamReader, + wsource: AsyncStreamWriter, + serverName: string, + bufferSize = SSL_BUFSIZE_BIDI, + minVersion = TLSVersion.TLS12, + maxVersion = TLSVersion.TLS12, + flags: set[TLSFlags] = {}, + trustAnchors: SomeTrustAnchorType = MozillaTrustAnchors + ): TLSAsyncStream = ## Create new TLS asynchronous stream for outbound (client) connections ## using reading stream ``rsource`` and writing stream ``wsource``. ## @@ -484,7 +488,8 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, ## a ``TrustAnchorStore`` you should reuse the same instance for ## every call to avoid making a copy of the trust anchors per call. when trustAnchors is TrustAnchorStore: - doAssert(len(trustAnchors.anchors) > 0, "Empty trust anchor list is invalid") + doAssert(len(trustAnchors.anchors) > 0, + "Empty trust anchor list is invalid") else: doAssert(len(trustAnchors) > 0, "Empty trust anchor list is invalid") var res = TLSAsyncStream() @@ -524,7 +529,7 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader, uint16(maxVersion)) if TLSFlags.NoVerifyServerName in flags: - let err = sslClientReset(res.ccontext, "", 0) + let err = sslClientReset(res.ccontext, nil, 0) if err == 0: raise newException(TLSStreamInitError, "Could not initialize TLS layer") else: diff --git a/chronos/threadsync.nim b/chronos/threadsync.nim new file mode 100644 index 00000000..d4141812 --- /dev/null +++ b/chronos/threadsync.nim @@ -0,0 +1,416 @@ +# +# Chronos multithreaded synchronization primitives +# +# (c) Copyright 2023-Present Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) + +## This module implements some core async thread synchronization primitives. +import stew/results +import "."/[timer, asyncloop] + +export results + +{.push raises: [].} + +const hasThreadSupport* = compileOption("threads") +when not(hasThreadSupport): + {.fatal: "Compile this program with threads enabled!".} + +import "."/[osdefs, osutils, oserrno] + +type + ThreadSignal* = object + when defined(windows): + event: HANDLE + elif defined(linux): + efd: AsyncFD + else: + rfd, wfd: AsyncFD + + ThreadSignalPtr* = ptr ThreadSignal + +proc new*(t: typedesc[ThreadSignalPtr]): Result[ThreadSignalPtr, string] = + ## Create new ThreadSignal object. + let res = cast[ptr ThreadSignal](allocShared0(sizeof(ThreadSignal))) + when defined(windows): + var sa = getSecurityAttributes() + let event = osdefs.createEvent(addr sa, DWORD(0), DWORD(0), nil) + if event == HANDLE(0): + deallocShared(res) + return err(osErrorMsg(osLastError())) + res[] = ThreadSignal(event: event) + elif defined(linux): + let efd = eventfd(0, EFD_CLOEXEC or EFD_NONBLOCK) + if efd == -1: + deallocShared(res) + return err(osErrorMsg(osLastError())) + res[] = ThreadSignal(efd: AsyncFD(efd)) + else: + var sockets: array[2, cint] + block: + let sres = socketpair(AF_UNIX, SOCK_DGRAM, 0, sockets) + if sres < 0: + deallocShared(res) + return err(osErrorMsg(osLastError())) + # MacOS do not have SOCK_NONBLOCK and SOCK_CLOEXEC, so we forced to use + # setDescriptorFlags() for every socket. + block: + let sres = setDescriptorFlags(sockets[0], true, true) + if sres.isErr(): + discard closeFd(sockets[0]) + discard closeFd(sockets[1]) + deallocShared(res) + return err(osErrorMsg(sres.error)) + block: + let sres = setDescriptorFlags(sockets[1], true, true) + if sres.isErr(): + discard closeFd(sockets[0]) + discard closeFd(sockets[1]) + deallocShared(res) + return err(osErrorMsg(sres.error)) + res[] = ThreadSignal(rfd: AsyncFD(sockets[0]), wfd: AsyncFD(sockets[1])) + ok(ThreadSignalPtr(res)) + +when not(defined(windows)): + type + WaitKind {.pure.} = enum + Read, Write + + when defined(linux): + proc checkBusy(fd: cint): bool = false + else: + proc checkBusy(fd: cint): bool = + var data = 0'u64 + let res = handleEintr(recv(SocketHandle(fd), + addr data, sizeof(uint64), MSG_PEEK)) + if res == sizeof(uint64): + true + else: + false + + func toTimeval(a: Duration): Timeval = + ## Convert Duration ``a`` to ``Timeval`` object. + let nanos = a.nanoseconds + let m = nanos mod Second.nanoseconds() + Timeval( + tv_sec: Time(nanos div Second.nanoseconds()), + tv_usec: Suseconds(m div Microsecond.nanoseconds()) + ) + + proc waitReady(fd: cint, kind: WaitKind, + timeout: Duration): Result[bool, OSErrorCode] = + var + tv: Timeval + fdset = + block: + var res: TFdSet + FD_ZERO(res) + FD_SET(SocketHandle(fd), res) + res + + let + ptv = + if not(timeout.isInfinite()): + tv = timeout.toTimeval() + addr tv + else: + nil + nfd = cint(fd) + 1 + res = + case kind + of WaitKind.Read: + handleEintr(select(nfd, addr fdset, nil, nil, ptv)) + of WaitKind.Write: + handleEintr(select(nfd, nil, addr fdset, nil, ptv)) + + if res > 0: + ok(true) + elif res == 0: + ok(false) + else: + err(osLastError()) + + proc safeUnregisterAndCloseFd(fd: AsyncFD): Result[void, OSErrorCode] = + let loop = getThreadDispatcher() + if loop.contains(fd): + ? unregister2(fd) + if closeFd(cint(fd)) != 0: + err(osLastError()) + else: + ok() + +proc close*(signal: ThreadSignalPtr): Result[void, string] = + ## Close ThreadSignal object and free all the resources. + defer: deallocShared(signal) + when defined(windows): + # We do not need to perform unregistering on Windows, we can only close it. + if closeHandle(signal[].event) == 0'u32: + return err(osErrorMsg(osLastError())) + elif defined(linux): + let res = safeUnregisterAndCloseFd(signal[].efd) + if res.isErr(): + return err(osErrorMsg(res.error)) + else: + let res1 = safeUnregisterAndCloseFd(signal[].rfd) + let res2 = safeUnregisterAndCloseFd(signal[].wfd) + if res1.isErr(): return err(osErrorMsg(res1.error)) + if res2.isErr(): return err(osErrorMsg(res2.error)) + ok() + +proc fireSync*(signal: ThreadSignalPtr, + timeout = InfiniteDuration): Result[bool, string] = + ## Set state of ``signal`` to signaled state in blocking way. + ## + ## Returns ``false`` if signal was not signalled in time, and ``true`` + ## if operation was successful. + when defined(windows): + if setEvent(signal[].event) == 0'u32: + return err(osErrorMsg(osLastError())) + ok(true) + else: + let + eventFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].wfd) + checkFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].rfd) + + if checkBusy(checkFd): + # Signal is already in signalled state + return ok(true) + + var data = 1'u64 + while true: + let res = + when defined(linux): + handleEintr(write(eventFd, addr data, sizeof(uint64))) + else: + handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64), + MSG_NOSIGNAL)) + if res < 0: + let errorCode = osLastError() + case errorCode + of EAGAIN: + let wres = waitReady(eventFd, WaitKind.Write, timeout) + if wres.isErr(): + return err(osErrorMsg(wres.error)) + if not(wres.get()): + return ok(false) + else: + return err(osErrorMsg(errorCode)) + elif res != sizeof(data): + return err(osErrorMsg(EINVAL)) + else: + return ok(true) + +proc waitSync*(signal: ThreadSignalPtr, + timeout = InfiniteDuration): Result[bool, string] = + ## Wait until the signal become signaled. This procedure is ``NOT`` async, + ## so it blocks execution flow, but this procedure do not need asynchronous + ## event loop to be present. + when defined(windows): + let + timeoutWin = + if timeout.isInfinite(): + INFINITE + else: + DWORD(timeout.milliseconds()) + handle = signal[].event + res = waitForSingleObject(handle, timeoutWin) + if res == WAIT_OBJECT_0: + ok(true) + elif res == WAIT_TIMEOUT: + ok(false) + elif res == WAIT_ABANDONED: + err("The wait operation has been abandoned") + else: + err("The wait operation has been failed") + else: + let eventFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].rfd) + var + data = 0'u64 + timer = timeout + while true: + let wres = + block: + let + start = Moment.now() + res = waitReady(eventFd, WaitKind.Read, timer) + timer = timer - (Moment.now() - start) + res + if wres.isErr(): + return err(osErrorMsg(wres.error)) + if not(wres.get()): + return ok(false) + let res = + when defined(linux): + handleEintr(read(eventFd, addr data, sizeof(uint64))) + else: + handleEintr(recv(SocketHandle(eventFd), addr data, sizeof(uint64), + cint(0))) + if res < 0: + let errorCode = osLastError() + # If errorCode == EAGAIN it means that reading operation is already + # pending and so some other consumer reading eventfd or pipe end, in + # this case we going to ignore error and wait for another event. + if errorCode != EAGAIN: + return err(osErrorMsg(errorCode)) + elif res != sizeof(data): + return err(osErrorMsg(EINVAL)) + else: + return ok(true) + +proc fire*(signal: ThreadSignalPtr): Future[void] = + ## Set state of ``signal`` to signaled in asynchronous way. + var retFuture = newFuture[void]("asyncthreadsignal.fire") + when defined(windows): + if setEvent(signal[].event) == 0'u32: + retFuture.fail(newException(AsyncError, osErrorMsg(osLastError()))) + else: + retFuture.complete() + else: + var data = 1'u64 + let + eventFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].wfd) + checkFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].rfd) + + proc continuation(udata: pointer) {.gcsafe, raises: [].} = + if not(retFuture.finished()): + let res = + when defined(linux): + handleEintr(write(eventFd, addr data, sizeof(uint64))) + else: + handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64), + MSG_NOSIGNAL)) + if res < 0: + let errorCode = osLastError() + discard removeWriter2(AsyncFD(eventFd)) + retFuture.fail(newException(AsyncError, osErrorMsg(errorCode))) + elif res != sizeof(data): + discard removeWriter2(AsyncFD(eventFd)) + retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL))) + else: + let eres = removeWriter2(AsyncFD(eventFd)) + if eres.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(eres.error))) + else: + retFuture.complete() + + proc cancellation(udata: pointer) {.gcsafe, raises: [].} = + if not(retFuture.finished()): + discard removeWriter2(AsyncFD(eventFd)) + + if checkBusy(checkFd): + # Signal is already in signalled state + retFuture.complete() + return retFuture + + let res = + when defined(linux): + handleEintr(write(eventFd, addr data, sizeof(uint64))) + else: + handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64), + MSG_NOSIGNAL)) + if res < 0: + let errorCode = osLastError() + case errorCode + of EAGAIN: + let loop = getThreadDispatcher() + if not(loop.contains(AsyncFD(eventFd))): + let rres = register2(AsyncFD(eventFd)) + if rres.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(rres.error))) + return retFuture + let wres = addWriter2(AsyncFD(eventFd), continuation) + if wres.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(wres.error))) + else: + retFuture.cancelCallback = cancellation + else: + retFuture.fail(newException(AsyncError, osErrorMsg(errorCode))) + elif res != sizeof(data): + retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL))) + else: + retFuture.complete() + + retFuture + +when defined(windows): + proc wait*(signal: ThreadSignalPtr) {.async.} = + let handle = signal[].event + let res = await waitForSingleObject(handle, InfiniteDuration) + # There should be no other response, because we use `InfiniteDuration`. + doAssert(res == WaitableResult.Ok) +else: + proc wait*(signal: ThreadSignalPtr): Future[void] = + var retFuture = newFuture[void]("asyncthreadsignal.wait") + var data = 1'u64 + let eventFd = + when defined(linux): + cint(signal[].efd) + else: + cint(signal[].rfd) + + proc continuation(udata: pointer) {.gcsafe, raises: [].} = + if not(retFuture.finished()): + let res = + when defined(linux): + handleEintr(read(eventFd, addr data, sizeof(uint64))) + else: + handleEintr(recv(SocketHandle(eventFd), addr data, sizeof(uint64), + cint(0))) + if res < 0: + let errorCode = osLastError() + # If errorCode == EAGAIN it means that reading operation is already + # pending and so some other consumer reading eventfd or pipe end, in + # this case we going to ignore error and wait for another event. + if errorCode != EAGAIN: + discard removeReader2(AsyncFD(eventFd)) + retFuture.fail(newException(AsyncError, osErrorMsg(errorCode))) + elif res != sizeof(data): + discard removeReader2(AsyncFD(eventFd)) + retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL))) + else: + let eres = removeReader2(AsyncFD(eventFd)) + if eres.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(eres.error))) + else: + retFuture.complete() + + proc cancellation(udata: pointer) {.gcsafe, raises: [].} = + if not(retFuture.finished()): + # Future is already cancelled so we ignore errors. + discard removeReader2(AsyncFD(eventFd)) + + let loop = getThreadDispatcher() + if not(loop.contains(AsyncFD(eventFd))): + let res = register2(AsyncFD(eventFd)) + if res.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(res.error))) + return retFuture + let res = addReader2(AsyncFD(eventFd), continuation) + if res.isErr(): + retFuture.fail(newException(AsyncError, osErrorMsg(res.error))) + return retFuture + retFuture.cancelCallback = cancellation + retFuture diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index cbec5d6f..4b4be7de 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -298,6 +298,9 @@ proc getAddrInfo(address: string, port: Port, domain: Domain, raises: [TransportAddressError].} = ## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in ## ``net.nim:getAddrInfo()``, which is not cross-platform. + ## + ## Warning: `ptr AddrInfo` returned by `getAddrInfo()` needs to be freed by + ## calling `freeAddrInfo()`. var hints: AddrInfo var res: ptr AddrInfo = nil hints.ai_family = toInt(domain) @@ -420,6 +423,7 @@ proc resolveTAddress*(address: string, port: Port, if ta notin res: res.add(ta) it = it.ai_next + freeAddrInfo(aiList) res proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {. @@ -574,10 +578,8 @@ template getTransportUseClosedError*(): ref TransportUseClosedError = newException(TransportUseClosedError, "Transport is already closed!") template getTransportOsError*(err: OSErrorCode): ref TransportOsError = - var msg = "(" & $int(err) & ") " & osErrorMsg(err) - var tre = newException(TransportOsError, msg) - tre.code = err - tre + (ref TransportOsError)( + code: err, msg: "(" & $int(err) & ") " & osErrorMsg(err)) template getTransportOsError*(err: cint): ref TransportOsError = getTransportOsError(OSErrorCode(err)) @@ -608,15 +610,16 @@ template getTransportTooManyError*( ): ref TransportTooManyError = let msg = when defined(posix): - if code == OSErrorCode(0): + case code + of OSErrorCode(0): "Too many open transports" - elif code == oserrno.EMFILE: + of EMFILE: "[EMFILE] Too many open files in the process" - elif code == oserrno.ENFILE: + of ENFILE: "[ENFILE] Too many open files in system" - elif code == oserrno.ENOBUFS: + of ENOBUFS: "[ENOBUFS] No buffer space available" - elif code == oserrno.ENOMEM: + of ENOMEM: "[ENOMEM] Not enough memory availble" else: "[" & $int(code) & "] Too many open transports" @@ -649,23 +652,26 @@ template getConnectionAbortedError*( ): ref TransportAbortedError = let msg = when defined(posix): - if code == OSErrorCode(0): + case code + of OSErrorCode(0), ECONNABORTED: "[ECONNABORTED] Connection has been aborted before being accepted" - elif code == oserrno.EPERM: + of EPERM: "[EPERM] Firewall rules forbid connection" - elif code == oserrno.ETIMEDOUT: + of ETIMEDOUT: "[ETIMEDOUT] Operation has been timed out" + of ENOTCONN: + "[ENOTCONN] Transport endpoint is not connected" else: "[" & $int(code) & "] Connection has been aborted" elif defined(windows): case code - of OSErrorCode(0), oserrno.WSAECONNABORTED: + of OSErrorCode(0), WSAECONNABORTED: "[ECONNABORTED] Connection has been aborted before being accepted" of WSAENETDOWN: "[ENETDOWN] Network is down" - of oserrno.WSAENETRESET: + of WSAENETRESET: "[ENETRESET] Network dropped connection on reset" - of oserrno.WSAECONNRESET: + of WSAECONNRESET: "[ECONNRESET] Connection reset by peer" of WSAETIMEDOUT: "[ETIMEDOUT] Connection timed out" @@ -675,3 +681,42 @@ template getConnectionAbortedError*( "[" & $int(code) & "] Connection has been aborted" newException(TransportAbortedError, msg) + +template getTransportError*(ecode: OSErrorCode): untyped = + when defined(posix): + case ecode + of ECONNABORTED, EPERM, ETIMEDOUT, ENOTCONN: + getConnectionAbortedError(ecode) + of EMFILE, ENFILE, ENOBUFS, ENOMEM: + getTransportTooManyError(ecode) + else: + getTransportOsError(ecode) + else: + case ecode + of WSAECONNABORTED, WSAENETDOWN, WSAENETRESET, WSAECONNRESET, WSAETIMEDOUT: + getConnectionAbortedError(ecode) + of ERROR_TOO_MANY_OPEN_FILES, WSAENOBUFS, WSAEMFILE: + getTransportTooManyError(ecode) + else: + getTransportOsError(ecode) + +proc raiseTransportError*(ecode: OSErrorCode) {. + raises: [TransportAbortedError, TransportTooManyError, TransportOsError], + noreturn.} = + ## Raises transport specific OS error. + when defined(posix): + case ecode + of ECONNABORTED, EPERM, ETIMEDOUT, ENOTCONN: + raise getConnectionAbortedError(ecode) + of EMFILE, ENFILE, ENOBUFS, ENOMEM: + raise getTransportTooManyError(ecode) + else: + raise getTransportOsError(ecode) + else: + case ecode + of WSAECONNABORTED, WSAENETDOWN, WSAENETRESET, WSAECONNRESET, WSAETIMEDOUT: + raise getConnectionAbortedError(ecode) + of ERROR_TOO_MANY_OPEN_FILES, WSAENOBUFS, WSAEMFILE: + raise getTransportTooManyError(ecode) + else: + raise getTransportOsError(ecode) diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index 91a7e7a0..665bc0ed 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -53,10 +53,6 @@ type rwsabuf: WSABUF # Reader WSABUF structure wwsabuf: WSABUF # Writer WSABUF structure - DgramTransportTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - const DgramTransportTrackerName* = "datagram.transport" @@ -88,39 +84,6 @@ template setReadError(t, e: untyped) = (t).state.incl(ReadError) (t).error = getTransportOsError(e) -proc setupDgramTransportTracker(): DgramTransportTracker {. - gcsafe, raises: [].} - -proc getDgramTransportTracker(): DgramTransportTracker {.inline.} = - var res = cast[DgramTransportTracker](getTracker(DgramTransportTrackerName)) - if isNil(res): - res = setupDgramTransportTracker() - doAssert(not(isNil(res))) - res - -proc dumpTransportTracking(): string {.gcsafe.} = - var tracker = getDgramTransportTracker() - "Opened transports: " & $tracker.opened & "\n" & - "Closed transports: " & $tracker.closed - -proc leakTransport(): bool {.gcsafe.} = - let tracker = getDgramTransportTracker() - tracker.opened != tracker.closed - -proc trackDgram(t: DatagramTransport) {.inline.} = - var tracker = getDgramTransportTracker() - inc(tracker.opened) - -proc untrackDgram(t: DatagramTransport) {.inline.} = - var tracker = getDgramTransportTracker() - inc(tracker.closed) - -proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} = - let res = DgramTransportTracker( - opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport) - addTracker(DgramTransportTrackerName, res) - res - when defined(windows): template setWriterWSABuffer(t, v: untyped) = (t).wwsabuf.buf = cast[cstring](v.buf) @@ -213,7 +176,7 @@ when defined(windows): transp.state.incl(ReadPaused) if ReadClosed in transp.state and not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) # If `ReadClosed` present, then close(transport) was called. transp.future.complete() GC_unref(transp) @@ -259,7 +222,7 @@ when defined(windows): # WSARecvFrom session. if ReadClosed in transp.state and not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) transp.future.complete() GC_unref(transp) break @@ -394,7 +357,7 @@ when defined(windows): len: ULONG(len(res.buffer))) GC_ref(res) # Start tracking transport - trackDgram(res) + trackCounter(DgramTransportTrackerName) if NoAutoRead notin flags: let rres = res.resumeRead() if rres.isErr(): raiseTransportOsError(rres.error()) @@ -503,11 +466,11 @@ else: var res = if isNil(child): DatagramTransport() else: child if sock == asyncInvalidSocket: - var proto = Protocol.IPPROTO_UDP - if local.family == AddressFamily.Unix: - # `Protocol` enum is missing `0` value, so we making here cast, until - # `Protocol` enum will not support IPPROTO_IP == 0. - proto = cast[Protocol](0) + let proto = + if local.family == AddressFamily.Unix: + Protocol.IPPROTO_IP + else: + Protocol.IPPROTO_UDP localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM, proto) if localSock == asyncInvalidSocket: @@ -592,7 +555,7 @@ else: res.future = newFuture[void]("datagram.transport") GC_ref(res) # Start tracking transport - trackDgram(res) + trackCounter(DgramTransportTrackerName) if NoAutoRead notin flags: let rres = res.resumeRead() if rres.isErr(): raiseTransportOsError(rres.error()) @@ -603,7 +566,7 @@ proc close*(transp: DatagramTransport) = proc continuation(udata: pointer) {.raises: [].} = if not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) transp.future.complete() GC_unref(transp) diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index 173bc69a..07574311 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -54,15 +54,6 @@ type ReuseAddr, ReusePort - - StreamTransportTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - - StreamServerTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - ReadMessagePredicate* = proc (data: openArray[byte]): tuple[consumed: int, done: bool] {. gcsafe, raises: [].} @@ -70,6 +61,7 @@ type const StreamTransportTrackerName* = "stream.transport" StreamServerTrackerName* = "stream.server" + DefaultBacklogSize* = high(int32) when defined(windows): type @@ -141,30 +133,28 @@ type # transport for new client proc remoteAddress*(transp: StreamTransport): TransportAddress {. - raises: [TransportError].} = + raises: [TransportAbortedError, TransportTooManyError, TransportOsError].} = ## Returns ``transp`` remote socket address. - if transp.kind != TransportKind.Socket: - raise newException(TransportError, "Socket required!") + doAssert(transp.kind == TransportKind.Socket, "Socket transport required!") if transp.remote.family == AddressFamily.None: var saddr: Sockaddr_storage var slen = SockLen(sizeof(saddr)) if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: - raiseTransportOsError(osLastError()) + raiseTransportError(osLastError()) fromSAddr(addr saddr, slen, transp.remote) transp.remote proc localAddress*(transp: StreamTransport): TransportAddress {. - raises: [TransportError].} = + raises: [TransportAbortedError, TransportTooManyError, TransportOsError].} = ## Returns ``transp`` local socket address. - if transp.kind != TransportKind.Socket: - raise newException(TransportError, "Socket required!") + doAssert(transp.kind == TransportKind.Socket, "Socket transport required!") if transp.local.family == AddressFamily.None: var saddr: Sockaddr_storage var slen = SockLen(sizeof(saddr)) if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: - raiseTransportOsError(osLastError()) + raiseTransportError(osLastError()) fromSAddr(addr saddr, slen, transp.local) transp.local @@ -201,71 +191,6 @@ template shiftVectorFile(v: var StreamVector, o: untyped) = (v).buf = cast[pointer](cast[uint]((v).buf) - uint(o)) (v).offset += uint(o) -proc setupStreamTransportTracker(): StreamTransportTracker {. - gcsafe, raises: [].} -proc setupStreamServerTracker(): StreamServerTracker {. - gcsafe, raises: [].} - -proc getStreamTransportTracker(): StreamTransportTracker {.inline.} = - var res = cast[StreamTransportTracker](getTracker(StreamTransportTrackerName)) - if isNil(res): - res = setupStreamTransportTracker() - doAssert(not(isNil(res))) - res - -proc getStreamServerTracker(): StreamServerTracker {.inline.} = - var res = cast[StreamServerTracker](getTracker(StreamServerTrackerName)) - if isNil(res): - res = setupStreamServerTracker() - doAssert(not(isNil(res))) - res - -proc dumpTransportTracking(): string {.gcsafe.} = - var tracker = getStreamTransportTracker() - "Opened transports: " & $tracker.opened & "\n" & - "Closed transports: " & $tracker.closed - -proc dumpServerTracking(): string {.gcsafe.} = - var tracker = getStreamServerTracker() - "Opened servers: " & $tracker.opened & "\n" & - "Closed servers: " & $tracker.closed - -proc leakTransport(): bool {.gcsafe.} = - var tracker = getStreamTransportTracker() - tracker.opened != tracker.closed - -proc leakServer(): bool {.gcsafe.} = - var tracker = getStreamServerTracker() - tracker.opened != tracker.closed - -proc trackStream(t: StreamTransport) {.inline.} = - var tracker = getStreamTransportTracker() - inc(tracker.opened) - -proc untrackStream(t: StreamTransport) {.inline.} = - var tracker = getStreamTransportTracker() - inc(tracker.closed) - -proc trackServer(s: StreamServer) {.inline.} = - var tracker = getStreamServerTracker() - inc(tracker.opened) - -proc untrackServer(s: StreamServer) {.inline.} = - var tracker = getStreamServerTracker() - inc(tracker.closed) - -proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} = - let res = StreamTransportTracker( - opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport) - addTracker(StreamTransportTrackerName, res) - res - -proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} = - let res = StreamServerTracker( - opened: 0, closed: 0, dump: dumpServerTracking, isLeaked: leakServer) - addTracker(StreamServerTrackerName, res) - res - proc completePendingWriteQueue(queue: var Deque[StreamVector], v: int) {.inline.} = while len(queue) > 0: @@ -282,7 +207,7 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector], proc clean(server: StreamServer) {.inline.} = if not(server.loopFuture.finished()): - untrackServer(server) + untrackCounter(StreamServerTrackerName) server.loopFuture.complete() if not(isNil(server.udata)) and (GCUserData in server.flags): GC_unref(cast[ref int](server.udata)) @@ -290,7 +215,7 @@ proc clean(server: StreamServer) {.inline.} = proc clean(transp: StreamTransport) {.inline.} = if not(transp.future.finished()): - untrackStream(transp) + untrackCounter(StreamTransportTrackerName) transp.future.complete() GC_unref(transp) @@ -786,7 +711,7 @@ when defined(windows): else: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) else: sock.closeSocket() @@ -855,7 +780,7 @@ when defined(windows): let transp = newStreamPipeTransport(AsyncFD(pipeHandle), bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) pipeContinuation(nil) @@ -911,7 +836,7 @@ when defined(windows): ntransp = newStreamPipeTransport(server.sock, server.bufferSize, nil, flags) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) of ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close call. @@ -1015,7 +940,7 @@ when defined(windows): ntransp = newStreamSocketTransport(server.asock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) of ERROR_OPERATION_ABORTED: @@ -1158,7 +1083,7 @@ when defined(windows): ntransp = newStreamSocketTransport(server.asock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) of ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close. @@ -1218,7 +1143,7 @@ when defined(windows): retFuture.fail(getTransportOsError(error)) return - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) of ERROR_OPERATION_ABORTED, ERROR_PIPE_NOT_CONNECTED: @@ -1550,14 +1475,13 @@ else: var saddr: Sockaddr_storage slen: SockLen - proto: Protocol var retFuture = newFuture[StreamTransport]("stream.transport.connect") address.toSAddr(saddr, slen) - proto = Protocol.IPPROTO_TCP - if address.family == AddressFamily.Unix: - # `Protocol` enum is missing `0` value, so we making here cast, until - # `Protocol` enum will not support IPPROTO_IP == 0. - proto = cast[Protocol](0) + let proto = + if address.family == AddressFamily.Unix: + Protocol.IPPROTO_IP + else: + Protocol.IPPROTO_TCP let sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) @@ -1628,7 +1552,7 @@ else: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) proc cancel(udata: pointer) = @@ -1641,7 +1565,7 @@ else: if res == 0: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) break else: @@ -1696,7 +1620,7 @@ else: newStreamSocketTransport(sock, server.bufferSize, transp) else: newStreamSocketTransport(sock, server.bufferSize, nil) - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) else: # Client was accepted, so we not going to raise assertion, but @@ -1784,7 +1708,7 @@ else: else: newStreamSocketTransport(sock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) else: discard closeFd(cint(sock)) @@ -1895,11 +1819,32 @@ proc closeWait*(server: StreamServer): Future[void] = server.close() server.join() +proc getBacklogSize(backlog: int): cint = + doAssert(backlog >= 0 and backlog <= high(int32)) + when defined(windows): + # The maximum length of the queue of pending connections. If set to + # SOMAXCONN, the underlying service provider responsible for + # socket s will set the backlog to a maximum reasonable value. If set to + # SOMAXCONN_HINT(N) (where N is a number), the backlog value will be N, + # adjusted to be within the range (200, 65535). Note that SOMAXCONN_HINT + # can be used to set the backlog to a larger value than possible with + # SOMAXCONN. + # + # Microsoft SDK values are + # #define SOMAXCONN 0x7fffffff + # #define SOMAXCONN_HINT(b) (-(b)) + if backlog != high(int32): + cint(-backlog) + else: + cint(backlog) + else: + cint(backlog) + proc createStreamServer*(host: TransportAddress, cbproc: StreamCallback, flags: set[ServerFlags] = {}, sock: AsyncFD = asyncInvalidSocket, - backlog: int = 100, + backlog: int = DefaultBacklogSize, bufferSize: int = DefaultStreamBufferSize, child: StreamServer = nil, init: TransportInitCallback = nil, @@ -1982,7 +1927,7 @@ proc createStreamServer*(host: TransportAddress, raiseTransportOsError(err) fromSAddr(addr saddr, slen, localAddress) - if listen(SocketHandle(serverSocket), cint(backlog)) != 0: + if listen(SocketHandle(serverSocket), getBacklogSize(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: discard closeFd(SocketHandle(serverSocket)) @@ -1992,11 +1937,10 @@ proc createStreamServer*(host: TransportAddress, else: # Posix if sock == asyncInvalidSocket: - var proto = Protocol.IPPROTO_TCP - if host.family == AddressFamily.Unix: - # `Protocol` enum is missing `0` value, so we making here cast, until - # `Protocol` enum will not support IPPROTO_IP == 0. - proto = cast[Protocol](0) + let proto = if host.family == AddressFamily.Unix: + Protocol.IPPROTO_IP + else: + Protocol.IPPROTO_TCP serverSocket = createAsyncSocket(host.getDomain(), SockType.SOCK_STREAM, proto) @@ -2056,7 +2000,7 @@ proc createStreamServer*(host: TransportAddress, raiseTransportOsError(err) fromSAddr(addr saddr, slen, localAddress) - if listen(SocketHandle(serverSocket), cint(backlog)) != 0: + if listen(SocketHandle(serverSocket), getBacklogSize(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: discard unregisterAndCloseFd(serverSocket) @@ -2100,14 +2044,14 @@ proc createStreamServer*(host: TransportAddress, sres.apending = false # Start tracking server - trackServer(sres) + trackCounter(StreamServerTrackerName) GC_ref(sres) sres proc createStreamServer*(host: TransportAddress, flags: set[ServerFlags] = {}, sock: AsyncFD = asyncInvalidSocket, - backlog: int = 100, + backlog: int = DefaultBacklogSize, bufferSize: int = DefaultStreamBufferSize, child: StreamServer = nil, init: TransportInitCallback = nil, @@ -2121,7 +2065,7 @@ proc createStreamServer*[T](host: TransportAddress, flags: set[ServerFlags] = {}, udata: ref T, sock: AsyncFD = asyncInvalidSocket, - backlog: int = 100, + backlog: int = DefaultBacklogSize, bufferSize: int = DefaultStreamBufferSize, child: StreamServer = nil, init: TransportInitCallback = nil): StreamServer {. @@ -2135,7 +2079,7 @@ proc createStreamServer*[T](host: TransportAddress, flags: set[ServerFlags] = {}, udata: ref T, sock: AsyncFD = asyncInvalidSocket, - backlog: int = 100, + backlog: int = DefaultBacklogSize, bufferSize: int = DefaultStreamBufferSize, child: StreamServer = nil, init: TransportInitCallback = nil): StreamServer {. @@ -2650,6 +2594,57 @@ proc closeWait*(transp: StreamTransport): Future[void] = transp.close() transp.join() +proc shutdownWait*(transp: StreamTransport): Future[void] = + ## Perform graceful shutdown of TCP connection backed by transport ``transp``. + doAssert(transp.kind == TransportKind.Socket) + let retFuture = newFuture[void]("stream.transport.shutdown") + transp.checkClosed(retFuture) + transp.checkWriteEof(retFuture) + + when defined(windows): + let loop = getThreadDispatcher() + proc continuation(udata: pointer) {.gcsafe.} = + let ovl = cast[RefCustomOverlapped](udata) + if not(retFuture.finished()): + if ovl.data.errCode == OSErrorCode(-1): + retFuture.complete() + else: + transp.state.excl({WriteEof}) + retFuture.fail(getTransportOsError(ovl.data.errCode)) + GC_unref(ovl) + + let povl = RefCustomOverlapped(data: CompletionData(cb: continuation)) + GC_ref(povl) + + let res = loop.disconnectEx(SocketHandle(transp.fd), + cast[POVERLAPPED](povl), 0'u32, 0'u32) + if res == FALSE: + let err = osLastError() + case err + of ERROR_IO_PENDING: + transp.state.incl({WriteEof}) + else: + GC_unref(povl) + retFuture.fail(getTransportOsError(err)) + else: + transp.state.incl({WriteEof}) + retFuture.complete() + + retFuture + else: + proc continuation(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): + retFuture.complete() + + let res = osdefs.shutdown(SocketHandle(transp.fd), SHUT_WR) + if res < 0: + let err = osLastError() + retFuture.fail(getTransportOsError(err)) + else: + transp.state.incl({WriteEof}) + callSoon(continuation, nil) + retFuture + proc closed*(transp: StreamTransport): bool {.inline.} = ## Returns ``true`` if transport in closed state. ({ReadClosed, WriteClosed} * transp.state != {}) @@ -2676,7 +2671,7 @@ proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil, ? register2(fd) var res = newStreamPipeTransport(fd, bufferSize, child) # Start tracking transport - trackStream(res) + trackCounter(StreamTransportTrackerName) ok(res) proc fromPipe*(fd: AsyncFD, child: StreamTransport = nil, diff --git a/chronos/unittest2/asynctests.nim b/chronos/unittest2/asynctests.nim index fda03537..bc703b7e 100644 --- a/chronos/unittest2/asynctests.nim +++ b/chronos/unittest2/asynctests.nim @@ -6,6 +6,7 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) +import std/tables import unittest2 import ../../chronos @@ -17,3 +18,14 @@ template asyncTest*(name: string, body: untyped): untyped = proc() {.async, gcsafe.} = body )()) + +template checkLeaks*(name: string): untyped = + let counter = getTrackerCounter(name) + if counter.opened != counter.closed: + echo "[" & name & "] opened = ", counter.opened, + ", closed = ", counter.closed + check counter.opened == counter.closed + +template checkLeaks*(): untyped = + for key in getThreadDispatcher().trackerCounterKeys(): + checkLeaks(key) diff --git a/tests/testall.nim b/tests/testall.nim index bf0e98a9..6419f983 100644 --- a/tests/testall.nim +++ b/tests/testall.nim @@ -5,10 +5,22 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import testmacro, testsync, testsoon, testtime, testfut, testsignal, - testaddress, testdatagram, teststream, testserver, testbugs, testnet, - testasyncstream, testhttpserver, testshttpserver, testhttpclient, - testproc, testratelimit, testfutures +import ".."/chronos/config -# Must be imported last to check for Pending futures -import testutils +when (chronosEventEngine in ["epoll", "kqueue"]) or defined(windows): + import testmacro, testsync, testsoon, testtime, testfut, testsignal, + testaddress, testdatagram, teststream, testserver, testbugs, testnet, + testasyncstream, testhttpserver, testshttpserver, testhttpclient, + testproc, testratelimit, testfutures, testthreadsync + + # Must be imported last to check for Pending futures + import testutils +elif chronosEventEngine == "poll": + # `poll` engine do not support signals and processes + import testmacro, testsync, testsoon, testtime, testfut, testaddress, + testdatagram, teststream, testserver, testbugs, testnet, + testasyncstream, testhttpserver, testshttpserver, testhttpclient, + testratelimit, testfutures, testthreadsync + + # Must be imported last to check for Pending futures + import testutils diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 47a6c942..d90b6887 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -7,8 +7,8 @@ # MIT license (LICENSE-MIT) import unittest2 import bearssl/[x509] -import ../chronos -import ../chronos/streams/[tlsstream, chunkstream, boundstream] +import ".."/chronos/unittest2/asynctests +import ".."/chronos/streams/[tlsstream, chunkstream, boundstream] {.used.} @@ -145,7 +145,7 @@ proc createBigMessage(message: string, size: int): seq[byte] = suite "AsyncStream test suite": test "AsyncStream(StreamTransport) readExactly() test": - proc testReadExactly(address: TransportAddress): Future[bool] {.async.} = + proc testReadExactly(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -157,9 +157,10 @@ suite "AsyncStream test suite": server.close() var buffer = newSeq[byte](10) - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) await rstream.readExactly(addr buffer[0], 10) check cast[string](buffer) == "0000000000" @@ -171,9 +172,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadExactly(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadExactly()) == true + test "AsyncStream(StreamTransport) readUntil() test": - proc testReadUntil(address: TransportAddress): Future[bool] {.async.} = + proc testReadUntil(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -186,9 +188,10 @@ suite "AsyncStream test suite": var buffer = newSeq[byte](13) var sep = @[byte('N'), byte('N'), byte('z')] - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var r1 = await rstream.readUntil(addr buffer[0], len(buffer), sep) check: @@ -207,9 +210,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadUntil(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadUntil()) == true + test "AsyncStream(StreamTransport) readLine() test": - proc testReadLine(address: TransportAddress): Future[bool] {.async.} = + proc testReadLine(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -220,9 +224,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var r1 = await rstream.readLine() check r1 == "0000000000" @@ -234,9 +239,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadLine(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadLine()) == true + test "AsyncStream(StreamTransport) read() test": - proc testRead(address: TransportAddress): Future[bool] {.async.} = + proc testRead(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -247,9 +253,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var buf1 = await rstream.read(10) check cast[string](buf1) == "0000000000" @@ -259,9 +266,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testRead(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testRead()) == true + test "AsyncStream(StreamTransport) consume() test": - proc testConsume(address: TransportAddress): Future[bool] {.async.} = + proc testConsume(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -272,9 +280,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var res1 = await rstream.consume(10) check: @@ -290,16 +299,13 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testConsume(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testConsume()) == true + test "AsyncStream(StreamTransport) leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() test "AsyncStream(AsyncStream) readExactly() test": - proc testReadExactly2(address: TransportAddress): Future[bool] {.async.} = + proc testReadExactly2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -323,9 +329,10 @@ suite "AsyncStream test suite": server.close() var buffer = newSeq[byte](10) - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) await rstream2.readExactly(addr buffer[0], 10) @@ -347,9 +354,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadExactly2(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadExactly2()) == true + test "AsyncStream(AsyncStream) readUntil() test": - proc testReadUntil2(address: TransportAddress): Future[bool] {.async.} = + proc testReadUntil2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -373,9 +381,10 @@ suite "AsyncStream test suite": var buffer = newSeq[byte](13) var sep = @[byte('N'), byte('N'), byte('z')] - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) @@ -404,9 +413,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadUntil2(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadUntil2()) == true + test "AsyncStream(AsyncStream) readLine() test": - proc testReadLine2(address: TransportAddress): Future[bool] {.async.} = + proc testReadLine2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -425,9 +435,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) var r1 = await rstream2.readLine() @@ -449,9 +460,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testReadLine2(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testReadLine2()) == true + test "AsyncStream(AsyncStream) read() test": - proc testRead2(address: TransportAddress): Future[bool] {.async.} = + proc testRead2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -469,9 +481,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) var buf1 = await rstream2.read(10) @@ -488,9 +501,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testRead2(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testRead2()) == true + test "AsyncStream(AsyncStream) consume() test": - proc testConsume2(address: TransportAddress): Future[bool] {.async.} = + proc testConsume2(): Future[bool] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = const @@ -518,9 +532,10 @@ suite "AsyncStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) @@ -547,9 +562,10 @@ suite "AsyncStream test suite": await transp.closeWait() await server.join() result = true - check waitFor(testConsume2(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testConsume2()) == true + test "AsyncStream(AsyncStream) write(eof) test": - proc testWriteEof(address: TransportAddress): Future[bool] {.async.} = + proc testWriteEof(): Future[bool] {.async.} = let size = 10240 message = createBigMessage("ABCDEFGHIJKLMNOP", size) @@ -578,7 +594,8 @@ suite "AsyncStream test suite": await transp.closeWait() let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay} - var server = createStreamServer(address, processClient, flags = flags) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + processClient, flags = flags) server.start() var conn = await connect(server.localAddress()) try: @@ -589,13 +606,10 @@ suite "AsyncStream test suite": await server.closeWait() return true - check waitFor(testWriteEof(initTAddress("127.0.0.1:46001"))) == true + check waitFor(testWriteEof()) == true + test "AsyncStream(AsyncStream) leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() suite "ChunkedStream test suite": test "ChunkedStream test vectors": @@ -624,8 +638,7 @@ suite "ChunkedStream test suite": " in\r\n\r\nchunks.\r\n0;position=4\r\n\r\n", "Wikipedia in\r\n\r\nchunks."], ] - proc checkVector(address: TransportAddress, - inputstr: string): Future[string] {.async.} = + proc checkVector(inputstr: string): Future[string] {.async.} = proc serveClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) @@ -637,9 +650,10 @@ suite "ChunkedStream test suite": server.stop() server.close() - var server = createStreamServer(address, serveClient, {ReuseAddr}) + var server = createStreamServer(initTAddress("127.0.0.1:0"), + serveClient, {ReuseAddr}) server.start() - var transp = await connect(address) + var transp = await connect(server.localAddress()) var rstream = newAsyncStreamReader(transp) var rstream2 = newChunkedStreamReader(rstream) var res = await rstream2.read() @@ -650,15 +664,16 @@ suite "ChunkedStream test suite": await server.join() result = ress - proc testVectors(address: TransportAddress): Future[bool] {.async.} = + proc testVectors(): Future[bool] {.async.} = var res = true for i in 0..= 5: let (code, data) = await session.fetch(ha.getUri()) await session.closeWait() @@ -691,26 +704,22 @@ suite "HTTP client testing suite": await server.closeWait() return "redirect-" & $res - proc testBasicAuthorization(): Future[bool] {.async.} = - let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost}, - maxRedirections = 10) - let url = parseUri("https://guest:guest@jigsaw.w3.org/HTTP/Basic/") - let resp = await session.fetch(url) - await session.closeWait() - if (resp.status == 200) and - ("Your browser made it!" in bytesToString(resp.data)): - return true - else: - echo "RESPONSE STATUS = [", resp.status, "]" - echo "RESPONSE = [", bytesToString(resp.data), "]" - return false + # proc testBasicAuthorization(): Future[bool] {.async.} = + # let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost}, + # maxRedirections = 10) + # let url = parseUri("https://guest:guest@jigsaw.w3.org/HTTP/Basic/") + # let resp = await session.fetch(url) + # await session.closeWait() + # if (resp.status == 200) and + # ("Your browser made it!" in bytesToString(resp.data)): + # return true + # else: + # echo "RESPONSE STATUS = [", resp.status, "]" + # echo "RESPONSE = [", bytesToString(resp.data), "]" + # return false - proc testConnectionManagement(address: TransportAddress): Future[bool] {. + proc testConnectionManagement(): Future[bool] {. async.} = - let - keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep") - dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop") - proc test1( a1: HttpAddress, version: HttpVersion, @@ -770,10 +779,15 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() - var server = createServer(address, process, false) + var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() + let address = server.instance.localAddress() + + let + keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep") + dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop") try: let @@ -872,11 +886,7 @@ suite "HTTP client testing suite": return true - proc testIdleConnection(address: TransportAddress): Future[bool] {. - async.} = - let - ha = getAddress(address, HttpClientScheme.NonSecure, "/test") - + proc testIdleConnection(): Future[bool] {.async.} = proc test( session: HttpSessionRef, a: HttpAddress @@ -900,13 +910,16 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() - var server = createServer(address, process, false) + var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() - let session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline}, - idleTimeout = 1.seconds, - idlePeriod = 200.milliseconds) + let + address = server.instance.localAddress() + ha = getAddress(address, HttpClientScheme.NonSecure, "/test") + session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline}, + idleTimeout = 1.seconds, + idlePeriod = 200.milliseconds) try: var f1 = test(session, ha) var f2 = test(session, ha) @@ -932,12 +945,7 @@ suite "HTTP client testing suite": return true - proc testNoPipeline(address: TransportAddress): Future[bool] {. - async.} = - let - ha = getAddress(address, HttpClientScheme.NonSecure, "/test") - hb = getAddress(address, HttpClientScheme.NonSecure, "/keep-test") - + proc testNoPipeline(): Future[bool] {.async.} = proc test( session: HttpSessionRef, a: HttpAddress @@ -964,12 +972,16 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() - var server = createServer(address, process, false) + var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() - let session = HttpSessionRef.new(idleTimeout = 100.seconds, - idlePeriod = 10.milliseconds) + let + address = server.instance.localAddress() + ha = getAddress(address, HttpClientScheme.NonSecure, "/test") + hb = getAddress(address, HttpClientScheme.NonSecure, "/keep-test") + session = HttpSessionRef.new(idleTimeout = 100.seconds, + idlePeriod = 10.milliseconds) try: var f1 = test(session, ha) var f2 = test(session, ha) @@ -1001,8 +1013,7 @@ suite "HTTP client testing suite": return true - proc testServerSentEvents(address: TransportAddress, - secure: bool): Future[bool] {.async.} = + proc testServerSentEvents(secure: bool): Future[bool] {.async.} = const SingleGoodTests = [ ("/test/single/1", "a:b\r\nc: d\re:f\n:comment\r\ng:\n h: j \n\n", @@ -1115,10 +1126,11 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() - var server = createServer(address, process, secure) + var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() + let address = server.instance.localAddress() var session = createSession(secure) @@ -1184,100 +1196,71 @@ suite "HTTP client testing suite": return true test "HTTP all request methods test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testMethods(address, false)) == 18 + check waitFor(testMethods(false)) == 18 test "HTTP(S) all request methods test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testMethods(address, true)) == 18 + check waitFor(testMethods(true)) == 18 test "HTTP client response streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testResponseStreamReadingTest(address, false)) == 8 + check waitFor(testResponseStreamReadingTest(false)) == 8 test "HTTP(S) client response streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testResponseStreamReadingTest(address, true)) == 8 + check waitFor(testResponseStreamReadingTest(true)) == 8 test "HTTP client (size) request streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestSizeStreamWritingTest(address, false)) == 2 + check waitFor(testRequestSizeStreamWritingTest(false)) == 2 test "HTTP(S) client (size) request streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestSizeStreamWritingTest(address, true)) == 2 + check waitFor(testRequestSizeStreamWritingTest(true)) == 2 test "HTTP client (chunked) request streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestChunkedStreamWritingTest(address, false)) == 2 + check waitFor(testRequestChunkedStreamWritingTest(false)) == 2 test "HTTP(S) client (chunked) request streaming test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestChunkedStreamWritingTest(address, true)) == 2 + check waitFor(testRequestChunkedStreamWritingTest(true)) == 2 test "HTTP client (size + chunked) url-encoded POST test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestPostUrlEncodedTest(address, false)) == 2 + check waitFor(testRequestPostUrlEncodedTest(false)) == 2 test "HTTP(S) client (size + chunked) url-encoded POST test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestPostUrlEncodedTest(address, true)) == 2 + check waitFor(testRequestPostUrlEncodedTest(true)) == 2 test "HTTP client (size + chunked) multipart POST test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestPostMultipartTest(address, false)) == 2 + check waitFor(testRequestPostMultipartTest(false)) == 2 test "HTTP(S) client (size + chunked) multipart POST test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestPostMultipartTest(address, true)) == 2 + check waitFor(testRequestPostMultipartTest(true)) == 2 test "HTTP client redirection test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestRedirectTest(address, false, 5)) == "ok-5-200" + check waitFor(testRequestRedirectTest(false, 5)) == "ok-5-200" test "HTTP(S) client redirection test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestRedirectTest(address, true, 5)) == "ok-5-200" + check waitFor(testRequestRedirectTest(true, 5)) == "ok-5-200" test "HTTP client maximum redirections test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestRedirectTest(address, false, 4)) == "redirect-true" + check waitFor(testRequestRedirectTest(false, 4)) == "redirect-true" test "HTTP(S) client maximum redirections test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testRequestRedirectTest(address, true, 4)) == "redirect-true" + check waitFor(testRequestRedirectTest(true, 4)) == "redirect-true" test "HTTPS basic authorization test": - check waitFor(testBasicAuthorization()) == true + skip() + # This test disabled because remote service is pretty flaky and fails pretty + # often. As soon as more stable service will be found this test should be + # recovered + # check waitFor(testBasicAuthorization()) == true test "HTTP client connection management test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testConnectionManagement(address)) == true + check waitFor(testConnectionManagement()) == true test "HTTP client idle connection test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testIdleConnection(address)) == true + check waitFor(testIdleConnection()) == true test "HTTP client no-pipeline test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testNoPipeline(address)) == true + check waitFor(testNoPipeline()) == true test "HTTP client server-sent events test": - let address = initTAddress("127.0.0.1:30080") - check waitFor(testServerSentEvents(address, false)) == true + check waitFor(testServerSentEvents(false)) == true test "Leaks test": - proc getTrackerLeaks(tracker: string): bool = - let tracker = getTracker(tracker) - if isNil(tracker): false else: tracker.isLeaked() - - check: - getTrackerLeaks("http.body.reader") == false - getTrackerLeaks("http.body.writer") == false - getTrackerLeaks("httpclient.connection") == false - getTrackerLeaks("httpclient.request") == false - getTrackerLeaks("httpclient.response") == false - getTrackerLeaks("async.stream.reader") == false - getTrackerLeaks("async.stream.writer") == false - getTrackerLeaks("stream.server") == false - getTrackerLeaks("stream.transport") == false + checkLeaks() diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index acf8b20b..0ecc9aa4 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -6,9 +6,9 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, algorithm] -import unittest2 -import ../chronos, ../chronos/apps/http/httpserver, - ../chronos/apps/http/httpcommon +import ".."/chronos/unittest2/asynctests, + ".."/chronos, + ".."/chronos/apps/http/[httpserver, httpcommon, httpdebug] import stew/base10 {.used.} @@ -17,6 +17,9 @@ suite "HTTP server testing suite": type TooBigTest = enum GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest + TestHttpResponse = object + headers: HttpTable + data: string proc httpClient(address: TransportAddress, data: string): Future[string] {.async.} = @@ -33,8 +36,32 @@ suite "HTTP server testing suite": if not(isNil(transp)): await closeWait(transp) - proc testTooBigBodyChunked(address: TransportAddress, - operation: TooBigTest): Future[bool] {.async.} = + proc httpClient2(transp: StreamTransport, + request: string, + length: int): Future[TestHttpResponse] {.async.} = + var buffer = newSeq[byte](4096) + var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + let wres = await transp.write(request) + if wres != len(request): + raise newException(ValueError, "Unable to write full request") + let hres = await transp.readUntil(addr buffer[0], len(buffer), sep) + var hdata = @buffer + hdata.setLen(hres) + zeroMem(addr buffer[0], len(buffer)) + await transp.readExactly(addr buffer[0], length) + let data = bytesToString(buffer.toOpenArray(0, length - 1)) + let headers = + block: + let resp = parseResponse(hdata, false) + if resp.failed(): + raise newException(ValueError, "Unable to decode response headers") + var res = HttpTable.init() + for key, value in resp.headers(hdata): + res.add(key, value) + res + return TestHttpResponse(headers: headers, data: data) + + proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -56,10 +83,10 @@ suite "HTTP server testing suite": # Reraising exception, because processor should properly handle it. raise exc else: - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, maxRequestBodySize = 10, socketFlags = socketFlags) if res.isErr(): @@ -67,18 +94,19 @@ suite "HTTP server testing suite": let server = res.get() server.start() + let address = server.instance.localAddress() let request = case operation of GetBodyTest, ConsumeBodyTest, PostUrlTest: - "POST / HTTP/1.0\r\n" & + "POST / HTTP/1.1\r\n" & "Content-Type: application/x-www-form-urlencoded\r\n" & "Transfer-Encoding: chunked\r\n" & "Cookie: 2\r\n\r\n" & "5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" & "2\r\n9F\r\n0\r\n\r\n" of PostMultipartTest: - "POST / HTTP/1.0\r\n" & + "POST / HTTP/1.1\r\n" & "Host: 127.0.0.1:30080\r\n" & "Transfer-Encoding: chunked\r\n" & "Content-Type: multipart/form-data; boundary=f98f0\r\n\r\n" & @@ -97,7 +125,7 @@ suite "HTTP server testing suite": return serverRes and (data.startsWith("HTTP/1.1 413")) test "Request headers timeout test": - proc testTimeout(address: TransportAddress): Future[bool] {.async.} = + proc testTimeout(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -105,28 +133,29 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.TimeoutError: + if r.error.kind == HttpServerError.TimeoutError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, socketFlags = socketFlags, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), + process, socketFlags = socketFlags, httpHeadersTimeout = 100.milliseconds) if res.isErr(): return false let server = res.get() server.start() - + let address = server.instance.localAddress() let data = await httpClient(address, "") await server.stop() await server.closeWait() return serverRes and (data.startsWith("HTTP/1.1 408")) - check waitFor(testTimeout(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testTimeout()) == true test "Empty headers test": - proc testEmpty(address: TransportAddress): Future[bool] {.async.} = + proc testEmpty(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -134,27 +163,29 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.CriticalError: + if r.error.kind == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, socketFlags = socketFlags) + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), + process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let data = await httpClient(address, "\r\n\r\n") await server.stop() await server.closeWait() return serverRes and (data.startsWith("HTTP/1.1 400")) - check waitFor(testEmpty(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testEmpty()) == true test "Too big headers test": - proc testTooBig(address: TransportAddress): Future[bool] {.async.} = + proc testTooBig(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -162,12 +193,12 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.CriticalError: + if r.error.error == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, maxHeadersSize = 10, socketFlags = socketFlags) if res.isErr(): @@ -175,28 +206,29 @@ suite "HTTP server testing suite": let server = res.get() server.start() + let address = server.instance.localAddress() let data = await httpClient(address, "GET / HTTP/1.1\r\n\r\n") await server.stop() await server.closeWait() return serverRes and (data.startsWith("HTTP/1.1 431")) - check waitFor(testTooBig(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testTooBig()) == true test "Too big request body test (content-length)": - proc testTooBigBody(address: TransportAddress): Future[bool] {.async.} = + proc testTooBigBody(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): discard else: - if r.error().error == HttpServerError.CriticalError: + if r.error.error == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, maxRequestBodySize = 10, socketFlags = socketFlags) if res.isErr(): @@ -204,6 +236,7 @@ suite "HTTP server testing suite": let server = res.get() server.start() + let address = server.instance.localAddress() let request = "GET / HTTP/1.1\r\nContent-Length: 20\r\n\r\n" let data = await httpClient(address, request) @@ -211,30 +244,26 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.startsWith("HTTP/1.1 413")) - check waitFor(testTooBigBody(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testTooBigBody()) == true test "Too big request body test (getBody()/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), - GetBodyTest)) == true + waitFor(testTooBigBodyChunked(GetBodyTest)) == true test "Too big request body test (consumeBody()/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), - ConsumeBodyTest)) == true + waitFor(testTooBigBodyChunked(ConsumeBodyTest)) == true test "Too big request body test (post()/urlencoded/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), - PostUrlTest)) == true + waitFor(testTooBigBodyChunked(PostUrlTest)) == true test "Too big request body test (post()/multipart/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), - PostMultipartTest)) == true + waitFor(testTooBigBodyChunked(PostMultipartTest)) == true test "Query arguments test": - proc testQuery(address: TransportAddress): Future[bool] {.async.} = + proc testQuery(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -249,16 +278,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let data1 = await httpClient(address, "GET /?a=1&a=2&b=3&c=4 HTTP/1.0\r\n\r\n") @@ -271,10 +301,10 @@ suite "HTTP server testing suite": (data2.find("TEST_OK:a:П:b:Ц:c:Ю:Ф:Б") >= 0) return r - check waitFor(testQuery(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testQuery()) == true test "Headers test": - proc testHeaders(address: TransportAddress): Future[bool] {.async.} = + proc testHeaders(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -289,16 +319,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "GET / HTTP/1.0\r\n" & @@ -314,10 +345,10 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testHeaders(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testHeaders()) == true test "POST arguments (urlencoded/content-length) test": - proc testPostUrl(address: TransportAddress): Future[bool] {.async.} = + proc testPostUrl(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -334,16 +365,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "POST / HTTP/1.0\r\n" & @@ -357,10 +389,10 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testPostUrl(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testPostUrl()) == true test "POST arguments (urlencoded/chunked encoding) test": - proc testPostUrl2(address: TransportAddress): Future[bool] {.async.} = + proc testPostUrl2(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -377,16 +409,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "POST / HTTP/1.0\r\n" & @@ -401,10 +434,10 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testPostUrl2(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testPostUrl2()) == true test "POST arguments (multipart/content-length) test": - proc testPostMultipart(address: TransportAddress): Future[bool] {.async.} = + proc testPostMultipart(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -421,16 +454,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "POST / HTTP/1.0\r\n" & @@ -456,10 +490,10 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testPostMultipart(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testPostMultipart()) == true test "POST arguments (multipart/chunked encoding) test": - proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} = + proc testPostMultipart2(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -476,16 +510,17 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "POST / HTTP/1.0\r\n" & @@ -520,12 +555,12 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testPostMultipart2()) == true test "drop() connections test": const ClientsCount = 10 - proc testHTTPdrop(address: TransportAddress): Future[bool] {.async.} = + proc testHTTPdrop(): Future[bool] {.async.} = var eventWait = newAsyncEvent() var eventContinue = newAsyncEvent() var count = 0 @@ -539,10 +574,10 @@ suite "HTTP server testing suite": await eventContinue.wait() return await request.respond(Http404, "", HttpTable.init()) else: - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags, maxConnections = 100) if res.isErr(): @@ -550,6 +585,7 @@ suite "HTTP server testing suite": let server = res.get() server.start() + let address = server.instance.localAddress() var clients: seq[Future[string]] let message = "GET / HTTP/1.0\r\nHost: https://127.0.0.1:80\r\n\r\n" @@ -572,7 +608,7 @@ suite "HTTP server testing suite": return false return true - check waitFor(testHTTPdrop(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testHTTPdrop()) == true test "Content-Type multipart boundary test": const AllowedCharacters = { @@ -1190,7 +1226,7 @@ suite "HTTP server testing suite": r6.get() == MediaType.init(req[1][6]) test "SSE server-side events stream test": - proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} = + proc testPostMultipart2(): Future[bool] {.async.} = var serverRes = false proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = @@ -1209,16 +1245,17 @@ suite "HTTP server testing suite": return response else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let res = HttpServerRef.new(address, process, + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, socketFlags = socketFlags) if res.isErr(): return false let server = res.get() server.start() + let address = server.instance.localAddress() let message = "GET / HTTP/1.1\r\n" & @@ -1237,12 +1274,158 @@ suite "HTTP server testing suite": await server.closeWait() return serverRes and (data.find(expect) >= 0) - check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true + check waitFor(testPostMultipart2()) == true + asyncTest "HTTP/1.1 pipeline test": + const TestMessages = [ + ("GET / HTTP/1.0\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, false, "close"), + ("GET / HTTP/1.0\r\nConnection: close\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, false, "close"), + ("GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, false, "close"), + ("GET / HTTP/1.0\r\n\r\n", + {}, false, "close"), + ("GET / HTTP/1.0\r\nConnection: close\r\n\r\n", + {}, false, "close"), + ("GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", + {}, false, "close"), + ("GET / HTTP/1.1\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, true, "keep-alive"), + ("GET / HTTP/1.1\r\nConnection: close\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, false, "close"), + ("GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n", + {HttpServerFlags.Http11Pipeline}, true, "keep-alive"), + ("GET / HTTP/1.1\r\n\r\n", + {}, false, "close"), + ("GET / HTTP/1.1\r\nConnection: close\r\n\r\n", + {}, false, "close"), + ("GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n", + {}, false, "close") + ] + + proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + return await request.respond(Http200, "TEST_OK", HttpTable.init()) + else: + return defaultResponse() + + for test in TestMessages: + let + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + serverFlags = test[1] + res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + socketFlags = socketFlags, + serverFlags = serverFlags) + check res.isOk() + + let + server = res.get() + address = server.instance.localAddress() + + server.start() + var transp: StreamTransport + try: + transp = await connect(address) + block: + let response = await transp.httpClient2(test[0], 7) + check: + response.data == "TEST_OK" + response.headers.getString("connection") == test[3] + # We do this sleeping here just because we running both server and + # client in single process, so when we received response from server + # it does not mean that connection has been immediately closed - it + # takes some more calls, so we trying to get this calls happens. + await sleepAsync(50.milliseconds) + let connectionStillAvailable = + try: + let response {.used.} = await transp.httpClient2(test[0], 7) + true + except CatchableError: + false + + check connectionStillAvailable == test[2] + + finally: + if not(isNil(transp)): + await transp.closeWait() + await server.stop() + await server.closeWait() + + asyncTest "HTTP debug tests": + const + TestsCount = 10 + TestRequest = "GET /httpdebug HTTP/1.1\r\nConnection: keep-alive\r\n\r\n" + + proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + return await request.respond(Http200, "TEST_OK", HttpTable.init()) + else: + return defaultResponse() + + proc client(address: TransportAddress, + data: string): Future[StreamTransport] {.async.} = + var transp: StreamTransport + var buffer = newSeq[byte](4096) + var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + try: + transp = await connect(address) + let wres {.used.} = + await transp.write(data) + let hres {.used.} = + await transp.readUntil(addr buffer[0], len(buffer), sep) + transp + except CatchableError: + if not(isNil(transp)): await transp.closeWait() + nil + + let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + serverFlags = {HttpServerFlags.Http11Pipeline}, + socketFlags = socketFlags) + check res.isOk() + + let server = res.get() + server.start() + let address = server.instance.localAddress() + + let info = server.getServerInfo() + + check: + info.connectionType == ConnectionType.NonSecure + info.address == address + info.state == HttpServerState.ServerRunning + info.flags == {HttpServerFlags.Http11Pipeline} + info.socketFlags == socketFlags + + try: + var clientFutures: seq[Future[StreamTransport]] + for i in 0 ..< TestsCount: + clientFutures.add(client(address, TestRequest)) + await allFutures(clientFutures) + + let connections = server.getConnections() + check len(connections) == TestsCount + let currentTime = Moment.now() + for index, connection in connections.pairs(): + let transp = clientFutures[index].read() + check: + connection.remoteAddress.get() == transp.localAddress() + connection.localAddress.get() == transp.remoteAddress() + connection.connectionType == ConnectionType.NonSecure + connection.connectionState == ConnectionState.Alive + connection.query.get("") == "/httpdebug" + (currentTime - connection.createMoment.get()) != ZeroDuration + (currentTime - connection.acceptMoment) != ZeroDuration + var pending: seq[Future[void]] + for transpFut in clientFutures: + pending.add(closeWait(transpFut.read())) + await allFutures(pending) + finally: + await server.stop() + await server.closeWait() test "Leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() diff --git a/tests/testmacro.nim b/tests/testmacro.nim index 2526c5de..ad4c22f3 100644 --- a/tests/testmacro.nim +++ b/tests/testmacro.nim @@ -177,6 +177,10 @@ suite "Macro transformations test suite": of false: await implicit7(v) of true: 42 + proc implicit9(): Future[int] {.async.} = + result = 42 + result + let fin = new int check: waitFor(implicit()) == 42 @@ -193,6 +197,8 @@ suite "Macro transformations test suite": waitFor(implicit8(true)) == 42 waitFor(implicit8(false)) == 33 + waitFor(implicit9()) == 42 + suite "Closure iterator's exception transformation issues": test "Nested defer/finally not called on return": # issue #288 diff --git a/tests/testproc.bat b/tests/testproc.bat index 314bea73..11b4047e 100644 --- a/tests/testproc.bat +++ b/tests/testproc.bat @@ -2,6 +2,8 @@ IF /I "%1" == "STDIN" ( GOTO :STDINTEST +) ELSE IF /I "%1" == "TIMEOUT1" ( + GOTO :TIMEOUTTEST1 ) ELSE IF /I "%1" == "TIMEOUT2" ( GOTO :TIMEOUTTEST2 ) ELSE IF /I "%1" == "TIMEOUT10" ( @@ -19,6 +21,10 @@ SET /P "INPUTDATA=" ECHO STDIN DATA: %INPUTDATA% EXIT 0 +:TIMEOUTTEST1 +ping -n 1 127.0.0.1 > NUL +EXIT 1 + :TIMEOUTTEST2 ping -n 2 127.0.0.1 > NUL EXIT 2 diff --git a/tests/testproc.nim b/tests/testproc.nim index 05f793db..288ec181 100644 --- a/tests/testproc.nim +++ b/tests/testproc.nim @@ -6,8 +6,9 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/os -import unittest2, stew/[base10, byteutils] +import stew/[base10, byteutils] import ".."/chronos/unittest2/asynctests +import ".."/chronos/asyncproc when defined(posix): from ".."/chronos/osdefs import SIGKILL @@ -96,7 +97,11 @@ suite "Asynchronous process management test suite": let options = {AsyncProcessOption.EvalCommand} - command = "exit 1" + command = + when defined(windows): + "tests\\testproc.bat timeout1" + else: + "tests/testproc.sh timeout1" process = await startProcess(command, options = options) @@ -407,6 +412,52 @@ suite "Asynchronous process management test suite": finally: await process.closeWait() + asyncTest "killAndWaitForExit() test": + let command = + when defined(windows): + ("tests\\testproc.bat", "timeout10", 0) + else: + ("tests/testproc.sh", "timeout10", 128 + int(SIGKILL)) + let process = await startProcess(command[0], arguments = @[command[1]]) + try: + let exitCode = await process.killAndWaitForExit(10.seconds) + check exitCode == command[2] + finally: + await process.closeWait() + + asyncTest "terminateAndWaitForExit() test": + let command = + when defined(windows): + ("tests\\testproc.bat", "timeout10", 0) + else: + ("tests/testproc.sh", "timeout10", 128 + int(SIGTERM)) + let process = await startProcess(command[0], arguments = @[command[1]]) + try: + let exitCode = await process.terminateAndWaitForExit(10.seconds) + check exitCode == command[2] + finally: + await process.closeWait() + + asyncTest "terminateAndWaitForExit() timeout test": + when defined(windows): + skip() + else: + let + command = ("tests/testproc.sh", "noterm", 128 + int(SIGKILL)) + process = await startProcess(command[0], arguments = @[command[1]]) + # We should wait here to allow `bash` execute `trap` command, otherwise + # our test script will be killed with SIGTERM. Increase this timeout + # if test become flaky. + await sleepAsync(1.seconds) + try: + expect AsyncProcessTimeoutError: + let exitCode {.used.} = + await process.terminateAndWaitForExit(1.seconds) + let exitCode = await process.killAndWaitForExit(10.seconds) + check exitCode == command[2] + finally: + await process.closeWait() + test "File descriptors leaks test": when defined(windows): skip() @@ -414,12 +465,4 @@ suite "Asynchronous process management test suite": check getCurrentFD() == markFD test "Leaks test": - proc getTrackerLeaks(tracker: string): bool = - let tracker = getTracker(tracker) - if isNil(tracker): false else: tracker.isLeaked() - - check: - getTrackerLeaks("async.process") == false - getTrackerLeaks("async.stream.reader") == false - getTrackerLeaks("async.stream.writer") == false - getTrackerLeaks("stream.transport") == false + checkLeaks() diff --git a/tests/testproc.sh b/tests/testproc.sh index 1725d49d..c5e7e0ac 100755 --- a/tests/testproc.sh +++ b/tests/testproc.sh @@ -3,6 +3,9 @@ if [ "$1" == "stdin" ]; then read -r inputdata echo "STDIN DATA: $inputdata" +elif [ "$1" == "timeout1" ]; then + sleep 1 + exit 1 elif [ "$1" == "timeout2" ]; then sleep 2 exit 2 @@ -15,6 +18,11 @@ elif [ "$1" == "bigdata" ]; then done elif [ "$1" == "envtest" ]; then echo "$CHRONOSASYNC" +elif [ "$1" == "noterm" ]; then + trap -- '' SIGTERM + while true; do + sleep 1 + done else echo "arguments missing" fi diff --git a/tests/testratelimit.nim b/tests/testratelimit.nim index 4c78664d..bf281eec 100644 --- a/tests/testratelimit.nim +++ b/tests/testratelimit.nim @@ -15,22 +15,23 @@ import ../chronos/ratelimit suite "Token Bucket": test "Sync test": var bucket = TokenBucket.new(1000, 1.milliseconds) + let + start = Moment.now() + fullTime = start + 1.milliseconds check: - bucket.tryConsume(800) == true - bucket.tryConsume(200) == true + bucket.tryConsume(800, start) == true + bucket.tryConsume(200, start) == true # Out of budget - bucket.tryConsume(100) == false - waitFor(sleepAsync(10.milliseconds)) - check: - bucket.tryConsume(800) == true - bucket.tryConsume(200) == true + bucket.tryConsume(100, start) == false + bucket.tryConsume(800, fullTime) == true + bucket.tryConsume(200, fullTime) == true # Out of budget - bucket.tryConsume(100) == false + bucket.tryConsume(100, fullTime) == false test "Async test": - var bucket = TokenBucket.new(1000, 500.milliseconds) + var bucket = TokenBucket.new(1000, 1000.milliseconds) check: bucket.tryConsume(1000) == true var toWait = newSeq[Future[void]]() @@ -41,28 +42,26 @@ suite "Token Bucket": waitFor(allFutures(toWait)) let duration = Moment.now() - start - check: duration in 700.milliseconds .. 1100.milliseconds + check: duration in 1400.milliseconds .. 2200.milliseconds test "Over budget async": - var bucket = TokenBucket.new(100, 10.milliseconds) + var bucket = TokenBucket.new(100, 100.milliseconds) # Consume 10* the budget cap let beforeStart = Moment.now() - waitFor(bucket.consume(1000).wait(1.seconds)) - when not defined(macosx): - # CI's macos scheduler is so jittery that this tests sometimes takes >500ms - # the test will still fail if it's >1 seconds - check Moment.now() - beforeStart in 90.milliseconds .. 150.milliseconds + waitFor(bucket.consume(1000).wait(5.seconds)) + check Moment.now() - beforeStart in 900.milliseconds .. 1500.milliseconds test "Sync manual replenish": var bucket = TokenBucket.new(1000, 0.seconds) + let start = Moment.now() check: - bucket.tryConsume(1000) == true - bucket.tryConsume(1000) == false + bucket.tryConsume(1000, start) == true + bucket.tryConsume(1000, start) == false bucket.replenish(2000) check: - bucket.tryConsume(1000) == true + bucket.tryConsume(1000, start) == true # replenish is capped to the bucket max - bucket.tryConsume(1000) == false + bucket.tryConsume(1000, start) == false test "Async manual replenish": var bucket = TokenBucket.new(10 * 150, 0.seconds) @@ -102,24 +101,25 @@ suite "Token Bucket": test "Very long replenish": var bucket = TokenBucket.new(7000, 1.hours) - check bucket.tryConsume(7000) - check bucket.tryConsume(1) == false + let start = Moment.now() + check bucket.tryConsume(7000, start) + check bucket.tryConsume(1, start) == false # With this setting, it takes 514 milliseconds # to tick one. Check that we can eventually # consume, even if we update multiple time # before that - let start = Moment.now() - while Moment.now() - start >= 514.milliseconds: - check bucket.tryConsume(1) == false - waitFor(sleepAsync(10.milliseconds)) + var fakeNow = start + while fakeNow - start < 514.milliseconds: + check bucket.tryConsume(1, fakeNow) == false + fakeNow += 30.milliseconds - check bucket.tryConsume(1) == false + check bucket.tryConsume(1, fakeNow) == true test "Short replenish": var bucket = TokenBucket.new(15000, 1.milliseconds) - check bucket.tryConsume(15000) - check bucket.tryConsume(1) == false + let start = Moment.now() + check bucket.tryConsume(15000, start) + check bucket.tryConsume(1, start) == false - waitFor(sleepAsync(1.milliseconds)) - check bucket.tryConsume(15000) == true + check bucket.tryConsume(15000, start + 1.milliseconds) == true diff --git a/tests/testshttpserver.nim b/tests/testshttpserver.nim index a258cc95..8aacb8e4 100644 --- a/tests/testshttpserver.nim +++ b/tests/testshttpserver.nim @@ -6,8 +6,9 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/strutils -import unittest2 -import ../chronos, ../chronos/apps/http/shttpserver +import ".."/chronos/unittest2/asynctests +import ".."/chronos, + ".."/chronos/apps/http/shttpserver import stew/base10 {.used.} @@ -115,7 +116,7 @@ suite "Secure HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let serverFlags = {Secure} @@ -154,7 +155,7 @@ suite "Secure HTTP server testing suite": else: serverRes = true testFut.complete() - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let serverFlags = {Secure} @@ -178,3 +179,6 @@ suite "Secure HTTP server testing suite": return serverRes and data == "EXCEPTION" check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true + + test "Leaks test": + checkLeaks() diff --git a/tests/teststream.nim b/tests/teststream.nim index 7601a397..9e1ce557 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -6,7 +6,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, os] -import unittest2 +import ".."/chronos/unittest2/asynctests import ".."/chronos, ".."/chronos/[osdefs, oserrno] {.used.} @@ -34,7 +34,7 @@ suite "Stream Transport test suite": ] else: let addresses = [ - initTAddress("127.0.0.1:33335"), + initTAddress("127.0.0.1:0"), initTAddress(r"/tmp/testpipe") ] @@ -43,7 +43,7 @@ suite "Stream Transport test suite": var markFD: int proc getCurrentFD(): int = - let local = initTAddress("127.0.0.1:33334") + let local = initTAddress("127.0.0.1:0") let sock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM, Protocol.IPPROTO_UDP) closeSocket(sock) @@ -348,7 +348,7 @@ suite "Stream Transport test suite": proc test1(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient1, {ReuseAddr}) server.start() - result = await swarmManager1(address) + result = await swarmManager1(server.local) server.stop() server.close() await server.join() @@ -356,7 +356,7 @@ suite "Stream Transport test suite": proc test2(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient2, {ReuseAddr}) server.start() - result = await swarmManager2(address) + result = await swarmManager2(server.local) server.stop() server.close() await server.join() @@ -364,7 +364,7 @@ suite "Stream Transport test suite": proc test3(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient3, {ReuseAddr}) server.start() - result = await swarmManager3(address) + result = await swarmManager3(server.local) server.stop() server.close() await server.join() @@ -372,7 +372,7 @@ suite "Stream Transport test suite": proc testSendFile(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient4, {ReuseAddr}) server.start() - result = await swarmManager4(address) + result = await swarmManager4(server.local) server.stop() server.close() await server.join() @@ -414,7 +414,7 @@ suite "Stream Transport test suite": var server = createStreamServer(address, serveClient, {ReuseAddr}) server.start() - result = await swarmManager(address) + result = await swarmManager(server.local) await server.join() proc testWCR(address: TransportAddress): Future[int] {.async.} = @@ -456,13 +456,13 @@ suite "Stream Transport test suite": var server = createStreamServer(address, serveClient, {ReuseAddr}) server.start() - result = await swarmManager(address) + result = await swarmManager(server.local) await server.join() proc test7(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient7, {ReuseAddr}) server.start() - result = await swarmWorker7(address) + result = await swarmWorker7(server.local) server.stop() server.close() await server.join() @@ -470,7 +470,7 @@ suite "Stream Transport test suite": proc test8(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient8, {ReuseAddr}) server.start() - result = await swarmWorker8(address) + result = await swarmWorker8(server.local) await server.join() # proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} = @@ -553,7 +553,7 @@ suite "Stream Transport test suite": proc test11(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient11, {ReuseAddr}) server.start() - result = await swarmWorker11(address) + result = await swarmWorker11(server.local) server.stop() server.close() await server.join() @@ -579,7 +579,7 @@ suite "Stream Transport test suite": proc test12(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient12, {ReuseAddr}) server.start() - result = await swarmWorker12(address) + result = await swarmWorker12(server.local) server.stop() server.close() await server.join() @@ -601,7 +601,7 @@ suite "Stream Transport test suite": proc test13(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient13, {ReuseAddr}) server.start() - result = await swarmWorker13(address) + result = await swarmWorker13(server.local) server.stop() server.close() await server.join() @@ -621,7 +621,7 @@ suite "Stream Transport test suite": subres = 0 server.start() - var transp = await connect(address) + var transp = await connect(server.local) var fut = swarmWorker(transp) # We perfrom shutdown(SHUT_RD/SD_RECEIVE) for the socket, in such way its # possible to emulate socket's EOF. @@ -674,7 +674,7 @@ suite "Stream Transport test suite": proc test16(address: TransportAddress): Future[int] {.async.} = var server = createStreamServer(address, serveClient16, {ReuseAddr}) server.start() - result = await swarmWorker16(address) + result = await swarmWorker16(server.local) server.stop() server.close() await server.join() @@ -701,7 +701,7 @@ suite "Stream Transport test suite": var server = createStreamServer(address, client, {ReuseAddr}) server.start() var msg = "HELLO" - var ntransp = await connect(address) + var ntransp = await connect(server.local) await syncFut while true: var res = await ntransp.write(msg) @@ -763,7 +763,7 @@ suite "Stream Transport test suite": var transp: StreamTransport try: - transp = await connect(address) + transp = await connect(server.local) flag = true except CatchableError: server.stop() @@ -796,31 +796,31 @@ suite "Stream Transport test suite": server.start() try: var r1, r2, r3, r4, r5: string - var t1 = await connect(address) + var t1 = await connect(server.local) try: r1 = await t1.readLine(4) finally: await t1.closeWait() - var t2 = await connect(address) + var t2 = await connect(server.local) try: r2 = await t2.readLine(6) finally: await t2.closeWait() - var t3 = await connect(address) + var t3 = await connect(server.local) try: r3 = await t3.readLine(8) finally: await t3.closeWait() - var t4 = await connect(address) + var t4 = await connect(server.local) try: r4 = await t4.readLine(8) finally: await t4.closeWait() - var t5 = await connect(address) + var t5 = await connect(server.local) try: r5 = await t5.readLine() finally: @@ -945,7 +945,7 @@ suite "Stream Transport test suite": var server = createStreamServer(address, serveClient, {ReuseAddr}) server.start() - var t1 = await connect(address) + var t1 = await connect(server.local) try: discard await t1.readLV(2000) except TransportIncompleteError: @@ -959,7 +959,7 @@ suite "Stream Transport test suite": await server.join() return false - var t2 = await connect(address) + var t2 = await connect(server.local) try: var r2 = await t2.readLV(2000) c2 = (r2 == @[]) @@ -972,7 +972,7 @@ suite "Stream Transport test suite": await server.join() return false - var t3 = await connect(address) + var t3 = await connect(server.local) try: discard await t3.readLV(2000) except TransportIncompleteError: @@ -986,7 +986,7 @@ suite "Stream Transport test suite": await server.join() return false - var t4 = await connect(address) + var t4 = await connect(server.local) try: discard await t4.readLV(2000) except TransportIncompleteError: @@ -1000,7 +1000,7 @@ suite "Stream Transport test suite": await server.join() return false - var t5 = await connect(address) + var t5 = await connect(server.local) try: discard await t5.readLV(1000) except ValueError: @@ -1014,7 +1014,7 @@ suite "Stream Transport test suite": await server.join() return false - var t6 = await connect(address) + var t6 = await connect(server.local) try: var expectMsg = createMessage(1024) var r6 = await t6.readLV(2000) @@ -1029,7 +1029,7 @@ suite "Stream Transport test suite": await server.join() return false - var t7 = await connect(address) + var t7 = await connect(server.local) try: var expectMsg = createMessage(1024) var expectDone = "DONE" @@ -1062,7 +1062,7 @@ suite "Stream Transport test suite": try: for i in 0 ..< TestsCount: - transp = await connect(address) + transp = await connect(server.local) await sleepAsync(10.milliseconds) await transp.closeWait() inc(connected) @@ -1117,7 +1117,7 @@ suite "Stream Transport test suite": try: for i in 0 ..< 3: try: - let transp = await connect(address) + let transp = await connect(server.local) await sleepAsync(10.milliseconds) await transp.closeWait() except TransportTooManyError: @@ -1166,7 +1166,7 @@ suite "Stream Transport test suite": await server.closeWait() var acceptFut = acceptTask(server) - var transp = await connect(address) + var transp = await connect(server.local) await server.join() await transp.closeWait() await acceptFut @@ -1187,7 +1187,7 @@ suite "Stream Transport test suite": await server.closeWait() var acceptFut = acceptTask(server) - var transp = await connect(address) + var transp = await connect(server.local) await server.join() await transp.closeWait() await acceptFut @@ -1259,46 +1259,39 @@ suite "Stream Transport test suite": return buffer == message proc testConnectBindLocalAddress() {.async.} = - let dst1 = initTAddress("127.0.0.1:33335") - let dst2 = initTAddress("127.0.0.1:33336") - let dst3 = initTAddress("127.0.0.1:33337") proc client(server: StreamServer, transp: StreamTransport) {.async.} = await transp.closeWait() - # We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when - # running the test multiple times or if a test ran previously used the same port. - let servers = - [createStreamServer(dst1, client, {ReuseAddr}), - createStreamServer(dst2, client, {ReuseAddr}), - createStreamServer(dst3, client, {ReusePort})] + let server1 = createStreamServer(initTAddress("127.0.0.1:0"), client) + let server2 = createStreamServer(initTAddress("127.0.0.1:0"), client) + let server3 = createStreamServer(initTAddress("127.0.0.1:0"), client, {ReusePort}) - for server in servers: - server.start() + server1.start() + server2.start() + server3.start() - let ta = initTAddress("0.0.0.0:35000") - - # It works cause there's no active listening socket bound to ta and we are using ReuseAddr - var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr}) - var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr}) - - # It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort - var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort}) + # It works cause even though there's an active listening socket bound to dst3, we are using ReusePort + var transp1 = await connect(server1.local, localAddress = server3.local, flags={SocketFlags.ReusePort}) + var transp2 = await connect(server2.local, localAddress = server3.local, flags={SocketFlags.ReusePort}) expect(TransportOsError): - var transp2 {.used.} = await connect(dst3, localAddress = ta) + var transp2 {.used.} = await connect(server2.local, localAddress = server3.local) expect(TransportOsError): - var transp3 {.used.} = - await connect(dst3, localAddress = initTAddress(":::35000")) + var transp3 {.used.} = await connect(server2.local, localAddress = initTAddress("::", server3.local.port)) await transp1.closeWait() await transp2.closeWait() - await transp3.closeWait() - for server in servers: - server.stop() - await server.closeWait() + server1.stop() + await server1.closeWait() + + server2.stop() + await server2.closeWait() + + server3.stop() + await server3.closeWait() markFD = getCurrentFD() @@ -1339,7 +1332,10 @@ suite "Stream Transport test suite": else: skip() else: - check waitFor(testSendFile(addresses[i])) == FilesCount + if defined(emscripten): + skip() + else: + check waitFor(testSendFile(addresses[i])) == FilesCount test prefixes[i] & "Connection refused test": var address: TransportAddress if addresses[i].family == AddressFamily.Unix: @@ -1370,10 +1366,11 @@ suite "Stream Transport test suite": test prefixes[i] & "close() while in accept() waiting test": check waitFor(testAcceptClose(addresses[i])) == true test prefixes[i] & "Intermediate transports leak test #1": + checkLeaks() when defined(windows): skip() else: - check getTracker("stream.transport").isLeaked() == false + checkLeaks(StreamTransportTrackerName) test prefixes[i] & "accept() too many file descriptors test": when defined(windows): skip() @@ -1389,10 +1386,8 @@ suite "Stream Transport test suite": check waitFor(testPipe()) == true test "[IP] bind connect to local address": waitFor(testConnectBindLocalAddress()) - test "Servers leak test": - check getTracker("stream.server").isLeaked() == false - test "Transports leak test": - check getTracker("stream.transport").isLeaked() == false + test "Leaks test": + checkLeaks() test "File descriptors leak test": when defined(windows): # Windows handle numbers depends on many conditions, so we can't use diff --git a/tests/testthreadsync.nim b/tests/testthreadsync.nim new file mode 100644 index 00000000..fc85dc8c --- /dev/null +++ b/tests/testthreadsync.nim @@ -0,0 +1,369 @@ +# Chronos Test Suite +# (c) Copyright 2023-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[cpuinfo, locks, strutils] +import ../chronos/unittest2/asynctests +import ../chronos/threadsync + +{.used.} + +type + ThreadResult = object + value: int + + ThreadResultPtr = ptr ThreadResult + + LockPtr = ptr Lock + + ThreadArg = object + signal: ThreadSignalPtr + retval: ThreadResultPtr + index: int + + ThreadArg2 = object + signal1: ThreadSignalPtr + signal2: ThreadSignalPtr + retval: ThreadResultPtr + + ThreadArg3 = object + lock: LockPtr + signal: ThreadSignalPtr + retval: ThreadResultPtr + index: int + + WaitSendKind {.pure.} = enum + Sync, Async + +const + TestsCount = 1000 + +suite "Asynchronous multi-threading sync primitives test suite": + proc setResult(thr: ThreadResultPtr, value: int) = + thr[].value = value + + proc new(t: typedesc[ThreadResultPtr], value: int = 0): ThreadResultPtr = + var res = cast[ThreadResultPtr](allocShared0(sizeof(ThreadResult))) + res[].value = value + res + + proc free(thr: ThreadResultPtr) = + doAssert(not(isNil(thr))) + deallocShared(thr) + + let numProcs = countProcessors() * 2 + + template threadSignalTest(sendFlag, waitFlag: WaitSendKind) = + proc testSyncThread(arg: ThreadArg) {.thread.} = + let res = waitSync(arg.signal, 1500.milliseconds) + if res.isErr(): + arg.retval.setResult(1) + else: + if res.get(): + arg.retval.setResult(2) + else: + arg.retval.setResult(3) + + proc testAsyncThread(arg: ThreadArg) {.thread.} = + proc testAsyncCode(arg: ThreadArg) {.async.} = + try: + await wait(arg.signal).wait(1500.milliseconds) + arg.retval.setResult(2) + except AsyncTimeoutError: + arg.retval.setResult(3) + except CatchableError: + arg.retval.setResult(1) + + waitFor testAsyncCode(arg) + + let signal = ThreadSignalPtr.new().tryGet() + var args: seq[ThreadArg] + var threads = newSeq[Thread[ThreadArg]](numProcs) + for i in 0 ..< numProcs: + let + res = ThreadResultPtr.new() + arg = ThreadArg(signal: signal, retval: res, index: i) + args.add(arg) + case waitFlag + of WaitSendKind.Sync: + createThread(threads[i], testSyncThread, arg) + of WaitSendKind.Async: + createThread(threads[i], testAsyncThread, arg) + + await sleepAsync(500.milliseconds) + case sendFlag + of WaitSendKind.Sync: + check signal.fireSync().isOk() + of WaitSendKind.Async: + await signal.fire() + + joinThreads(threads) + + var ncheck: array[3, int] + for item in args: + if item.retval[].value == 1: + inc(ncheck[0]) + elif item.retval[].value == 2: + inc(ncheck[1]) + elif item.retval[].value == 3: + inc(ncheck[2]) + free(item.retval) + check: + signal.close().isOk() + ncheck[0] == 0 + ncheck[1] == 1 + ncheck[2] == numProcs - 1 + + template threadSignalTest2(testsCount: int, + sendFlag, waitFlag: WaitSendKind) = + proc testSyncThread(arg: ThreadArg2) {.thread.} = + for i in 0 ..< testsCount: + block: + let res = waitSync(arg.signal1, 1500.milliseconds) + if res.isErr(): + arg.retval.setResult(-1) + return + if not(res.get()): + arg.retval.setResult(-2) + return + + block: + let res = arg.signal2.fireSync() + if res.isErr(): + arg.retval.setResult(-3) + return + + arg.retval.setResult(i + 1) + + proc testAsyncThread(arg: ThreadArg2) {.thread.} = + proc testAsyncCode(arg: ThreadArg2) {.async.} = + for i in 0 ..< testsCount: + try: + await wait(arg.signal1).wait(1500.milliseconds) + except AsyncTimeoutError: + arg.retval.setResult(-2) + return + except AsyncError: + arg.retval.setResult(-1) + return + except CatchableError: + arg.retval.setResult(-3) + return + + try: + await arg.signal2.fire() + except AsyncError: + arg.retval.setResult(-4) + return + except CatchableError: + arg.retval.setResult(-5) + return + + arg.retval.setResult(i + 1) + + waitFor testAsyncCode(arg) + + let + signal1 = ThreadSignalPtr.new().tryGet() + signal2 = ThreadSignalPtr.new().tryGet() + retval = ThreadResultPtr.new() + arg = ThreadArg2(signal1: signal1, signal2: signal2, retval: retval) + var thread: Thread[ThreadArg2] + + case waitFlag + of WaitSendKind.Sync: + createThread(thread, testSyncThread, arg) + of WaitSendKind.Async: + createThread(thread, testAsyncThread, arg) + + let start = Moment.now() + for i in 0 ..< testsCount: + case sendFlag + of WaitSendKind.Sync: + block: + let res = signal1.fireSync() + check res.isOk() + block: + let res = waitSync(arg.signal2, 1500.milliseconds) + check: + res.isOk() + res.get() == true + of WaitSendKind.Async: + await arg.signal1.fire() + await wait(arg.signal2).wait(1500.milliseconds) + joinThreads(thread) + let finish = Moment.now() + let perf = (float64(nanoseconds(1.seconds)) / + float64(nanoseconds(finish - start))) * float64(testsCount) + echo "Switches tested: ", testsCount, ", elapsed time: ", (finish - start), + ", performance = ", formatFloat(perf, ffDecimal, 4), + " switches/second" + + check: + arg.retval[].value == testsCount + + template threadSignalTest3(testsCount: int, + sendFlag, waitFlag: WaitSendKind) = + proc testSyncThread(arg: ThreadArg3) {.thread.} = + withLock(arg.lock[]): + let res = waitSync(arg.signal, 10.milliseconds) + if res.isErr(): + arg.retval.setResult(1) + else: + if res.get(): + arg.retval.setResult(2) + else: + arg.retval.setResult(3) + + proc testAsyncThread(arg: ThreadArg3) {.thread.} = + proc testAsyncCode(arg: ThreadArg3) {.async.} = + withLock(arg.lock[]): + try: + await wait(arg.signal).wait(10.milliseconds) + arg.retval.setResult(2) + except AsyncTimeoutError: + arg.retval.setResult(3) + except CatchableError: + arg.retval.setResult(1) + + waitFor testAsyncCode(arg) + + let signal = ThreadSignalPtr.new().tryGet() + var args: seq[ThreadArg3] + var threads = newSeq[Thread[ThreadArg3]](numProcs) + var lockPtr = cast[LockPtr](allocShared0(sizeof(Lock))) + initLock(lockPtr[]) + acquire(lockPtr[]) + + for i in 0 ..< numProcs: + let + res = ThreadResultPtr.new() + arg = ThreadArg3(signal: signal, retval: res, index: i, lock: lockPtr) + args.add(arg) + case waitFlag + of WaitSendKind.Sync: + createThread(threads[i], testSyncThread, arg) + of WaitSendKind.Async: + createThread(threads[i], testAsyncThread, arg) + + await sleepAsync(500.milliseconds) + case sendFlag + of WaitSendKind.Sync: + for i in 0 ..< testsCount: + check signal.fireSync().isOk() + of WaitSendKind.Async: + for i in 0 ..< testsCount: + await signal.fire() + + release(lockPtr[]) + joinThreads(threads) + deinitLock(lockPtr[]) + deallocShared(lockPtr) + + var ncheck: array[3, int] + for item in args: + if item.retval[].value == 1: + inc(ncheck[0]) + elif item.retval[].value == 2: + inc(ncheck[1]) + elif item.retval[].value == 3: + inc(ncheck[2]) + free(item.retval) + check: + signal.close().isOk() + ncheck[0] == 0 + ncheck[1] == 1 + ncheck[2] == numProcs - 1 + + template threadSignalTest4(testsCount: int, + sendFlag, waitFlag: WaitSendKind) = + let signal = ThreadSignalPtr.new().tryGet() + let start = Moment.now() + for i in 0 ..< testsCount: + case sendFlag + of WaitSendKind.Sync: + check signal.fireSync().isOk() + of WaitSendKind.Async: + await signal.fire() + + case waitFlag + of WaitSendKind.Sync: + check waitSync(signal).isOk() + of WaitSendKind.Async: + await wait(signal) + let finish = Moment.now() + let perf = (float64(nanoseconds(1.seconds)) / + float64(nanoseconds(finish - start))) * float64(testsCount) + echo "Switches tested: ", testsCount, ", elapsed time: ", (finish - start), + ", performance = ", formatFloat(perf, ffDecimal, 4), + " switches/second" + + check: + signal.close.isOk() + + asyncTest "ThreadSignal: Multiple [" & $numProcs & + "] threads waiting test [sync -> sync]": + threadSignalTest(WaitSendKind.Sync, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Multiple [" & $numProcs & + "] threads waiting test [async -> async]": + threadSignalTest(WaitSendKind.Async, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple [" & $numProcs & + "] threads waiting test [async -> sync]": + threadSignalTest(WaitSendKind.Async, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Multiple [" & $numProcs & + "] threads waiting test [sync -> async]": + threadSignalTest(WaitSendKind.Sync, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & + "] test [sync -> sync]": + threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & + "] test [async -> async]": + threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & + "] test [sync -> async]": + threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount & + "] test [async -> sync]": + threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Multiple signals [" & $TestsCount & + "] to multiple threads [" & $numProcs & "] test [sync -> sync]": + threadSignalTest3(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Multiple signals [" & $TestsCount & + "] to multiple threads [" & $numProcs & "] test [async -> async]": + threadSignalTest3(TestsCount, WaitSendKind.Async, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple signals [" & $TestsCount & + "] to multiple threads [" & $numProcs & "] test [sync -> async]": + threadSignalTest3(TestsCount, WaitSendKind.Sync, WaitSendKind.Async) + + asyncTest "ThreadSignal: Multiple signals [" & $TestsCount & + "] to multiple threads [" & $numProcs & "] test [async -> sync]": + threadSignalTest3(TestsCount, WaitSendKind.Async, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount & + "] test [sync -> sync]": + threadSignalTest4(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync) + + asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount & + "] test [sync -> sync]": + threadSignalTest4(TestsCount, WaitSendKind.Async, WaitSendKind.Async) + + asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount & + "] test [sync -> async]": + threadSignalTest4(TestsCount, WaitSendKind.Sync, WaitSendKind.Async) + + asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount & + "] test [async -> sync]": + threadSignalTest4(TestsCount, WaitSendKind.Async, WaitSendKind.Sync)