diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim index ba2b1d4..b948fbd 100644 --- a/chronos/apps/http/httpbodyrw.nim +++ b/chronos/apps/http/httpbodyrw.nim @@ -25,71 +25,6 @@ type bstate*: HttpState streams*: seq[AsyncStreamWriter] - HttpBodyTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - -proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe, raises: [].} -proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe, raises: [].} - -proc getHttpBodyWriterTracker(): HttpBodyTracker {.inline.} = - var res = cast[HttpBodyTracker](getTracker(HttpBodyWriterTrackerName)) - if isNil(res): - res = setupHttpBodyWriterTracker() - res - -proc getHttpBodyReaderTracker(): HttpBodyTracker {.inline.} = - var res = cast[HttpBodyTracker](getTracker(HttpBodyReaderTrackerName)) - if isNil(res): - res = setupHttpBodyReaderTracker() - res - -proc dumpHttpBodyWriterTracking(): string {.gcsafe.} = - let tracker = getHttpBodyWriterTracker() - "Opened HTTP body writers: " & $tracker.opened & "\n" & - "Closed HTTP body writers: " & $tracker.closed - -proc dumpHttpBodyReaderTracking(): string {.gcsafe.} = - let tracker = getHttpBodyReaderTracker() - "Opened HTTP body readers: " & $tracker.opened & "\n" & - "Closed HTTP body readers: " & $tracker.closed - -proc leakHttpBodyWriter(): bool {.gcsafe.} = - var tracker = getHttpBodyWriterTracker() - tracker.opened != tracker.closed - -proc leakHttpBodyReader(): bool {.gcsafe.} = - var tracker = getHttpBodyReaderTracker() - tracker.opened != tracker.closed - -proc trackHttpBodyWriter(t: HttpBodyWriter) {.inline.} = - inc(getHttpBodyWriterTracker().opened) - -proc untrackHttpBodyWriter*(t: HttpBodyWriter) {.inline.} = - inc(getHttpBodyWriterTracker().closed) - -proc trackHttpBodyReader(t: HttpBodyReader) {.inline.} = - inc(getHttpBodyReaderTracker().opened) - -proc untrackHttpBodyReader*(t: HttpBodyReader) {.inline.} = - inc(getHttpBodyReaderTracker().closed) - -proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe.} = - var res = HttpBodyTracker(opened: 0, closed: 0, - dump: dumpHttpBodyWriterTracking, - isLeaked: leakHttpBodyWriter - ) - addTracker(HttpBodyWriterTrackerName, res) - res - -proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe.} = - var res = HttpBodyTracker(opened: 0, closed: 0, - dump: dumpHttpBodyReaderTracking, - isLeaked: leakHttpBodyReader - ) - addTracker(HttpBodyReaderTrackerName, res) - res - proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = ## HttpBodyReader is AsyncStreamReader which holds references to all the ## ``streams``. Also on close it will close all the ``streams``. @@ -98,7 +33,7 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = doAssert(len(streams) > 0, "At least one stream must be added") var res = HttpBodyReader(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) - trackHttpBodyReader(res) + trackCounter(HttpBodyReaderTrackerName) res proc closeWait*(bstream: HttpBodyReader) {.async.} = @@ -113,7 +48,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async.} = await allFutures(res) await procCall(closeWait(AsyncStreamReader(bstream))) bstream.bstate = HttpState.Closed - untrackHttpBodyReader(bstream) + untrackCounter(HttpBodyReaderTrackerName) proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = ## HttpBodyWriter is AsyncStreamWriter which holds references to all the @@ -123,7 +58,7 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = doAssert(len(streams) > 0, "At least one stream must be added") var res = HttpBodyWriter(bstate: HttpState.Alive, streams: @streams) res.init(streams[0]) - trackHttpBodyWriter(res) + trackCounter(HttpBodyWriterTrackerName) res proc closeWait*(bstream: HttpBodyWriter) {.async.} = @@ -136,7 +71,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async.} = await allFutures(res) await procCall(closeWait(AsyncStreamWriter(bstream))) bstream.bstate = HttpState.Closed - untrackHttpBodyWriter(bstream) + untrackCounter(HttpBodyWriterTrackerName) proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} = if len(bstream.streams) == 1: diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index 311ff1b..6e9ea0c 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -190,10 +190,6 @@ type HttpClientFlags* = set[HttpClientFlag] - HttpClientTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - ServerSentEvent* = object name*: string data*: string @@ -204,100 +200,6 @@ type # HttpClientResponseRef valid states are # Open -> (Finished, Error) -> (Closing, Closed) -proc setupHttpClientConnectionTracker(): HttpClientTracker {. - gcsafe, raises: [].} -proc setupHttpClientRequestTracker(): HttpClientTracker {. - gcsafe, raises: [].} -proc setupHttpClientResponseTracker(): HttpClientTracker {. - gcsafe, raises: [].} - -proc getHttpClientConnectionTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientConnectionTrackerName)) - if isNil(res): - res = setupHttpClientConnectionTracker() - res - -proc getHttpClientRequestTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientRequestTrackerName)) - if isNil(res): - res = setupHttpClientRequestTracker() - res - -proc getHttpClientResponseTracker(): HttpClientTracker {.inline.} = - var res = cast[HttpClientTracker](getTracker(HttpClientResponseTrackerName)) - if isNil(res): - res = setupHttpClientResponseTracker() - res - -proc dumpHttpClientConnectionTracking(): string {.gcsafe.} = - let tracker = getHttpClientConnectionTracker() - "Opened HTTP client connections: " & $tracker.opened & "\n" & - "Closed HTTP client connections: " & $tracker.closed - -proc dumpHttpClientRequestTracking(): string {.gcsafe.} = - let tracker = getHttpClientRequestTracker() - "Opened HTTP client requests: " & $tracker.opened & "\n" & - "Closed HTTP client requests: " & $tracker.closed - -proc dumpHttpClientResponseTracking(): string {.gcsafe.} = - let tracker = getHttpClientResponseTracker() - "Opened HTTP client responses: " & $tracker.opened & "\n" & - "Closed HTTP client responses: " & $tracker.closed - -proc leakHttpClientConnection(): bool {.gcsafe.} = - var tracker = getHttpClientConnectionTracker() - tracker.opened != tracker.closed - -proc leakHttpClientRequest(): bool {.gcsafe.} = - var tracker = getHttpClientRequestTracker() - tracker.opened != tracker.closed - -proc leakHttpClientResponse(): bool {.gcsafe.} = - var tracker = getHttpClientResponseTracker() - tracker.opened != tracker.closed - -proc trackHttpClientConnection(t: HttpClientConnectionRef) {.inline.} = - inc(getHttpClientConnectionTracker().opened) - -proc untrackHttpClientConnection*(t: HttpClientConnectionRef) {.inline.} = - inc(getHttpClientConnectionTracker().closed) - -proc trackHttpClientRequest(t: HttpClientRequestRef) {.inline.} = - inc(getHttpClientRequestTracker().opened) - -proc untrackHttpClientRequest*(t: HttpClientRequestRef) {.inline.} = - inc(getHttpClientRequestTracker().closed) - -proc trackHttpClientResponse(t: HttpClientResponseRef) {.inline.} = - inc(getHttpClientResponseTracker().opened) - -proc untrackHttpClientResponse*(t: HttpClientResponseRef) {.inline.} = - inc(getHttpClientResponseTracker().closed) - -proc setupHttpClientConnectionTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientConnectionTracking, - isLeaked: leakHttpClientConnection - ) - addTracker(HttpClientConnectionTrackerName, res) - res - -proc setupHttpClientRequestTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientRequestTracking, - isLeaked: leakHttpClientRequest - ) - addTracker(HttpClientRequestTrackerName, res) - res - -proc setupHttpClientResponseTracker(): HttpClientTracker {.gcsafe.} = - var res = HttpClientTracker(opened: 0, closed: 0, - dump: dumpHttpClientResponseTracking, - isLeaked: leakHttpClientResponse - ) - addTracker(HttpClientResponseTrackerName, res) - res - template checkClosed(reqresp: untyped): untyped = if reqresp.connection.state in {HttpClientConnectionState.Closing, HttpClientConnectionState.Closed}: @@ -556,7 +458,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, state: HttpClientConnectionState.Connecting, remoteHostname: ha.id ) - trackHttpClientConnection(res) + trackCounter(HttpClientConnectionTrackerName) res of HttpClientScheme.Secure: let treader = newAsyncStreamReader(transp) @@ -575,7 +477,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, state: HttpClientConnectionState.Connecting, remoteHostname: ha.id ) - trackHttpClientConnection(res) + trackCounter(HttpClientConnectionTrackerName) res proc setError(request: HttpClientRequestRef, error: ref HttpError) {. @@ -615,7 +517,7 @@ proc closeWait(conn: HttpClientConnectionRef) {.async.} = discard await conn.transp.closeWait() conn.state = HttpClientConnectionState.Closed - untrackHttpClientConnection(conn) + untrackCounter(HttpClientConnectionTrackerName) proc connect(session: HttpSessionRef, ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} = @@ -835,7 +737,7 @@ proc closeWait*(request: HttpClientRequestRef) {.async.} = request.session = nil request.error = nil request.state = HttpReqRespState.Closed - untrackHttpClientRequest(request) + untrackCounter(HttpClientRequestTrackerName) proc closeWait*(response: HttpClientResponseRef) {.async.} = if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}: @@ -848,7 +750,7 @@ proc closeWait*(response: HttpClientResponseRef) {.async.} = response.session = nil response.error = nil response.state = HttpReqRespState.Closed - untrackHttpClientResponse(response) + untrackCounter(HttpClientResponseTrackerName) proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] ): HttpResult[HttpClientResponseRef] {.raises: [] .} = @@ -958,7 +860,7 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] httpPipeline: res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive) res.connection.flags.incl(HttpClientConnectionFlag.Response) - trackHttpClientResponse(res) + trackCounter(HttpClientResponseTrackerName) ok(res) proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {. @@ -996,7 +898,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, version: version, flags: flags, headers: HttpTable.init(headers), address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body ) - trackHttpClientRequest(res) + trackCounter(HttpClientRequestTrackerName) res proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, @@ -1012,7 +914,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, version: version, flags: flags, headers: HttpTable.init(headers), address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body ) - trackHttpClientRequest(res) + trackCounter(HttpClientRequestTrackerName) ok(res) proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index cc2478d..5a4a628 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -13,6 +13,15 @@ import ../../streams/[asyncstream, boundstream] export asyncloop, asyncsync, results, httputils, strutils const + HttpServerUnsecureConnectionTrackerName* = + "httpserver.unsecure.connection" + HttpServerSecureConnectionTrackerName* = + "httpserver.secure.connection" + HttpServerRequestTrackerName* = + "httpserver.request" + HttpServerResponseTrackerName* = + "httpserver.response" + HeadersMark* = @[0x0d'u8, 0x0a'u8, 0x0d'u8, 0x0a'u8] PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete} diff --git a/chronos/apps/http/httpdebug.nim b/chronos/apps/http/httpdebug.nim new file mode 100644 index 0000000..2f40674 --- /dev/null +++ b/chronos/apps/http/httpdebug.nim @@ -0,0 +1,120 @@ +# +# Chronos HTTP/S server implementation +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/tables +import stew/results +import ../../timer +import httpserver, shttpserver +from httpclient import HttpClientScheme +from httpcommon import HttpState +from ../../osdefs import SocketHandle +from ../../transports/common import TransportAddress, ServerFlags +export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState + +{.push raises: [].} + +type + ConnectionType* {.pure.} = enum + NonSecure, Secure + + ConnectionState* {.pure.} = enum + Accepted, Alive, Closing, Closed + + ServerConnectionInfo* = object + handle*: SocketHandle + connectionType*: ConnectionType + connectionState*: ConnectionState + remoteAddress*: Opt[TransportAddress] + localAddress*: Opt[TransportAddress] + acceptMoment*: Moment + createMoment*: Opt[Moment] + + ServerInfo* = object + connectionType*: ConnectionType + address*: TransportAddress + state*: HttpServerState + maxConnections*: int + backlogSize*: int + baseUri*: Uri + serverIdent*: string + flags*: set[HttpServerFlags] + socketFlags*: set[ServerFlags] + headersTimeout*: Duration + bufferSize*: int + maxHeadersSize*: int + maxRequestBodySize*: int + +proc getConnectionType*( + server: HttpServerRef | SecureHttpServerRef): ConnectionType = + when server is SecureHttpServerRef: + ConnectionType.Secure + else: + if HttpServerFlags.Secure in server.flags: + ConnectionType.Secure + else: + ConnectionType.NonSecure + +proc getServerInfo*(server: HttpServerRef|SecureHttpServerRef): ServerInfo = + ServerInfo( + connectionType: server.getConnectionType(), + address: server.address, + state: server.state(), + maxConnections: server.maxConnections, + backlogSize: server.backlogSize, + baseUri: server.baseUri, + serverIdent: server.serverIdent, + flags: server.flags, + socketFlags: server.socketFlags, + headersTimeout: server.headersTimeout, + bufferSize: server.bufferSize, + maxHeadersSize: server.maxHeadersSize, + maxRequestBodySize: server.maxRequestBodySize + ) + +proc getConnectionState*(holder: HttpConnectionHolderRef): ConnectionState = + if not(isNil(holder.connection)): + case holder.connection.state + of HttpState.Alive: ConnectionState.Alive + of HttpState.Closing: ConnectionState.Closing + of HttpState.Closed: ConnectionState.Closed + else: + ConnectionState.Accepted + +proc init*(t: typedesc[ServerConnectionInfo], + holder: HttpConnectionHolderRef): ServerConnectionInfo = + let + localAddress = + try: + Opt.some(holder.transp.localAddress()) + except CatchableError: + Opt.none(TransportAddress) + remoteAddress = + try: + Opt.some(holder.transp.remoteAddress()) + except CatchableError: + Opt.none(TransportAddress) + + ServerConnectionInfo( + handle: SocketHandle(holder.transp.fd), + connectionType: holder.server.getConnectionType(), + connectionState: holder.getConnectionState(), + remoteAddress: remoteAddress, + localAddress: localAddress, + acceptMoment: holder.acceptMoment, + createMoment: + if not(isNil(holder.connection)): + Opt.some(holder.connection.createMoment) + else: + Opt.none(Moment) + ) + +proc getConnections*(server: HttpServerRef): seq[ServerConnectionInfo] = + var res: seq[ServerConnectionInfo] + for holder in server.connections.values(): + res.add(ServerConnectionInfo.init(holder)) + res diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index 03aaaf9..b5b8bfc 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -29,18 +29,20 @@ type ## Enable HTTP/1.1 pipelining. HttpServerError* {.pure.} = enum - TimeoutError, CatchableError, RecoverableError, CriticalError, - DisconnectError + InterruptError, TimeoutError, CatchableError, RecoverableError, + CriticalError, DisconnectError HttpServerState* {.pure.} = enum ServerRunning, ServerStopped, ServerClosed HttpProcessError* = object - error*: HttpServerError + kind*: HttpServerError code*: HttpCode exc*: ref CatchableError - remote*: TransportAddress + remote*: Opt[TransportAddress] + ConnectionFence* = Result[HttpConnectionRef, HttpProcessError] + ResponseFence* = Result[HttpResponseRef, HttpProcessError] RequestFence* = Result[HttpRequestRef, HttpProcessError] HttpRequestFlags* {.pure.} = enum @@ -53,7 +55,7 @@ type Plain, SSE, Chunked HttpResponseState* {.pure.} = enum - Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb + Empty, Prepared, Sending, Finished, Failed, Cancelled, Default HttpProcessCallback* = proc(req: RequestFence): Future[HttpResponseRef] {. @@ -64,6 +66,20 @@ type transp: StreamTransport): Future[HttpConnectionRef] {. gcsafe, raises: [].} + HttpCloseConnectionCallback* = + proc(connection: HttpConnectionRef): Future[void] {. + gcsafe, raises: [].} + + HttpConnectionHolder* = object of RootObj + connection*: HttpConnectionRef + server*: HttpServerRef + future*: Future[void] + transp*: StreamTransport + acceptMoment*: Moment + connectionId*: string + + HttpConnectionHolderRef* = ref HttpConnectionHolder + HttpServer* = object of RootObj instance*: StreamServer address*: TransportAddress @@ -74,7 +90,7 @@ type serverIdent*: string flags*: set[HttpServerFlags] socketFlags*: set[ServerFlags] - connections*: Table[string, Future[void]] + connections*: OrderedTable[string, HttpConnectionHolderRef] acceptLoop*: Future[void] lifetime*: Future[void] headersTimeout*: Duration @@ -122,11 +138,13 @@ type HttpConnection* = object of RootObj state*: HttpState server*: HttpServerRef - transp: StreamTransport + transp*: StreamTransport mainReader*: AsyncStreamReader mainWriter*: AsyncStreamWriter reader*: AsyncStreamReader writer*: AsyncStreamWriter + closeCb*: HttpCloseConnectionCallback + createMoment*: Moment buffer: seq[byte] HttpConnectionRef* = ref HttpConnection @@ -134,9 +152,24 @@ type ByteChar* = string | seq[byte] proc init(htype: typedesc[HttpProcessError], error: HttpServerError, - exc: ref CatchableError, remote: TransportAddress, - code: HttpCode): HttpProcessError {.raises: [].} = - HttpProcessError(error: error, exc: exc, remote: remote, code: code) + exc: ref CatchableError, remote: Opt[TransportAddress], + code: HttpCode): HttpProcessError {. + raises: [].} = + HttpProcessError(kind: error, exc: exc, remote: remote, code: code) + +proc init(htype: typedesc[HttpProcessError], + error: HttpServerError): HttpProcessError {. + raises: [].} = + HttpProcessError(kind: error) + +proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef, + transp: StreamTransport, + connectionId: string): HttpConnectionHolderRef = + HttpConnectionHolderRef( + server: server, transp: transp, acceptMoment: Moment.now(), + connectionId: connectionId) + +proc error*(e: HttpProcessError): HttpServerError = e.kind proc createConnection(server: HttpServerRef, transp: StreamTransport): Future[HttpConnectionRef] {. @@ -176,7 +209,7 @@ proc new*(htype: typedesc[HttpServerRef], return err(exc.msg) var res = HttpServerRef( - address: address, + address: serverInstance.localAddress(), instance: serverInstance, processCallback: processCallback, createConnCallback: createConnection, @@ -196,15 +229,22 @@ proc new*(htype: typedesc[HttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initTable[string, Future[void]]() + connections: initOrderedTable[string, HttpConnectionHolderRef]() ) ok(res) -proc getResponseFlags*(req: HttpRequestRef): set[HttpResponseFlags] = +proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] = + var defaultFlags: set[HttpServerFlags] = {} + if isNil(req): return defaultFlags + if isNil(req.connection): return defaultFlags + if isNil(req.connection.server): return defaultFlags + req.connection.server.flags + +proc getResponseFlags(req: HttpRequestRef): set[HttpResponseFlags] = var defaultFlags: set[HttpResponseFlags] = {} case req.version of HttpVersion11: - if HttpServerFlags.Http11Pipeline notin req.connection.server.flags: + if HttpServerFlags.Http11Pipeline notin req.getServerFlags(): return defaultFlags let header = req.headers.getString(ConnectionHeader, "keep-alive") if header == "keep-alive": @@ -214,6 +254,12 @@ proc getResponseFlags*(req: HttpRequestRef): set[HttpResponseFlags] = else: defaultFlags +proc getResponseVersion(reqFence: RequestFence): HttpVersion {.raises: [].} = + if reqFence.isErr(): + HttpVersion11 + else: + reqFence.get().version + proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} = if req.response.isNone(): var resp = HttpResponseRef( @@ -235,9 +281,14 @@ proc getHostname*(server: HttpServerRef): string = else: server.baseUri.hostname -proc dumbResponse*(): HttpResponseRef {.raises: [].} = +proc defaultResponse*(): HttpResponseRef {.raises: [].} = ## Create an empty response to return when request processor got no request. - HttpResponseRef(state: HttpResponseState.Dumb, version: HttpVersion11) + HttpResponseRef(state: HttpResponseState.Default, version: HttpVersion11) + +proc dumbResponse*(): HttpResponseRef {.raises: [], + deprecated: "Please use defaultResponse() instead".} = + ## Create an empty response to return when request processor got no request. + defaultResponse() proc getId(transp: StreamTransport): Result[string, string] {.inline.} = ## Returns string unique transport's identifier as string. @@ -371,6 +422,7 @@ proc prepareRequest(conn: HttpConnectionRef, if strip(expectHeader).toLowerAscii() == "100-continue": request.requestFlags.incl(HttpRequestFlags.ClientExpect) + trackCounter(HttpServerRequestTrackerName) ok(request) proc getBodyReader*(request: HttpRequestRef): HttpResult[HttpBodyReader] = @@ -579,7 +631,7 @@ proc preferredContentType*(request: HttpRequestRef, proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion, code: HttpCode, keepAlive = true, datatype = "text/text", - databody = ""): Future[bool] {.async.} = + databody = "") {.async.} = var answer = $version & " " & $code & "\r\n" answer.add(DateHeader) answer.add(": ") @@ -605,13 +657,90 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion, answer.add(databody) try: await conn.writer.write(answer) - return true + except CancelledError as exc: + raise exc + except CatchableError: + # We ignore errors here, because we indicating error already. + discard + +proc sendErrorResponse(conn: HttpConnectionRef, reqFence: RequestFence, + respError: HttpProcessError): Future[bool] {.async.} = + let version = getResponseVersion(reqFence) + try: + if reqFence.isOk(): + case respError.kind + of HttpServerError.CriticalError: + await conn.sendErrorResponse(version, respError.code, false) + false + of HttpServerError.RecoverableError: + await conn.sendErrorResponse(version, respError.code, true) + true + of HttpServerError.CatchableError: + await conn.sendErrorResponse(version, respError.code, false) + false + of HttpServerError.DisconnectError, + HttpServerError.InterruptError, + HttpServerError.TimeoutError: + raiseAssert("Unexpected response error: " & $respError.kind) + else: + false except CancelledError: - return false - except AsyncStreamWriteError: - return false - except AsyncStreamIncompleteError: - return false + false + +proc sendDefaultResponse(conn: HttpConnectionRef, reqFence: RequestFence, + response: HttpResponseRef): Future[bool] {.async.} = + let + version = getResponseVersion(reqFence) + keepConnection = + if isNil(response): + false + else: + HttpResponseFlags.KeepAlive in response.flags + try: + if reqFence.isOk(): + if isNil(response): + await conn.sendErrorResponse(version, Http404, keepConnection) + keepConnection + else: + case response.state + of HttpResponseState.Empty: + # Response was ignored, so we respond with not found. + await conn.sendErrorResponse(version, Http404, keepConnection) + keepConnection + of HttpResponseState.Prepared: + # Response was prepared but not sent, so we can respond with some + # error code + await conn.sendErrorResponse(HttpVersion11, Http409, keepConnection) + keepConnection + of HttpResponseState.Sending, HttpResponseState.Failed, + HttpResponseState.Cancelled: + # Just drop connection, because we dont know at what stage we are + false + of HttpResponseState.Default: + # Response was ignored, so we respond with not found. + await conn.sendErrorResponse(version, Http404, keepConnection) + keepConnection + of HttpResponseState.Finished: + keepConnection + else: + case reqFence.error.kind + of HttpServerError.TimeoutError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + false + of HttpServerError.CriticalError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + false + of HttpServerError.RecoverableError: + await conn.sendErrorResponse(version, reqFence.error.code, true) + false + of HttpServerError.CatchableError: + await conn.sendErrorResponse(version, reqFence.error.code, false) + false + of HttpServerError.DisconnectError, + HttpServerError.InterruptError: + raiseAssert("Unexpected request error: " & $reqFence.error.kind) + except CancelledError: + false proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} = try: @@ -644,31 +773,33 @@ proc init*(value: var HttpConnection, server: HttpServerRef, mainWriter: newAsyncStreamWriter(transp) ) +proc closeUnsecureConnection(conn: HttpConnectionRef) {.async.} = + if conn.state == HttpState.Alive: + conn.state = HttpState.Closing + var pending: seq[Future[void]] + pending.add(conn.mainReader.closeWait()) + pending.add(conn.mainWriter.closeWait()) + pending.add(conn.transp.closeWait()) + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + untrackCounter(HttpServerUnsecureConnectionTrackerName) + conn.state = HttpState.Closed + proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef, transp: StreamTransport): HttpConnectionRef = var res = HttpConnectionRef() res[].init(server, transp) res.reader = res.mainReader res.writer = res.mainWriter + res.closeCb = closeUnsecureConnection + res.createMoment = Moment.now() + trackCounter(HttpServerUnsecureConnectionTrackerName) res -proc closeWait*(conn: HttpConnectionRef) {.async.} = - if conn.state == HttpState.Alive: - conn.state = HttpState.Closing - var pending: seq[Future[void]] - if conn.reader != conn.mainReader: - pending.add(conn.reader.closeWait()) - if conn.writer != conn.mainWriter: - pending.add(conn.writer.closeWait()) - if len(pending) > 0: - await allFutures(pending) - # After we going to close everything else. - pending.setLen(3) - pending[0] = conn.mainReader.closeWait() - pending[1] = conn.mainWriter.closeWait() - pending[2] = conn.transp.closeWait() - await allFutures(pending) - conn.state = HttpState.Closed +proc closeWait*(conn: HttpConnectionRef): Future[void] = + conn.closeCb(conn) proc closeWait*(req: HttpRequestRef) {.async.} = if req.state == HttpState.Alive: @@ -676,7 +807,12 @@ proc closeWait*(req: HttpRequestRef) {.async.} = req.state = HttpState.Closing let resp = req.response.get() if (HttpResponseFlags.Stream in resp.flags) and not(isNil(resp.writer)): - await resp.writer.closeWait() + var writer = resp.writer.closeWait() + try: + await writer + except CancelledError: + await writer + untrackCounter(HttpServerRequestTrackerName) req.state = HttpState.Closed proc createConnection(server: HttpServerRef, @@ -694,174 +830,168 @@ proc `keepalive=`*(resp: HttpResponseRef, value: bool) = proc keepalive*(resp: HttpResponseRef): bool {.raises: [].} = HttpResponseFlags.KeepAlive in resp.flags -proc processLoop(server: HttpServerRef, transp: StreamTransport, - connId: string) {.async.} = - var - conn: HttpConnectionRef - connArg: RequestFence - runLoop = false - +proc getRemoteAddress(transp: StreamTransport): Opt[TransportAddress] {. + raises: [].} = + if isNil(transp): return Opt.none(TransportAddress) try: - conn = await server.createConnCallback(server, transp) - runLoop = true + Opt.some(transp.remoteAddress()) + except CatchableError: + Opt.none(TransportAddress) + +proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] {. + raises: [].} = + if isNil(connection): return Opt.none(TransportAddress) + getRemoteAddress(connection.transp) + +proc getResponseFence*(connection: HttpConnectionRef, + reqFence: RequestFence): Future[ResponseFence] {. + async.} = + try: + let res = await connection.server.processCallback(reqFence) + ResponseFence.ok(res) except CancelledError: - server.connections.del(connId) - await transp.closeWait() - return + ResponseFence.err(HttpProcessError.init( + HttpServerError.InterruptError)) except HttpCriticalError as exc: - let error = HttpProcessError.init(HttpServerError.CriticalError, exc, - transp.remoteAddress(), exc.code) - connArg = RequestFence.err(error) - runLoop = false + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpRecoverableError as exc: + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except CatchableError as exc: + let address = connection.getRemoteAddress() + ResponseFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http503)) - if not(runLoop): - try: - # We still want to notify process callback about failure, but we ignore - # result. - discard await server.processCallback(connArg) - except CancelledError: - runLoop = false - except CatchableError as exc: - # There should be no exceptions, so we will raise `Defect`. - raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]") +proc getResponseFence*(server: HttpServerRef, + connFence: ConnectionFence): Future[ResponseFence] {. + async.} = + doAssert(connFence.isErr()) + try: + let + reqFence = RequestFence.err(connFence.error) + res = await server.processCallback(reqFence) + ResponseFence.ok(res) + except CancelledError: + ResponseFence.err(HttpProcessError.init( + HttpServerError.InterruptError)) + except HttpCriticalError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpRecoverableError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except CatchableError as exc: + let address = Opt.none(TransportAddress) + ResponseFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http503)) - var breakLoop = false - while runLoop: - var - arg: RequestFence - resp: HttpResponseRef - - try: - let request = - if server.headersTimeout.isInfinite(): - await conn.getRequest() - else: - await conn.getRequest().wait(server.headersTimeout) - arg = RequestFence.ok(request) - except CancelledError: - breakLoop = true - except AsyncTimeoutError as exc: - let error = HttpProcessError.init(HttpServerError.TimeoutError, exc, - transp.remoteAddress(), Http408) - arg = RequestFence.err(error) - except HttpRecoverableError as exc: - let error = HttpProcessError.init(HttpServerError.RecoverableError, exc, - transp.remoteAddress(), exc.code) - arg = RequestFence.err(error) - except HttpCriticalError as exc: - let error = HttpProcessError.init(HttpServerError.CriticalError, exc, - transp.remoteAddress(), exc.code) - arg = RequestFence.err(error) - except HttpDisconnectError as exc: - if HttpServerFlags.NotifyDisconnect in server.flags: - let error = HttpProcessError.init(HttpServerError.DisconnectError, exc, - transp.remoteAddress(), Http400) - arg = RequestFence.err(error) +proc getRequestFence*(server: HttpServerRef, + connection: HttpConnectionRef): Future[RequestFence] {. + async.} = + try: + let res = + if server.headersTimeout.isInfinite(): + await connection.getRequest() else: - breakLoop = true - except CatchableError as exc: - let error = HttpProcessError.init(HttpServerError.CatchableError, exc, - transp.remoteAddress(), Http500) - arg = RequestFence.err(error) + await connection.getRequest().wait(server.headersTimeout) + RequestFence.ok(res) + except CancelledError: + RequestFence.err(HttpProcessError.init(HttpServerError.InterruptError)) + except AsyncTimeoutError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.TimeoutError, exc, address, Http408)) + except HttpRecoverableError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.RecoverableError, exc, address, exc.code)) + except HttpCriticalError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) + except HttpDisconnectError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.DisconnectError, exc, address, Http400)) + except CatchableError as exc: + let address = connection.getRemoteAddress() + RequestFence.err(HttpProcessError.init( + HttpServerError.CatchableError, exc, address, Http500)) - if breakLoop: - break +proc getConnectionFence*(server: HttpServerRef, + transp: StreamTransport): Future[ConnectionFence] {. + async.} = + try: + let res = await server.createConnCallback(server, transp) + ConnectionFence.ok(res) + except CancelledError: + await transp.closeWait() + ConnectionFence.err(HttpProcessError.init(HttpServerError.InterruptError)) + except HttpCriticalError as exc: + await transp.closeWait() + let address = transp.getRemoteAddress() + ConnectionFence.err(HttpProcessError.init( + HttpServerError.CriticalError, exc, address, exc.code)) - breakLoop = false - var lastErrorCode: Opt[HttpCode] - - try: - resp = await conn.server.processCallback(arg) - except CancelledError: - breakLoop = true - except HttpCriticalError as exc: - lastErrorCode = Opt.some(exc.code) - except HttpRecoverableError as exc: - lastErrorCode = Opt.some(exc.code) - except CatchableError: - lastErrorCode = Opt.some(Http503) - - if breakLoop: - break - - if arg.isErr(): - let code = arg.error().code - try: - case arg.error().error - of HttpServerError.TimeoutError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.RecoverableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.CriticalError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.CatchableError: - discard await conn.sendErrorResponse(HttpVersion11, code, false) - of HttpServerError.DisconnectError: - discard - except CancelledError: - # We swallowing `CancelledError` in a loop, but we going to exit - # loop ASAP. - discard - break +proc processRequest(server: HttpServerRef, + connection: HttpConnectionRef, + connId: string): Future[bool] {.async.} = + let requestFence = await getRequestFence(server, connection) + if requestFence.isErr(): + case requestFence.error.kind + of HttpServerError.InterruptError: + return false + of HttpServerError.DisconnectError: + if HttpServerFlags.NotifyDisconnect notin server.flags: + return false else: - let request = arg.get() - var keepConn = HttpResponseFlags.KeepAlive in request.getResponseFlags() - if lastErrorCode.isNone(): - if isNil(resp): - # Response was `nil`. - try: - discard await conn.sendErrorResponse(HttpVersion11, Http404, false) - except CancelledError: - keepConn = false - else: - try: - case resp.state - of HttpResponseState.Empty: - # Response was ignored - discard await conn.sendErrorResponse(HttpVersion11, Http404, - keepConn) - of HttpResponseState.Prepared: - # Response was prepared but not sent. - discard await conn.sendErrorResponse(HttpVersion11, Http409, - keepConn) - else: - # some data was already sent to the client. - keepConn = resp.keepalive() - except CancelledError: - keepConn = false - else: - try: - discard await conn.sendErrorResponse(HttpVersion11, - lastErrorCode.get(), false) - except CancelledError: - keepConn = false + discard - # Closing and releasing all the request resources. - try: - await request.closeWait() - except CancelledError: - # We swallowing `CancelledError` in a loop, but we still need to close - # `request` before exiting. - await request.closeWait() + defer: + if requestFence.isOk(): + await requestFence.get().closeWait() - if not(keepConn): - break + let responseFence = await getResponseFence(connection, requestFence) + if responseFence.isErr() and + (responseFence.error.kind == HttpServerError.InterruptError): + return false - # Connection could be `nil` only when secure handshake is failed. - if not(isNil(conn)): - try: - await conn.closeWait() - except CancelledError: - # Cancellation could be happened while we closing `conn`. But we still - # need to close it. - await conn.closeWait() + if responseFence.isErr(): + await connection.sendErrorResponse(requestFence, responseFence.error) + else: + await connection.sendDefaultResponse(requestFence, responseFence.get()) - server.connections.del(connId) - # if server.maxConnections > 0: - # server.semaphore.release() +proc processLoop(holder: HttpConnectionHolderRef) {.async.} = + let + server = holder.server + transp = holder.transp + connectionId = holder.connectionId + connection = + block: + let res = await server.getConnectionFence(transp) + if res.isErr(): + if res.error.kind != HttpServerError.InterruptError: + discard await server.getResponseFence(res) + server.connections.del(connectionId) + return + res.get() + + holder.connection = connection + + defer: + server.connections.del(connectionId) + await connection.closeWait() + + var runLoop = true + while runLoop: + runLoop = await server.processRequest(connection, connectionId) proc acceptClientLoop(server: HttpServerRef) {.async.} = - var breakLoop = false while true: try: # if server.maxConnections > 0: @@ -872,27 +1002,26 @@ proc acceptClientLoop(server: HttpServerRef) {.async.} = # We are unable to identify remote peer, it means that remote peer # disconnected before identification. await transp.closeWait() - breakLoop = false + break else: let connId = resId.get() - server.connections[connId] = processLoop(server, transp, connId) + let holder = HttpConnectionHolderRef.new(server, transp, resId.get()) + server.connections[connId] = holder + holder.future = processLoop(holder) except CancelledError: # Server was stopped - breakLoop = true + break except TransportOsError: # This is some critical unrecoverable error. - breakLoop = true + break except TransportTooManyError: # Non critical error - breakLoop = false + discard except TransportAbortedError: # Non critical error - breakLoop = false + discard except CatchableError: # Unexpected error - breakLoop = true - - if breakLoop: break proc state*(server: HttpServerRef): HttpServerState {.raises: [].} = @@ -922,11 +1051,11 @@ proc drop*(server: HttpServerRef) {.async.} = ## Drop all pending HTTP connections. var pending: seq[Future[void]] if server.state in {ServerStopped, ServerRunning}: - for fut in server.connections.values(): - if not(fut.finished()): - fut.cancel() - pending.add(fut) + for holder in server.connections.values(): + if not(isNil(holder.future)) and not(holder.future.finished()): + pending.add(holder.future.cancelAndWait()) await allFutures(pending) + server.connections.clear() proc closeWait*(server: HttpServerRef) {.async.} = ## Stop HTTP server and drop all the pending connections. diff --git a/chronos/apps/http/shttpserver.nim b/chronos/apps/http/shttpserver.nim index 93f253b..927ca62 100644 --- a/chronos/apps/http/shttpserver.nim +++ b/chronos/apps/http/shttpserver.nim @@ -24,6 +24,28 @@ type SecureHttpConnectionRef* = ref SecureHttpConnection +proc closeSecConnection(conn: HttpConnectionRef) {.async.} = + if conn.state == HttpState.Alive: + conn.state = HttpState.Closing + var pending: seq[Future[void]] + pending.add(conn.writer.closeWait()) + pending.add(conn.reader.closeWait()) + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + # After we going to close everything else. + pending.setLen(3) + pending[0] = conn.mainReader.closeWait() + pending[1] = conn.mainWriter.closeWait() + pending[2] = conn.transp.closeWait() + try: + await allFutures(pending) + except CancelledError: + await allFutures(pending) + untrackCounter(HttpServerSecureConnectionTrackerName) + conn.state = HttpState.Closed + proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef, transp: StreamTransport): SecureHttpConnectionRef = var res = SecureHttpConnectionRef() @@ -37,6 +59,8 @@ proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef, res.tlsStream = tlsStream res.reader = AsyncStreamReader(tlsStream.reader) res.writer = AsyncStreamWriter(tlsStream.writer) + res.closeCb = closeSecConnection + trackCounter(HttpServerSecureConnectionTrackerName) res proc createSecConnection(server: HttpServerRef, @@ -100,7 +124,7 @@ proc new*(htype: typedesc[SecureHttpServerRef], createConnCallback: createSecConnection, baseUri: serverUri, serverIdent: serverIdent, - flags: serverFlags, + flags: serverFlags + {HttpServerFlags.Secure}, socketFlags: socketFlags, maxConnections: maxConnections, bufferSize: bufferSize, @@ -114,7 +138,7 @@ proc new*(htype: typedesc[SecureHttpServerRef], # else: # nil lifetime: newFuture[void]("http.server.lifetime"), - connections: initTable[string, Future[void]](), + connections: initOrderedTable[string, HttpConnectionHolderRef](), tlsCertificate: tlsCertificate, tlsPrivateKey: tlsPrivateKey, secureFlags: secureFlags diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index 7743916..a603ee4 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -171,11 +171,16 @@ type dump*: proc(): string {.gcsafe, raises: [].} isLeaked*: proc(): bool {.gcsafe, raises: [].} + TrackerCounter* = object + opened*: uint64 + closed*: uint64 + PDispatcherBase = ref object of RootRef timers*: HeapQueue[TimerCallback] callbacks*: Deque[AsyncCallback] idlers*: Deque[AsyncCallback] trackers*: Table[string, TrackerBase] + counters*: Table[string, TrackerCounter] proc sentinelCallbackImpl(arg: pointer) {.gcsafe.} = raiseAssert "Sentinel callback MUST not be scheduled" @@ -404,7 +409,8 @@ when defined(windows): timers: initHeapQueue[TimerCallback](), callbacks: initDeque[AsyncCallback](64), idlers: initDeque[AsyncCallback](), - trackers: initTable[string, TrackerBase]() + trackers: initTable[string, TrackerBase](), + counters: initTable[string, TrackerCounter]() ) res.callbacks.addLast(SentinelCallback) initAPI(res) @@ -814,7 +820,8 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or callbacks: initDeque[AsyncCallback](asyncEventsCount), idlers: initDeque[AsyncCallback](), keys: newSeq[ReadyKey](asyncEventsCount), - trackers: initTable[string, TrackerBase]() + trackers: initTable[string, TrackerBase](), + counters: initTable[string, TrackerCounter]() ) res.callbacks.addLast(SentinelCallback) initAPI(res) @@ -1505,16 +1512,54 @@ proc waitFor*[T](fut: Future[T]): T {.raises: [CatchableError].} = fut.read() -proc addTracker*[T](id: string, tracker: T) = +proc addTracker*[T](id: string, tracker: T) {. + deprecated: "Please use trackCounter facility instead".} = ## Add new ``tracker`` object to current thread dispatcher with identifier ## ``id``. - let loop = getThreadDispatcher() - loop.trackers[id] = tracker + getThreadDispatcher().trackers[id] = tracker -proc getTracker*(id: string): TrackerBase = +proc getTracker*(id: string): TrackerBase {. + deprecated: "Please use getTrackerCounter() instead".} = ## Get ``tracker`` from current thread dispatcher using identifier ``id``. - let loop = getThreadDispatcher() - result = loop.trackers.getOrDefault(id, nil) + getThreadDispatcher().trackers.getOrDefault(id, nil) + +proc trackCounter*(name: string) {.noinit.} = + ## Increase tracker counter with name ``name`` by 1. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).opened) + +proc untrackCounter*(name: string) {.noinit.} = + ## Decrease tracker counter with name ``name`` by 1. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).closed) + +proc getTrackerCounter*(name: string): TrackerCounter {.noinit.} = + ## Return value of counter with name ``name``. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + getThreadDispatcher().counters.getOrDefault(name, tracker) + +proc isCounterLeaked*(name: string): bool {.noinit.} = + ## Returns ``true`` if leak is detected, number of `opened` not equal to + ## number of `closed` requests. + let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64) + let res = getThreadDispatcher().counters.getOrDefault(name, tracker) + res.opened == res.closed + +iterator trackerCounters*( + loop: PDispatcher + ): tuple[name: string, value: TrackerCounter] = + ## Iterates over `loop` thread dispatcher tracker counter table, returns all + ## the tracker counter's names and values. + doAssert(not(isNil(loop))) + for key, value in loop.counters.pairs(): + yield (key, value) + +iterator trackerCounterKeys*(loop: PDispatcher): string = + doAssert(not(isNil(loop))) + ## Iterates over `loop` thread dispatcher tracker counter table, returns all + ## tracker names. + for key in loop.counters.keys(): + yield key when chronosFutureTracking: iterator pendingFutures*(): FutureBase = diff --git a/chronos/asyncproc.nim b/chronos/asyncproc.nim index 8d15b72..8d0cdb7 100644 --- a/chronos/asyncproc.nim +++ b/chronos/asyncproc.nim @@ -23,8 +23,6 @@ const AsyncProcessTrackerName* = "async.process" ## AsyncProcess leaks tracker name - - type AsyncProcessError* = object of CatchableError @@ -109,49 +107,9 @@ type stdError*: string status*: int - AsyncProcessTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle = ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto) -proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} - -proc getAsyncProcessTracker(): AsyncProcessTracker {.inline.} = - var res = cast[AsyncProcessTracker](getTracker(AsyncProcessTrackerName)) - if isNil(res): - res = setupAsyncProcessTracker() - res - -proc dumpAsyncProcessTracking(): string {.gcsafe.} = - var tracker = getAsyncProcessTracker() - let res = "Started async processes: " & $tracker.opened & "\n" & - "Closed async processes: " & $tracker.closed - res - -proc leakAsyncProccessTracker(): bool {.gcsafe.} = - var tracker = getAsyncProcessTracker() - tracker.opened != tracker.closed - -proc trackAsyncProccess(t: AsyncProcessRef) {.inline.} = - var tracker = getAsyncProcessTracker() - inc(tracker.opened) - -proc untrackAsyncProcess(t: AsyncProcessRef) {.inline.} = - var tracker = getAsyncProcessTracker() - inc(tracker.closed) - -proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} = - var res = AsyncProcessTracker( - opened: 0, - closed: 0, - dump: dumpAsyncProcessTracking, - isLeaked: leakAsyncProccessTracker - ) - addTracker(AsyncProcessTrackerName, res) - res - proc init*(t: typedesc[AsyncFD], handle: ProcessStreamHandle): AsyncFD = case handle.kind of ProcessStreamHandleKind.ProcHandle: @@ -502,7 +460,7 @@ when defined(windows): flags: pipes.flags ) - trackAsyncProccess(process) + trackCounter(AsyncProcessTrackerName) return process proc peekProcessExitCode(p: AsyncProcessRef): AsyncProcessResult[int] = @@ -919,7 +877,7 @@ else: flags: pipes.flags ) - trackAsyncProccess(process) + trackCounter(AsyncProcessTrackerName) return process proc peekProcessExitCode(p: AsyncProcessRef, @@ -1237,7 +1195,7 @@ proc closeWait*(p: AsyncProcessRef) {.async.} = discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0)) await p.pipes.closeProcessStreams(p.options) discard p.closeThreadAndProcessHandle() - untrackAsyncProcess(p) + untrackCounter(AsyncProcessTrackerName) proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter = doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer, diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index 9920fc7..7e6e5d2 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -96,10 +96,6 @@ type reader*: AsyncStreamReader writer*: AsyncStreamWriter - AsyncStreamTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = @@ -332,79 +328,6 @@ template checkStreamClosed*(t: untyped) = template checkStreamFinished*(t: untyped) = if t.atEof(): raiseAsyncStreamWriteEOFError() -proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {. - gcsafe, raises: [].} -proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {. - gcsafe, raises: [].} - -proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} = - var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName)) - if isNil(res): - res = setupAsyncStreamReaderTracker() - res - -proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} = - var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName)) - if isNil(res): - res = setupAsyncStreamWriterTracker() - res - -proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} = - var tracker = getAsyncStreamReaderTracker() - let res = "Opened async stream readers: " & $tracker.opened & "\n" & - "Closed async stream readers: " & $tracker.closed - res - -proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} = - var tracker = getAsyncStreamWriterTracker() - let res = "Opened async stream writers: " & $tracker.opened & "\n" & - "Closed async stream writers: " & $tracker.closed - res - -proc leakAsyncStreamReader(): bool {.gcsafe.} = - var tracker = getAsyncStreamReaderTracker() - tracker.opened != tracker.closed - -proc leakAsyncStreamWriter(): bool {.gcsafe.} = - var tracker = getAsyncStreamWriterTracker() - tracker.opened != tracker.closed - -proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} = - var tracker = getAsyncStreamReaderTracker() - inc(tracker.opened) - -proc untrackAsyncStreamReader*(t: AsyncStreamReader) {.inline.} = - var tracker = getAsyncStreamReaderTracker() - inc(tracker.closed) - -proc trackAsyncStreamWriter(t: AsyncStreamWriter) {.inline.} = - var tracker = getAsyncStreamWriterTracker() - inc(tracker.opened) - -proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} = - var tracker = getAsyncStreamWriterTracker() - inc(tracker.closed) - -proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} = - var res = AsyncStreamTracker( - opened: 0, - closed: 0, - dump: dumpAsyncStreamReaderTracking, - isLeaked: leakAsyncStreamReader - ) - addTracker(AsyncStreamReaderTrackerName, res) - res - -proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} = - var res = AsyncStreamTracker( - opened: 0, - closed: 0, - dump: dumpAsyncStreamWriterTracking, - isLeaked: leakAsyncStreamWriter - ) - addTracker(AsyncStreamWriterTrackerName, res) - res - template readLoop(body: untyped): untyped = while true: if rstream.buffer.dataLen() == 0: @@ -977,9 +900,9 @@ proc close*(rw: AsyncStreamRW) = if not(rw.future.finished()): rw.future.complete() when rw is AsyncStreamReader: - untrackAsyncStreamReader(rw) + untrackCounter(AsyncStreamReaderTrackerName) elif rw is AsyncStreamWriter: - untrackAsyncStreamWriter(rw) + untrackCounter(AsyncStreamWriterTrackerName) rw.state = AsyncStreamState.Closed when rw is AsyncStreamReader: @@ -1028,7 +951,7 @@ proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, child.wsource = wsource child.tsource = wsource.tsource child.queue = newAsyncQueue[WriteItem](queueSize) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, @@ -1042,7 +965,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -1053,7 +976,7 @@ proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop, child.rsource = rsource child.tsource = rsource.tsource child.buffer = AsyncBuffer.init(bufferSize) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop, @@ -1068,7 +991,7 @@ proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*(child: AsyncStreamWriter, tsource: StreamTransport) = @@ -1077,7 +1000,7 @@ proc init*(child: AsyncStreamWriter, tsource: StreamTransport) = child.writerLoop = nil child.wsource = nil child.tsource = tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport, @@ -1087,7 +1010,7 @@ proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport, child.writerLoop = nil child.wsource = nil child.tsource = tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child, wsource: AsyncStreamWriter) = @@ -1096,7 +1019,7 @@ proc init*(child, wsource: AsyncStreamWriter) = child.writerLoop = nil child.wsource = wsource child.tsource = wsource.tsource - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) = @@ -1108,7 +1031,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) = if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamWriter(child) + trackCounter(AsyncStreamWriterTrackerName) child.startWriter() proc init*(child: AsyncStreamReader, tsource: StreamTransport) = @@ -1117,7 +1040,7 @@ proc init*(child: AsyncStreamReader, tsource: StreamTransport) = child.readerLoop = nil child.rsource = nil child.tsource = tsource - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child: AsyncStreamReader, tsource: StreamTransport, @@ -1130,7 +1053,7 @@ proc init*[T](child: AsyncStreamReader, tsource: StreamTransport, if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*(child, rsource: AsyncStreamReader) = @@ -1139,7 +1062,7 @@ proc init*(child, rsource: AsyncStreamReader) = child.readerLoop = nil child.rsource = rsource child.tsource = rsource.tsource - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) = @@ -1151,7 +1074,7 @@ proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) = if not isNil(udata): GC_ref(udata) child.udata = cast[pointer](udata) - trackAsyncStreamReader(child) + trackCounter(AsyncStreamReaderTrackerName) child.startReader() proc newAsyncStreamReader*[T](rsource: AsyncStreamReader, diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index 91a7e7a..3e10f76 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -53,10 +53,6 @@ type rwsabuf: WSABUF # Reader WSABUF structure wwsabuf: WSABUF # Writer WSABUF structure - DgramTransportTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - const DgramTransportTrackerName* = "datagram.transport" @@ -88,39 +84,6 @@ template setReadError(t, e: untyped) = (t).state.incl(ReadError) (t).error = getTransportOsError(e) -proc setupDgramTransportTracker(): DgramTransportTracker {. - gcsafe, raises: [].} - -proc getDgramTransportTracker(): DgramTransportTracker {.inline.} = - var res = cast[DgramTransportTracker](getTracker(DgramTransportTrackerName)) - if isNil(res): - res = setupDgramTransportTracker() - doAssert(not(isNil(res))) - res - -proc dumpTransportTracking(): string {.gcsafe.} = - var tracker = getDgramTransportTracker() - "Opened transports: " & $tracker.opened & "\n" & - "Closed transports: " & $tracker.closed - -proc leakTransport(): bool {.gcsafe.} = - let tracker = getDgramTransportTracker() - tracker.opened != tracker.closed - -proc trackDgram(t: DatagramTransport) {.inline.} = - var tracker = getDgramTransportTracker() - inc(tracker.opened) - -proc untrackDgram(t: DatagramTransport) {.inline.} = - var tracker = getDgramTransportTracker() - inc(tracker.closed) - -proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} = - let res = DgramTransportTracker( - opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport) - addTracker(DgramTransportTrackerName, res) - res - when defined(windows): template setWriterWSABuffer(t, v: untyped) = (t).wwsabuf.buf = cast[cstring](v.buf) @@ -213,7 +176,7 @@ when defined(windows): transp.state.incl(ReadPaused) if ReadClosed in transp.state and not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) # If `ReadClosed` present, then close(transport) was called. transp.future.complete() GC_unref(transp) @@ -259,7 +222,7 @@ when defined(windows): # WSARecvFrom session. if ReadClosed in transp.state and not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) transp.future.complete() GC_unref(transp) break @@ -394,7 +357,7 @@ when defined(windows): len: ULONG(len(res.buffer))) GC_ref(res) # Start tracking transport - trackDgram(res) + trackCounter(DgramTransportTrackerName) if NoAutoRead notin flags: let rres = res.resumeRead() if rres.isErr(): raiseTransportOsError(rres.error()) @@ -592,7 +555,7 @@ else: res.future = newFuture[void]("datagram.transport") GC_ref(res) # Start tracking transport - trackDgram(res) + trackCounter(DgramTransportTrackerName) if NoAutoRead notin flags: let rres = res.resumeRead() if rres.isErr(): raiseTransportOsError(rres.error()) @@ -603,7 +566,7 @@ proc close*(transp: DatagramTransport) = proc continuation(udata: pointer) {.raises: [].} = if not(transp.future.finished()): # Stop tracking transport - untrackDgram(transp) + untrackCounter(DgramTransportTrackerName) transp.future.complete() GC_unref(transp) diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index 3abd942..a4190da 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -54,15 +54,6 @@ type ReuseAddr, ReusePort - - StreamTransportTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - - StreamServerTracker* = ref object of TrackerBase - opened*: int64 - closed*: int64 - ReadMessagePredicate* = proc (data: openArray[byte]): tuple[consumed: int, done: bool] {. gcsafe, raises: [].} @@ -199,71 +190,6 @@ template shiftVectorFile(v: var StreamVector, o: untyped) = (v).buf = cast[pointer](cast[uint]((v).buf) - uint(o)) (v).offset += uint(o) -proc setupStreamTransportTracker(): StreamTransportTracker {. - gcsafe, raises: [].} -proc setupStreamServerTracker(): StreamServerTracker {. - gcsafe, raises: [].} - -proc getStreamTransportTracker(): StreamTransportTracker {.inline.} = - var res = cast[StreamTransportTracker](getTracker(StreamTransportTrackerName)) - if isNil(res): - res = setupStreamTransportTracker() - doAssert(not(isNil(res))) - res - -proc getStreamServerTracker(): StreamServerTracker {.inline.} = - var res = cast[StreamServerTracker](getTracker(StreamServerTrackerName)) - if isNil(res): - res = setupStreamServerTracker() - doAssert(not(isNil(res))) - res - -proc dumpTransportTracking(): string {.gcsafe.} = - var tracker = getStreamTransportTracker() - "Opened transports: " & $tracker.opened & "\n" & - "Closed transports: " & $tracker.closed - -proc dumpServerTracking(): string {.gcsafe.} = - var tracker = getStreamServerTracker() - "Opened servers: " & $tracker.opened & "\n" & - "Closed servers: " & $tracker.closed - -proc leakTransport(): bool {.gcsafe.} = - var tracker = getStreamTransportTracker() - tracker.opened != tracker.closed - -proc leakServer(): bool {.gcsafe.} = - var tracker = getStreamServerTracker() - tracker.opened != tracker.closed - -proc trackStream(t: StreamTransport) {.inline.} = - var tracker = getStreamTransportTracker() - inc(tracker.opened) - -proc untrackStream(t: StreamTransport) {.inline.} = - var tracker = getStreamTransportTracker() - inc(tracker.closed) - -proc trackServer(s: StreamServer) {.inline.} = - var tracker = getStreamServerTracker() - inc(tracker.opened) - -proc untrackServer(s: StreamServer) {.inline.} = - var tracker = getStreamServerTracker() - inc(tracker.closed) - -proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} = - let res = StreamTransportTracker( - opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport) - addTracker(StreamTransportTrackerName, res) - res - -proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} = - let res = StreamServerTracker( - opened: 0, closed: 0, dump: dumpServerTracking, isLeaked: leakServer) - addTracker(StreamServerTrackerName, res) - res - proc completePendingWriteQueue(queue: var Deque[StreamVector], v: int) {.inline.} = while len(queue) > 0: @@ -280,7 +206,7 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector], proc clean(server: StreamServer) {.inline.} = if not(server.loopFuture.finished()): - untrackServer(server) + untrackCounter(StreamServerTrackerName) server.loopFuture.complete() if not(isNil(server.udata)) and (GCUserData in server.flags): GC_unref(cast[ref int](server.udata)) @@ -288,7 +214,7 @@ proc clean(server: StreamServer) {.inline.} = proc clean(transp: StreamTransport) {.inline.} = if not(transp.future.finished()): - untrackStream(transp) + untrackCounter(StreamTransportTrackerName) transp.future.complete() GC_unref(transp) @@ -784,7 +710,7 @@ when defined(windows): else: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) else: sock.closeSocket() @@ -853,7 +779,7 @@ when defined(windows): let transp = newStreamPipeTransport(AsyncFD(pipeHandle), bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) pipeContinuation(nil) @@ -909,7 +835,7 @@ when defined(windows): ntransp = newStreamPipeTransport(server.sock, server.bufferSize, nil, flags) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) of ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close call. @@ -1013,7 +939,7 @@ when defined(windows): ntransp = newStreamSocketTransport(server.asock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) of ERROR_OPERATION_ABORTED: @@ -1156,7 +1082,7 @@ when defined(windows): ntransp = newStreamSocketTransport(server.asock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) of ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close. @@ -1216,7 +1142,7 @@ when defined(windows): retFuture.fail(getTransportOsError(error)) return - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) of ERROR_OPERATION_ABORTED, ERROR_PIPE_NOT_CONNECTED: @@ -1626,7 +1552,7 @@ else: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) proc cancel(udata: pointer) = @@ -1639,7 +1565,7 @@ else: if res == 0: let transp = newStreamSocketTransport(sock, bufferSize, child) # Start tracking transport - trackStream(transp) + trackCounter(StreamTransportTrackerName) retFuture.complete(transp) break else: @@ -1694,7 +1620,7 @@ else: newStreamSocketTransport(sock, server.bufferSize, transp) else: newStreamSocketTransport(sock, server.bufferSize, nil) - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) asyncSpawn server.function(server, ntransp) else: # Client was accepted, so we not going to raise assertion, but @@ -1782,7 +1708,7 @@ else: else: newStreamSocketTransport(sock, server.bufferSize, nil) # Start tracking transport - trackStream(ntransp) + trackCounter(StreamTransportTrackerName) retFuture.complete(ntransp) else: discard closeFd(cint(sock)) @@ -2098,7 +2024,7 @@ proc createStreamServer*(host: TransportAddress, sres.apending = false # Start tracking server - trackServer(sres) + trackCounter(StreamServerTrackerName) GC_ref(sres) sres @@ -2671,7 +2597,7 @@ proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil, ? register2(fd) var res = newStreamPipeTransport(fd, bufferSize, child) # Start tracking transport - trackStream(res) + trackCounter(StreamTransportTrackerName) ok(res) proc fromPipe*(fd: AsyncFD, child: StreamTransport = nil, diff --git a/chronos/unittest2/asynctests.nim b/chronos/unittest2/asynctests.nim index fda0353..bc703b7 100644 --- a/chronos/unittest2/asynctests.nim +++ b/chronos/unittest2/asynctests.nim @@ -6,6 +6,7 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) +import std/tables import unittest2 import ../../chronos @@ -17,3 +18,14 @@ template asyncTest*(name: string, body: untyped): untyped = proc() {.async, gcsafe.} = body )()) + +template checkLeaks*(name: string): untyped = + let counter = getTrackerCounter(name) + if counter.opened != counter.closed: + echo "[" & name & "] opened = ", counter.opened, + ", closed = ", counter.closed + check counter.opened == counter.closed + +template checkLeaks*(): untyped = + for key in getThreadDispatcher().trackerCounterKeys(): + checkLeaks(key) diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 09a0b7e..d90b688 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -7,8 +7,8 @@ # MIT license (LICENSE-MIT) import unittest2 import bearssl/[x509] -import ../chronos -import ../chronos/streams/[tlsstream, chunkstream, boundstream] +import ".."/chronos/unittest2/asynctests +import ".."/chronos/streams/[tlsstream, chunkstream, boundstream] {.used.} @@ -302,11 +302,7 @@ suite "AsyncStream test suite": check waitFor(testConsume()) == true test "AsyncStream(StreamTransport) leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() test "AsyncStream(AsyncStream) readExactly() test": proc testReadExactly2(): Future[bool] {.async.} = @@ -613,11 +609,7 @@ suite "AsyncStream test suite": check waitFor(testWriteEof()) == true test "AsyncStream(AsyncStream) leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() suite "ChunkedStream test suite": test "ChunkedStream test vectors": @@ -911,11 +903,7 @@ suite "ChunkedStream test suite": check waitFor(testSmallChunk(767309, 4457, 173)) == true test "ChunkedStream leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() suite "TLSStream test suite": const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] @@ -1039,11 +1027,7 @@ suite "TLSStream test suite": check res == "Some message\r\n" test "TLSStream leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() suite "BoundedStream test suite": @@ -1411,8 +1395,4 @@ suite "BoundedStream test suite": check waitFor(checkEmptyStreams()) == true test "BoundedStream leaks test": - check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + checkLeaks() diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index 17385a3..7db04f9 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -6,8 +6,8 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, net] -import unittest2 -import ../chronos +import ".."/chronos/unittest2/asynctests +import ".."/chronos {.used.} @@ -558,4 +558,4 @@ suite "Datagram Transport test suite": test "0.0.0.0/::0 (INADDR_ANY) test": check waitFor(testAnyAddress()) == 6 test "Transports leak test": - check getTracker("datagram.transport").isLeaked() == false + checkLeaks() diff --git a/tests/testhttpclient.nim b/tests/testhttpclient.nim index 2807ebc..1eacc21 100644 --- a/tests/testhttpclient.nim +++ b/tests/testhttpclient.nim @@ -6,8 +6,9 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, sha1] -import unittest2 -import ../chronos, ../chronos/apps/http/[httpserver, shttpserver, httpclient] +import ".."/chronos/unittest2/asynctests +import ".."/chronos, + ".."/chronos/apps/http/[httpserver, shttpserver, httpclient] import stew/base10 {.used.} @@ -138,7 +139,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -241,7 +242,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -324,7 +325,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -394,7 +395,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -470,7 +471,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -569,7 +570,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -667,7 +668,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -778,7 +779,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() @@ -909,7 +910,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() @@ -971,7 +972,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, false) server.start() @@ -1125,7 +1126,7 @@ suite "HTTP client testing suite": else: return await request.respond(Http404, "Page not found") else: - return dumbResponse() + return defaultResponse() var server = createServer(initTAddress("127.0.0.1:0"), process, secure) server.start() @@ -1262,17 +1263,4 @@ suite "HTTP client testing suite": check waitFor(testServerSentEvents(false)) == true test "Leaks test": - proc getTrackerLeaks(tracker: string): bool = - let tracker = getTracker(tracker) - if isNil(tracker): false else: tracker.isLeaked() - - check: - getTrackerLeaks("http.body.reader") == false - getTrackerLeaks("http.body.writer") == false - getTrackerLeaks("httpclient.connection") == false - getTrackerLeaks("httpclient.request") == false - getTrackerLeaks("httpclient.response") == false - getTrackerLeaks("async.stream.reader") == false - getTrackerLeaks("async.stream.writer") == false - getTrackerLeaks("stream.server") == false - getTrackerLeaks("stream.transport") == false + checkLeaks() diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 63c92b2..83372ea 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -6,10 +6,10 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, algorithm] -import unittest2 -import ../chronos, ../chronos/apps/http/httpserver, - ../chronos/apps/http/httpcommon, - ../chronos/unittest2/asynctests +import ".."/chronos/unittest2/asynctests, + ".."/chronos, ".."/chronos/apps/http/httpserver, + ".."/chronos/apps/http/httpcommon, + ".."/chronos/apps/http/httpdebug import stew/base10 {.used.} @@ -84,7 +84,7 @@ suite "HTTP server testing suite": # Reraising exception, because processor should properly handle it. raise exc else: - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -100,14 +100,14 @@ suite "HTTP server testing suite": let request = case operation of GetBodyTest, ConsumeBodyTest, PostUrlTest: - "POST / HTTP/1.0\r\n" & + "POST / HTTP/1.1\r\n" & "Content-Type: application/x-www-form-urlencoded\r\n" & "Transfer-Encoding: chunked\r\n" & "Cookie: 2\r\n\r\n" & "5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" & "2\r\n9F\r\n0\r\n\r\n" of PostMultipartTest: - "POST / HTTP/1.0\r\n" & + "POST / HTTP/1.1\r\n" & "Host: 127.0.0.1:30080\r\n" & "Transfer-Encoding: chunked\r\n" & "Content-Type: multipart/form-data; boundary=f98f0\r\n\r\n" & @@ -134,9 +134,9 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.TimeoutError: + if r.error.kind == HttpServerError.TimeoutError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), @@ -148,7 +148,6 @@ suite "HTTP server testing suite": let server = res.get() server.start() let address = server.instance.localAddress() - let data = await httpClient(address, "") await server.stop() await server.closeWait() @@ -165,9 +164,9 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.CriticalError: + if r.error.kind == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), @@ -195,9 +194,9 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - if r.error().error == HttpServerError.CriticalError: + if r.error.error == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -225,9 +224,9 @@ suite "HTTP server testing suite": if r.isOk(): discard else: - if r.error().error == HttpServerError.CriticalError: + if r.error.error == HttpServerError.CriticalError: serverRes = true - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -280,7 +279,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -321,7 +320,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -367,7 +366,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -411,7 +410,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -456,7 +455,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -512,7 +511,7 @@ suite "HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -576,7 +575,7 @@ suite "HTTP server testing suite": await eventContinue.wait() return await request.respond(Http404, "", HttpTable.init()) else: - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -1247,7 +1246,7 @@ suite "HTTP server testing suite": return response else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, @@ -1311,7 +1310,7 @@ suite "HTTP server testing suite": let request = r.get() return await request.respond(Http200, "TEST_OK", HttpTable.init()) else: - return dumbResponse() + return defaultResponse() for test in TestMessages: let @@ -1355,9 +1354,78 @@ suite "HTTP server testing suite": await server.stop() await server.closeWait() - test "Leaks test": + asyncTest "HTTP debug tests": + const + TestsCount = 10 + TestRequest = "GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n" + + proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + return await request.respond(Http200, "TEST_OK", HttpTable.init()) + else: + return defaultResponse() + + proc client(address: TransportAddress, + data: string): Future[StreamTransport] {.async.} = + var transp: StreamTransport + var buffer = newSeq[byte](4096) + var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + try: + transp = await connect(address) + let wres {.used.} = + await transp.write(data) + let hres {.used.} = + await transp.readUntil(addr buffer[0], len(buffer), sep) + transp + except CatchableError: + if not(isNil(transp)): await transp.closeWait() + nil + + let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process, + serverFlags = {HttpServerFlags.Http11Pipeline}, + socketFlags = socketFlags) + check res.isOk() + + let server = res.get() + server.start() + let address = server.instance.localAddress() + + let info = server.getServerInfo() + check: - getTracker("async.stream.reader").isLeaked() == false - getTracker("async.stream.writer").isLeaked() == false - getTracker("stream.server").isLeaked() == false - getTracker("stream.transport").isLeaked() == false + info.connectionType == ConnectionType.NonSecure + info.address == address + info.state == HttpServerState.ServerRunning + info.flags == {HttpServerFlags.Http11Pipeline} + info.socketFlags == socketFlags + + try: + var clientFutures: seq[Future[StreamTransport]] + for i in 0 ..< TestsCount: + clientFutures.add(client(address, TestRequest)) + await allFutures(clientFutures) + + let connections = server.getConnections() + check len(connections) == TestsCount + let currentTime = Moment.now() + for index, connection in connections.pairs(): + let transp = clientFutures[index].read() + check: + connection.remoteAddress.get() == transp.localAddress() + connection.localAddress.get() == transp.remoteAddress() + connection.connectionType == ConnectionType.NonSecure + connection.connectionState == ConnectionState.Alive + (currentTime - connection.createMoment.get()) != ZeroDuration + (currentTime - connection.acceptMoment) != ZeroDuration + var pending: seq[Future[void]] + for transpFut in clientFutures: + pending.add(closeWait(transpFut.read())) + await allFutures(pending) + finally: + await server.stop() + await server.closeWait() + + test "Leaks test": + checkLeaks() diff --git a/tests/testproc.nim b/tests/testproc.nim index 05f793d..b038325 100644 --- a/tests/testproc.nim +++ b/tests/testproc.nim @@ -6,7 +6,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/os -import unittest2, stew/[base10, byteutils] +import stew/[base10, byteutils] import ".."/chronos/unittest2/asynctests when defined(posix): @@ -414,12 +414,4 @@ suite "Asynchronous process management test suite": check getCurrentFD() == markFD test "Leaks test": - proc getTrackerLeaks(tracker: string): bool = - let tracker = getTracker(tracker) - if isNil(tracker): false else: tracker.isLeaked() - - check: - getTrackerLeaks("async.process") == false - getTrackerLeaks("async.stream.reader") == false - getTrackerLeaks("async.stream.writer") == false - getTrackerLeaks("stream.transport") == false + checkLeaks() diff --git a/tests/testshttpserver.nim b/tests/testshttpserver.nim index a258cc9..a83d0b2 100644 --- a/tests/testshttpserver.nim +++ b/tests/testshttpserver.nim @@ -6,8 +6,8 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/strutils -import unittest2 -import ../chronos, ../chronos/apps/http/shttpserver +import ".."/chronos/unittest2/asynctests +import ".."/chronos, ".."/chronos/apps/http/shttpserver import stew/base10 {.used.} @@ -115,7 +115,7 @@ suite "Secure HTTP server testing suite": HttpTable.init()) else: serverRes = false - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let serverFlags = {Secure} @@ -154,7 +154,7 @@ suite "Secure HTTP server testing suite": else: serverRes = true testFut.complete() - return dumbResponse() + return defaultResponse() let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} let serverFlags = {Secure} @@ -178,3 +178,6 @@ suite "Secure HTTP server testing suite": return serverRes and data == "EXCEPTION" check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true + + test "Leaks test": + checkLeaks() diff --git a/tests/teststream.nim b/tests/teststream.nim index 7601a39..f6bc99b 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -6,7 +6,7 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, os] -import unittest2 +import ".."/chronos/unittest2/asynctests import ".."/chronos, ".."/chronos/[osdefs, oserrno] {.used.} @@ -1370,10 +1370,11 @@ suite "Stream Transport test suite": test prefixes[i] & "close() while in accept() waiting test": check waitFor(testAcceptClose(addresses[i])) == true test prefixes[i] & "Intermediate transports leak test #1": + checkLeaks() when defined(windows): skip() else: - check getTracker("stream.transport").isLeaked() == false + checkLeaks(StreamTransportTrackerName) test prefixes[i] & "accept() too many file descriptors test": when defined(windows): skip() @@ -1389,10 +1390,8 @@ suite "Stream Transport test suite": check waitFor(testPipe()) == true test "[IP] bind connect to local address": waitFor(testConnectBindLocalAddress()) - test "Servers leak test": - check getTracker("stream.server").isLeaked() == false - test "Transports leak test": - check getTracker("stream.transport").isLeaked() == false + test "Leaks test": + checkLeaks() test "File descriptors leak test": when defined(windows): # Windows handle numbers depends on many conditions, so we can't use