Trackers refactoring. (#416)
* Refactor chronos trackers to be more simple. * Refactor trackers. Add HTTP server trackers. Refactor HTTP main processing loop. * Compatibility fixes. Add checkLeaks(). * Fix posix test issue. * Add httpdebug module which introduces HTTP connection dumping helpers. Add tests for it. * Recover and deprecate old version of Trackers. * Make public iterators to iterate over all tracker counters available. Fix asynctests to use public iterators instead private one.
This commit is contained in:
parent
3d80ea9fc7
commit
155d89450e
|
@ -25,71 +25,6 @@ type
|
|||
bstate*: HttpState
|
||||
streams*: seq[AsyncStreamWriter]
|
||||
|
||||
HttpBodyTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe, raises: [].}
|
||||
proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe, raises: [].}
|
||||
|
||||
proc getHttpBodyWriterTracker(): HttpBodyTracker {.inline.} =
|
||||
var res = cast[HttpBodyTracker](getTracker(HttpBodyWriterTrackerName))
|
||||
if isNil(res):
|
||||
res = setupHttpBodyWriterTracker()
|
||||
res
|
||||
|
||||
proc getHttpBodyReaderTracker(): HttpBodyTracker {.inline.} =
|
||||
var res = cast[HttpBodyTracker](getTracker(HttpBodyReaderTrackerName))
|
||||
if isNil(res):
|
||||
res = setupHttpBodyReaderTracker()
|
||||
res
|
||||
|
||||
proc dumpHttpBodyWriterTracking(): string {.gcsafe.} =
|
||||
let tracker = getHttpBodyWriterTracker()
|
||||
"Opened HTTP body writers: " & $tracker.opened & "\n" &
|
||||
"Closed HTTP body writers: " & $tracker.closed
|
||||
|
||||
proc dumpHttpBodyReaderTracking(): string {.gcsafe.} =
|
||||
let tracker = getHttpBodyReaderTracker()
|
||||
"Opened HTTP body readers: " & $tracker.opened & "\n" &
|
||||
"Closed HTTP body readers: " & $tracker.closed
|
||||
|
||||
proc leakHttpBodyWriter(): bool {.gcsafe.} =
|
||||
var tracker = getHttpBodyWriterTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc leakHttpBodyReader(): bool {.gcsafe.} =
|
||||
var tracker = getHttpBodyReaderTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackHttpBodyWriter(t: HttpBodyWriter) {.inline.} =
|
||||
inc(getHttpBodyWriterTracker().opened)
|
||||
|
||||
proc untrackHttpBodyWriter*(t: HttpBodyWriter) {.inline.} =
|
||||
inc(getHttpBodyWriterTracker().closed)
|
||||
|
||||
proc trackHttpBodyReader(t: HttpBodyReader) {.inline.} =
|
||||
inc(getHttpBodyReaderTracker().opened)
|
||||
|
||||
proc untrackHttpBodyReader*(t: HttpBodyReader) {.inline.} =
|
||||
inc(getHttpBodyReaderTracker().closed)
|
||||
|
||||
proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe.} =
|
||||
var res = HttpBodyTracker(opened: 0, closed: 0,
|
||||
dump: dumpHttpBodyWriterTracking,
|
||||
isLeaked: leakHttpBodyWriter
|
||||
)
|
||||
addTracker(HttpBodyWriterTrackerName, res)
|
||||
res
|
||||
|
||||
proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe.} =
|
||||
var res = HttpBodyTracker(opened: 0, closed: 0,
|
||||
dump: dumpHttpBodyReaderTracking,
|
||||
isLeaked: leakHttpBodyReader
|
||||
)
|
||||
addTracker(HttpBodyReaderTrackerName, res)
|
||||
res
|
||||
|
||||
proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader =
|
||||
## HttpBodyReader is AsyncStreamReader which holds references to all the
|
||||
## ``streams``. Also on close it will close all the ``streams``.
|
||||
|
@ -98,7 +33,7 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader =
|
|||
doAssert(len(streams) > 0, "At least one stream must be added")
|
||||
var res = HttpBodyReader(bstate: HttpState.Alive, streams: @streams)
|
||||
res.init(streams[0])
|
||||
trackHttpBodyReader(res)
|
||||
trackCounter(HttpBodyReaderTrackerName)
|
||||
res
|
||||
|
||||
proc closeWait*(bstream: HttpBodyReader) {.async.} =
|
||||
|
@ -113,7 +48,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async.} =
|
|||
await allFutures(res)
|
||||
await procCall(closeWait(AsyncStreamReader(bstream)))
|
||||
bstream.bstate = HttpState.Closed
|
||||
untrackHttpBodyReader(bstream)
|
||||
untrackCounter(HttpBodyReaderTrackerName)
|
||||
|
||||
proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter =
|
||||
## HttpBodyWriter is AsyncStreamWriter which holds references to all the
|
||||
|
@ -123,7 +58,7 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter =
|
|||
doAssert(len(streams) > 0, "At least one stream must be added")
|
||||
var res = HttpBodyWriter(bstate: HttpState.Alive, streams: @streams)
|
||||
res.init(streams[0])
|
||||
trackHttpBodyWriter(res)
|
||||
trackCounter(HttpBodyWriterTrackerName)
|
||||
res
|
||||
|
||||
proc closeWait*(bstream: HttpBodyWriter) {.async.} =
|
||||
|
@ -136,7 +71,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async.} =
|
|||
await allFutures(res)
|
||||
await procCall(closeWait(AsyncStreamWriter(bstream)))
|
||||
bstream.bstate = HttpState.Closed
|
||||
untrackHttpBodyWriter(bstream)
|
||||
untrackCounter(HttpBodyWriterTrackerName)
|
||||
|
||||
proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} =
|
||||
if len(bstream.streams) == 1:
|
||||
|
|
|
@ -190,10 +190,6 @@ type
|
|||
|
||||
HttpClientFlags* = set[HttpClientFlag]
|
||||
|
||||
HttpClientTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
ServerSentEvent* = object
|
||||
name*: string
|
||||
data*: string
|
||||
|
@ -204,100 +200,6 @@ type
|
|||
# HttpClientResponseRef valid states are
|
||||
# Open -> (Finished, Error) -> (Closing, Closed)
|
||||
|
||||
proc setupHttpClientConnectionTracker(): HttpClientTracker {.
|
||||
gcsafe, raises: [].}
|
||||
proc setupHttpClientRequestTracker(): HttpClientTracker {.
|
||||
gcsafe, raises: [].}
|
||||
proc setupHttpClientResponseTracker(): HttpClientTracker {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
proc getHttpClientConnectionTracker(): HttpClientTracker {.inline.} =
|
||||
var res = cast[HttpClientTracker](getTracker(HttpClientConnectionTrackerName))
|
||||
if isNil(res):
|
||||
res = setupHttpClientConnectionTracker()
|
||||
res
|
||||
|
||||
proc getHttpClientRequestTracker(): HttpClientTracker {.inline.} =
|
||||
var res = cast[HttpClientTracker](getTracker(HttpClientRequestTrackerName))
|
||||
if isNil(res):
|
||||
res = setupHttpClientRequestTracker()
|
||||
res
|
||||
|
||||
proc getHttpClientResponseTracker(): HttpClientTracker {.inline.} =
|
||||
var res = cast[HttpClientTracker](getTracker(HttpClientResponseTrackerName))
|
||||
if isNil(res):
|
||||
res = setupHttpClientResponseTracker()
|
||||
res
|
||||
|
||||
proc dumpHttpClientConnectionTracking(): string {.gcsafe.} =
|
||||
let tracker = getHttpClientConnectionTracker()
|
||||
"Opened HTTP client connections: " & $tracker.opened & "\n" &
|
||||
"Closed HTTP client connections: " & $tracker.closed
|
||||
|
||||
proc dumpHttpClientRequestTracking(): string {.gcsafe.} =
|
||||
let tracker = getHttpClientRequestTracker()
|
||||
"Opened HTTP client requests: " & $tracker.opened & "\n" &
|
||||
"Closed HTTP client requests: " & $tracker.closed
|
||||
|
||||
proc dumpHttpClientResponseTracking(): string {.gcsafe.} =
|
||||
let tracker = getHttpClientResponseTracker()
|
||||
"Opened HTTP client responses: " & $tracker.opened & "\n" &
|
||||
"Closed HTTP client responses: " & $tracker.closed
|
||||
|
||||
proc leakHttpClientConnection(): bool {.gcsafe.} =
|
||||
var tracker = getHttpClientConnectionTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc leakHttpClientRequest(): bool {.gcsafe.} =
|
||||
var tracker = getHttpClientRequestTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc leakHttpClientResponse(): bool {.gcsafe.} =
|
||||
var tracker = getHttpClientResponseTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackHttpClientConnection(t: HttpClientConnectionRef) {.inline.} =
|
||||
inc(getHttpClientConnectionTracker().opened)
|
||||
|
||||
proc untrackHttpClientConnection*(t: HttpClientConnectionRef) {.inline.} =
|
||||
inc(getHttpClientConnectionTracker().closed)
|
||||
|
||||
proc trackHttpClientRequest(t: HttpClientRequestRef) {.inline.} =
|
||||
inc(getHttpClientRequestTracker().opened)
|
||||
|
||||
proc untrackHttpClientRequest*(t: HttpClientRequestRef) {.inline.} =
|
||||
inc(getHttpClientRequestTracker().closed)
|
||||
|
||||
proc trackHttpClientResponse(t: HttpClientResponseRef) {.inline.} =
|
||||
inc(getHttpClientResponseTracker().opened)
|
||||
|
||||
proc untrackHttpClientResponse*(t: HttpClientResponseRef) {.inline.} =
|
||||
inc(getHttpClientResponseTracker().closed)
|
||||
|
||||
proc setupHttpClientConnectionTracker(): HttpClientTracker {.gcsafe.} =
|
||||
var res = HttpClientTracker(opened: 0, closed: 0,
|
||||
dump: dumpHttpClientConnectionTracking,
|
||||
isLeaked: leakHttpClientConnection
|
||||
)
|
||||
addTracker(HttpClientConnectionTrackerName, res)
|
||||
res
|
||||
|
||||
proc setupHttpClientRequestTracker(): HttpClientTracker {.gcsafe.} =
|
||||
var res = HttpClientTracker(opened: 0, closed: 0,
|
||||
dump: dumpHttpClientRequestTracking,
|
||||
isLeaked: leakHttpClientRequest
|
||||
)
|
||||
addTracker(HttpClientRequestTrackerName, res)
|
||||
res
|
||||
|
||||
proc setupHttpClientResponseTracker(): HttpClientTracker {.gcsafe.} =
|
||||
var res = HttpClientTracker(opened: 0, closed: 0,
|
||||
dump: dumpHttpClientResponseTracking,
|
||||
isLeaked: leakHttpClientResponse
|
||||
)
|
||||
addTracker(HttpClientResponseTrackerName, res)
|
||||
res
|
||||
|
||||
template checkClosed(reqresp: untyped): untyped =
|
||||
if reqresp.connection.state in {HttpClientConnectionState.Closing,
|
||||
HttpClientConnectionState.Closed}:
|
||||
|
@ -556,7 +458,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef,
|
|||
state: HttpClientConnectionState.Connecting,
|
||||
remoteHostname: ha.id
|
||||
)
|
||||
trackHttpClientConnection(res)
|
||||
trackCounter(HttpClientConnectionTrackerName)
|
||||
res
|
||||
of HttpClientScheme.Secure:
|
||||
let treader = newAsyncStreamReader(transp)
|
||||
|
@ -575,7 +477,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef,
|
|||
state: HttpClientConnectionState.Connecting,
|
||||
remoteHostname: ha.id
|
||||
)
|
||||
trackHttpClientConnection(res)
|
||||
trackCounter(HttpClientConnectionTrackerName)
|
||||
res
|
||||
|
||||
proc setError(request: HttpClientRequestRef, error: ref HttpError) {.
|
||||
|
@ -615,7 +517,7 @@ proc closeWait(conn: HttpClientConnectionRef) {.async.} =
|
|||
discard
|
||||
await conn.transp.closeWait()
|
||||
conn.state = HttpClientConnectionState.Closed
|
||||
untrackHttpClientConnection(conn)
|
||||
untrackCounter(HttpClientConnectionTrackerName)
|
||||
|
||||
proc connect(session: HttpSessionRef,
|
||||
ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} =
|
||||
|
@ -835,7 +737,7 @@ proc closeWait*(request: HttpClientRequestRef) {.async.} =
|
|||
request.session = nil
|
||||
request.error = nil
|
||||
request.state = HttpReqRespState.Closed
|
||||
untrackHttpClientRequest(request)
|
||||
untrackCounter(HttpClientRequestTrackerName)
|
||||
|
||||
proc closeWait*(response: HttpClientResponseRef) {.async.} =
|
||||
if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
|
||||
|
@ -848,7 +750,7 @@ proc closeWait*(response: HttpClientResponseRef) {.async.} =
|
|||
response.session = nil
|
||||
response.error = nil
|
||||
response.state = HttpReqRespState.Closed
|
||||
untrackHttpClientResponse(response)
|
||||
untrackCounter(HttpClientResponseTrackerName)
|
||||
|
||||
proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
|
||||
): HttpResult[HttpClientResponseRef] {.raises: [] .} =
|
||||
|
@ -958,7 +860,7 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
|
|||
httpPipeline:
|
||||
res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive)
|
||||
res.connection.flags.incl(HttpClientConnectionFlag.Response)
|
||||
trackHttpClientResponse(res)
|
||||
trackCounter(HttpClientResponseTrackerName)
|
||||
ok(res)
|
||||
|
||||
proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {.
|
||||
|
@ -996,7 +898,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
|
|||
version: version, flags: flags, headers: HttpTable.init(headers),
|
||||
address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body
|
||||
)
|
||||
trackHttpClientRequest(res)
|
||||
trackCounter(HttpClientRequestTrackerName)
|
||||
res
|
||||
|
||||
proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
|
||||
|
@ -1012,7 +914,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
|
|||
version: version, flags: flags, headers: HttpTable.init(headers),
|
||||
address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body
|
||||
)
|
||||
trackHttpClientRequest(res)
|
||||
trackCounter(HttpClientRequestTrackerName)
|
||||
ok(res)
|
||||
|
||||
proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
|
||||
|
|
|
@ -13,6 +13,15 @@ import ../../streams/[asyncstream, boundstream]
|
|||
export asyncloop, asyncsync, results, httputils, strutils
|
||||
|
||||
const
|
||||
HttpServerUnsecureConnectionTrackerName* =
|
||||
"httpserver.unsecure.connection"
|
||||
HttpServerSecureConnectionTrackerName* =
|
||||
"httpserver.secure.connection"
|
||||
HttpServerRequestTrackerName* =
|
||||
"httpserver.request"
|
||||
HttpServerResponseTrackerName* =
|
||||
"httpserver.response"
|
||||
|
||||
HeadersMark* = @[0x0d'u8, 0x0a'u8, 0x0d'u8, 0x0a'u8]
|
||||
PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete}
|
||||
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
#
|
||||
# Chronos HTTP/S server implementation
|
||||
# (c) Copyright 2021-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/tables
|
||||
import stew/results
|
||||
import ../../timer
|
||||
import httpserver, shttpserver
|
||||
from httpclient import HttpClientScheme
|
||||
from httpcommon import HttpState
|
||||
from ../../osdefs import SocketHandle
|
||||
from ../../transports/common import TransportAddress, ServerFlags
|
||||
export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
type
|
||||
ConnectionType* {.pure.} = enum
|
||||
NonSecure, Secure
|
||||
|
||||
ConnectionState* {.pure.} = enum
|
||||
Accepted, Alive, Closing, Closed
|
||||
|
||||
ServerConnectionInfo* = object
|
||||
handle*: SocketHandle
|
||||
connectionType*: ConnectionType
|
||||
connectionState*: ConnectionState
|
||||
remoteAddress*: Opt[TransportAddress]
|
||||
localAddress*: Opt[TransportAddress]
|
||||
acceptMoment*: Moment
|
||||
createMoment*: Opt[Moment]
|
||||
|
||||
ServerInfo* = object
|
||||
connectionType*: ConnectionType
|
||||
address*: TransportAddress
|
||||
state*: HttpServerState
|
||||
maxConnections*: int
|
||||
backlogSize*: int
|
||||
baseUri*: Uri
|
||||
serverIdent*: string
|
||||
flags*: set[HttpServerFlags]
|
||||
socketFlags*: set[ServerFlags]
|
||||
headersTimeout*: Duration
|
||||
bufferSize*: int
|
||||
maxHeadersSize*: int
|
||||
maxRequestBodySize*: int
|
||||
|
||||
proc getConnectionType*(
|
||||
server: HttpServerRef | SecureHttpServerRef): ConnectionType =
|
||||
when server is SecureHttpServerRef:
|
||||
ConnectionType.Secure
|
||||
else:
|
||||
if HttpServerFlags.Secure in server.flags:
|
||||
ConnectionType.Secure
|
||||
else:
|
||||
ConnectionType.NonSecure
|
||||
|
||||
proc getServerInfo*(server: HttpServerRef|SecureHttpServerRef): ServerInfo =
|
||||
ServerInfo(
|
||||
connectionType: server.getConnectionType(),
|
||||
address: server.address,
|
||||
state: server.state(),
|
||||
maxConnections: server.maxConnections,
|
||||
backlogSize: server.backlogSize,
|
||||
baseUri: server.baseUri,
|
||||
serverIdent: server.serverIdent,
|
||||
flags: server.flags,
|
||||
socketFlags: server.socketFlags,
|
||||
headersTimeout: server.headersTimeout,
|
||||
bufferSize: server.bufferSize,
|
||||
maxHeadersSize: server.maxHeadersSize,
|
||||
maxRequestBodySize: server.maxRequestBodySize
|
||||
)
|
||||
|
||||
proc getConnectionState*(holder: HttpConnectionHolderRef): ConnectionState =
|
||||
if not(isNil(holder.connection)):
|
||||
case holder.connection.state
|
||||
of HttpState.Alive: ConnectionState.Alive
|
||||
of HttpState.Closing: ConnectionState.Closing
|
||||
of HttpState.Closed: ConnectionState.Closed
|
||||
else:
|
||||
ConnectionState.Accepted
|
||||
|
||||
proc init*(t: typedesc[ServerConnectionInfo],
|
||||
holder: HttpConnectionHolderRef): ServerConnectionInfo =
|
||||
let
|
||||
localAddress =
|
||||
try:
|
||||
Opt.some(holder.transp.localAddress())
|
||||
except CatchableError:
|
||||
Opt.none(TransportAddress)
|
||||
remoteAddress =
|
||||
try:
|
||||
Opt.some(holder.transp.remoteAddress())
|
||||
except CatchableError:
|
||||
Opt.none(TransportAddress)
|
||||
|
||||
ServerConnectionInfo(
|
||||
handle: SocketHandle(holder.transp.fd),
|
||||
connectionType: holder.server.getConnectionType(),
|
||||
connectionState: holder.getConnectionState(),
|
||||
remoteAddress: remoteAddress,
|
||||
localAddress: localAddress,
|
||||
acceptMoment: holder.acceptMoment,
|
||||
createMoment:
|
||||
if not(isNil(holder.connection)):
|
||||
Opt.some(holder.connection.createMoment)
|
||||
else:
|
||||
Opt.none(Moment)
|
||||
)
|
||||
|
||||
proc getConnections*(server: HttpServerRef): seq[ServerConnectionInfo] =
|
||||
var res: seq[ServerConnectionInfo]
|
||||
for holder in server.connections.values():
|
||||
res.add(ServerConnectionInfo.init(holder))
|
||||
res
|
|
@ -29,18 +29,20 @@ type
|
|||
## Enable HTTP/1.1 pipelining.
|
||||
|
||||
HttpServerError* {.pure.} = enum
|
||||
TimeoutError, CatchableError, RecoverableError, CriticalError,
|
||||
DisconnectError
|
||||
InterruptError, TimeoutError, CatchableError, RecoverableError,
|
||||
CriticalError, DisconnectError
|
||||
|
||||
HttpServerState* {.pure.} = enum
|
||||
ServerRunning, ServerStopped, ServerClosed
|
||||
|
||||
HttpProcessError* = object
|
||||
error*: HttpServerError
|
||||
kind*: HttpServerError
|
||||
code*: HttpCode
|
||||
exc*: ref CatchableError
|
||||
remote*: TransportAddress
|
||||
remote*: Opt[TransportAddress]
|
||||
|
||||
ConnectionFence* = Result[HttpConnectionRef, HttpProcessError]
|
||||
ResponseFence* = Result[HttpResponseRef, HttpProcessError]
|
||||
RequestFence* = Result[HttpRequestRef, HttpProcessError]
|
||||
|
||||
HttpRequestFlags* {.pure.} = enum
|
||||
|
@ -53,7 +55,7 @@ type
|
|||
Plain, SSE, Chunked
|
||||
|
||||
HttpResponseState* {.pure.} = enum
|
||||
Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb
|
||||
Empty, Prepared, Sending, Finished, Failed, Cancelled, Default
|
||||
|
||||
HttpProcessCallback* =
|
||||
proc(req: RequestFence): Future[HttpResponseRef] {.
|
||||
|
@ -64,6 +66,20 @@ type
|
|||
transp: StreamTransport): Future[HttpConnectionRef] {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
HttpCloseConnectionCallback* =
|
||||
proc(connection: HttpConnectionRef): Future[void] {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
HttpConnectionHolder* = object of RootObj
|
||||
connection*: HttpConnectionRef
|
||||
server*: HttpServerRef
|
||||
future*: Future[void]
|
||||
transp*: StreamTransport
|
||||
acceptMoment*: Moment
|
||||
connectionId*: string
|
||||
|
||||
HttpConnectionHolderRef* = ref HttpConnectionHolder
|
||||
|
||||
HttpServer* = object of RootObj
|
||||
instance*: StreamServer
|
||||
address*: TransportAddress
|
||||
|
@ -74,7 +90,7 @@ type
|
|||
serverIdent*: string
|
||||
flags*: set[HttpServerFlags]
|
||||
socketFlags*: set[ServerFlags]
|
||||
connections*: Table[string, Future[void]]
|
||||
connections*: OrderedTable[string, HttpConnectionHolderRef]
|
||||
acceptLoop*: Future[void]
|
||||
lifetime*: Future[void]
|
||||
headersTimeout*: Duration
|
||||
|
@ -122,11 +138,13 @@ type
|
|||
HttpConnection* = object of RootObj
|
||||
state*: HttpState
|
||||
server*: HttpServerRef
|
||||
transp: StreamTransport
|
||||
transp*: StreamTransport
|
||||
mainReader*: AsyncStreamReader
|
||||
mainWriter*: AsyncStreamWriter
|
||||
reader*: AsyncStreamReader
|
||||
writer*: AsyncStreamWriter
|
||||
closeCb*: HttpCloseConnectionCallback
|
||||
createMoment*: Moment
|
||||
buffer: seq[byte]
|
||||
|
||||
HttpConnectionRef* = ref HttpConnection
|
||||
|
@ -134,9 +152,24 @@ type
|
|||
ByteChar* = string | seq[byte]
|
||||
|
||||
proc init(htype: typedesc[HttpProcessError], error: HttpServerError,
|
||||
exc: ref CatchableError, remote: TransportAddress,
|
||||
code: HttpCode): HttpProcessError {.raises: [].} =
|
||||
HttpProcessError(error: error, exc: exc, remote: remote, code: code)
|
||||
exc: ref CatchableError, remote: Opt[TransportAddress],
|
||||
code: HttpCode): HttpProcessError {.
|
||||
raises: [].} =
|
||||
HttpProcessError(kind: error, exc: exc, remote: remote, code: code)
|
||||
|
||||
proc init(htype: typedesc[HttpProcessError],
|
||||
error: HttpServerError): HttpProcessError {.
|
||||
raises: [].} =
|
||||
HttpProcessError(kind: error)
|
||||
|
||||
proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef,
|
||||
transp: StreamTransport,
|
||||
connectionId: string): HttpConnectionHolderRef =
|
||||
HttpConnectionHolderRef(
|
||||
server: server, transp: transp, acceptMoment: Moment.now(),
|
||||
connectionId: connectionId)
|
||||
|
||||
proc error*(e: HttpProcessError): HttpServerError = e.kind
|
||||
|
||||
proc createConnection(server: HttpServerRef,
|
||||
transp: StreamTransport): Future[HttpConnectionRef] {.
|
||||
|
@ -176,7 +209,7 @@ proc new*(htype: typedesc[HttpServerRef],
|
|||
return err(exc.msg)
|
||||
|
||||
var res = HttpServerRef(
|
||||
address: address,
|
||||
address: serverInstance.localAddress(),
|
||||
instance: serverInstance,
|
||||
processCallback: processCallback,
|
||||
createConnCallback: createConnection,
|
||||
|
@ -196,15 +229,22 @@ proc new*(htype: typedesc[HttpServerRef],
|
|||
# else:
|
||||
# nil
|
||||
lifetime: newFuture[void]("http.server.lifetime"),
|
||||
connections: initTable[string, Future[void]]()
|
||||
connections: initOrderedTable[string, HttpConnectionHolderRef]()
|
||||
)
|
||||
ok(res)
|
||||
|
||||
proc getResponseFlags*(req: HttpRequestRef): set[HttpResponseFlags] =
|
||||
proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] =
|
||||
var defaultFlags: set[HttpServerFlags] = {}
|
||||
if isNil(req): return defaultFlags
|
||||
if isNil(req.connection): return defaultFlags
|
||||
if isNil(req.connection.server): return defaultFlags
|
||||
req.connection.server.flags
|
||||
|
||||
proc getResponseFlags(req: HttpRequestRef): set[HttpResponseFlags] =
|
||||
var defaultFlags: set[HttpResponseFlags] = {}
|
||||
case req.version
|
||||
of HttpVersion11:
|
||||
if HttpServerFlags.Http11Pipeline notin req.connection.server.flags:
|
||||
if HttpServerFlags.Http11Pipeline notin req.getServerFlags():
|
||||
return defaultFlags
|
||||
let header = req.headers.getString(ConnectionHeader, "keep-alive")
|
||||
if header == "keep-alive":
|
||||
|
@ -214,6 +254,12 @@ proc getResponseFlags*(req: HttpRequestRef): set[HttpResponseFlags] =
|
|||
else:
|
||||
defaultFlags
|
||||
|
||||
proc getResponseVersion(reqFence: RequestFence): HttpVersion {.raises: [].} =
|
||||
if reqFence.isErr():
|
||||
HttpVersion11
|
||||
else:
|
||||
reqFence.get().version
|
||||
|
||||
proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} =
|
||||
if req.response.isNone():
|
||||
var resp = HttpResponseRef(
|
||||
|
@ -235,9 +281,14 @@ proc getHostname*(server: HttpServerRef): string =
|
|||
else:
|
||||
server.baseUri.hostname
|
||||
|
||||
proc dumbResponse*(): HttpResponseRef {.raises: [].} =
|
||||
proc defaultResponse*(): HttpResponseRef {.raises: [].} =
|
||||
## Create an empty response to return when request processor got no request.
|
||||
HttpResponseRef(state: HttpResponseState.Dumb, version: HttpVersion11)
|
||||
HttpResponseRef(state: HttpResponseState.Default, version: HttpVersion11)
|
||||
|
||||
proc dumbResponse*(): HttpResponseRef {.raises: [],
|
||||
deprecated: "Please use defaultResponse() instead".} =
|
||||
## Create an empty response to return when request processor got no request.
|
||||
defaultResponse()
|
||||
|
||||
proc getId(transp: StreamTransport): Result[string, string] {.inline.} =
|
||||
## Returns string unique transport's identifier as string.
|
||||
|
@ -371,6 +422,7 @@ proc prepareRequest(conn: HttpConnectionRef,
|
|||
if strip(expectHeader).toLowerAscii() == "100-continue":
|
||||
request.requestFlags.incl(HttpRequestFlags.ClientExpect)
|
||||
|
||||
trackCounter(HttpServerRequestTrackerName)
|
||||
ok(request)
|
||||
|
||||
proc getBodyReader*(request: HttpRequestRef): HttpResult[HttpBodyReader] =
|
||||
|
@ -579,7 +631,7 @@ proc preferredContentType*(request: HttpRequestRef,
|
|||
proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
|
||||
code: HttpCode, keepAlive = true,
|
||||
datatype = "text/text",
|
||||
databody = ""): Future[bool] {.async.} =
|
||||
databody = "") {.async.} =
|
||||
var answer = $version & " " & $code & "\r\n"
|
||||
answer.add(DateHeader)
|
||||
answer.add(": ")
|
||||
|
@ -605,13 +657,90 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
|
|||
answer.add(databody)
|
||||
try:
|
||||
await conn.writer.write(answer)
|
||||
return true
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError:
|
||||
# We ignore errors here, because we indicating error already.
|
||||
discard
|
||||
|
||||
proc sendErrorResponse(conn: HttpConnectionRef, reqFence: RequestFence,
|
||||
respError: HttpProcessError): Future[bool] {.async.} =
|
||||
let version = getResponseVersion(reqFence)
|
||||
try:
|
||||
if reqFence.isOk():
|
||||
case respError.kind
|
||||
of HttpServerError.CriticalError:
|
||||
await conn.sendErrorResponse(version, respError.code, false)
|
||||
false
|
||||
of HttpServerError.RecoverableError:
|
||||
await conn.sendErrorResponse(version, respError.code, true)
|
||||
true
|
||||
of HttpServerError.CatchableError:
|
||||
await conn.sendErrorResponse(version, respError.code, false)
|
||||
false
|
||||
of HttpServerError.DisconnectError,
|
||||
HttpServerError.InterruptError,
|
||||
HttpServerError.TimeoutError:
|
||||
raiseAssert("Unexpected response error: " & $respError.kind)
|
||||
else:
|
||||
false
|
||||
except CancelledError:
|
||||
return false
|
||||
except AsyncStreamWriteError:
|
||||
return false
|
||||
except AsyncStreamIncompleteError:
|
||||
return false
|
||||
false
|
||||
|
||||
proc sendDefaultResponse(conn: HttpConnectionRef, reqFence: RequestFence,
|
||||
response: HttpResponseRef): Future[bool] {.async.} =
|
||||
let
|
||||
version = getResponseVersion(reqFence)
|
||||
keepConnection =
|
||||
if isNil(response):
|
||||
false
|
||||
else:
|
||||
HttpResponseFlags.KeepAlive in response.flags
|
||||
try:
|
||||
if reqFence.isOk():
|
||||
if isNil(response):
|
||||
await conn.sendErrorResponse(version, Http404, keepConnection)
|
||||
keepConnection
|
||||
else:
|
||||
case response.state
|
||||
of HttpResponseState.Empty:
|
||||
# Response was ignored, so we respond with not found.
|
||||
await conn.sendErrorResponse(version, Http404, keepConnection)
|
||||
keepConnection
|
||||
of HttpResponseState.Prepared:
|
||||
# Response was prepared but not sent, so we can respond with some
|
||||
# error code
|
||||
await conn.sendErrorResponse(HttpVersion11, Http409, keepConnection)
|
||||
keepConnection
|
||||
of HttpResponseState.Sending, HttpResponseState.Failed,
|
||||
HttpResponseState.Cancelled:
|
||||
# Just drop connection, because we dont know at what stage we are
|
||||
false
|
||||
of HttpResponseState.Default:
|
||||
# Response was ignored, so we respond with not found.
|
||||
await conn.sendErrorResponse(version, Http404, keepConnection)
|
||||
keepConnection
|
||||
of HttpResponseState.Finished:
|
||||
keepConnection
|
||||
else:
|
||||
case reqFence.error.kind
|
||||
of HttpServerError.TimeoutError:
|
||||
await conn.sendErrorResponse(version, reqFence.error.code, false)
|
||||
false
|
||||
of HttpServerError.CriticalError:
|
||||
await conn.sendErrorResponse(version, reqFence.error.code, false)
|
||||
false
|
||||
of HttpServerError.RecoverableError:
|
||||
await conn.sendErrorResponse(version, reqFence.error.code, true)
|
||||
false
|
||||
of HttpServerError.CatchableError:
|
||||
await conn.sendErrorResponse(version, reqFence.error.code, false)
|
||||
false
|
||||
of HttpServerError.DisconnectError,
|
||||
HttpServerError.InterruptError:
|
||||
raiseAssert("Unexpected request error: " & $reqFence.error.kind)
|
||||
except CancelledError:
|
||||
false
|
||||
|
||||
proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} =
|
||||
try:
|
||||
|
@ -644,31 +773,33 @@ proc init*(value: var HttpConnection, server: HttpServerRef,
|
|||
mainWriter: newAsyncStreamWriter(transp)
|
||||
)
|
||||
|
||||
proc closeUnsecureConnection(conn: HttpConnectionRef) {.async.} =
|
||||
if conn.state == HttpState.Alive:
|
||||
conn.state = HttpState.Closing
|
||||
var pending: seq[Future[void]]
|
||||
pending.add(conn.mainReader.closeWait())
|
||||
pending.add(conn.mainWriter.closeWait())
|
||||
pending.add(conn.transp.closeWait())
|
||||
try:
|
||||
await allFutures(pending)
|
||||
except CancelledError:
|
||||
await allFutures(pending)
|
||||
untrackCounter(HttpServerUnsecureConnectionTrackerName)
|
||||
conn.state = HttpState.Closed
|
||||
|
||||
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
|
||||
transp: StreamTransport): HttpConnectionRef =
|
||||
var res = HttpConnectionRef()
|
||||
res[].init(server, transp)
|
||||
res.reader = res.mainReader
|
||||
res.writer = res.mainWriter
|
||||
res.closeCb = closeUnsecureConnection
|
||||
res.createMoment = Moment.now()
|
||||
trackCounter(HttpServerUnsecureConnectionTrackerName)
|
||||
res
|
||||
|
||||
proc closeWait*(conn: HttpConnectionRef) {.async.} =
|
||||
if conn.state == HttpState.Alive:
|
||||
conn.state = HttpState.Closing
|
||||
var pending: seq[Future[void]]
|
||||
if conn.reader != conn.mainReader:
|
||||
pending.add(conn.reader.closeWait())
|
||||
if conn.writer != conn.mainWriter:
|
||||
pending.add(conn.writer.closeWait())
|
||||
if len(pending) > 0:
|
||||
await allFutures(pending)
|
||||
# After we going to close everything else.
|
||||
pending.setLen(3)
|
||||
pending[0] = conn.mainReader.closeWait()
|
||||
pending[1] = conn.mainWriter.closeWait()
|
||||
pending[2] = conn.transp.closeWait()
|
||||
await allFutures(pending)
|
||||
conn.state = HttpState.Closed
|
||||
proc closeWait*(conn: HttpConnectionRef): Future[void] =
|
||||
conn.closeCb(conn)
|
||||
|
||||
proc closeWait*(req: HttpRequestRef) {.async.} =
|
||||
if req.state == HttpState.Alive:
|
||||
|
@ -676,7 +807,12 @@ proc closeWait*(req: HttpRequestRef) {.async.} =
|
|||
req.state = HttpState.Closing
|
||||
let resp = req.response.get()
|
||||
if (HttpResponseFlags.Stream in resp.flags) and not(isNil(resp.writer)):
|
||||
await resp.writer.closeWait()
|
||||
var writer = resp.writer.closeWait()
|
||||
try:
|
||||
await writer
|
||||
except CancelledError:
|
||||
await writer
|
||||
untrackCounter(HttpServerRequestTrackerName)
|
||||
req.state = HttpState.Closed
|
||||
|
||||
proc createConnection(server: HttpServerRef,
|
||||
|
@ -694,174 +830,168 @@ proc `keepalive=`*(resp: HttpResponseRef, value: bool) =
|
|||
proc keepalive*(resp: HttpResponseRef): bool {.raises: [].} =
|
||||
HttpResponseFlags.KeepAlive in resp.flags
|
||||
|
||||
proc processLoop(server: HttpServerRef, transp: StreamTransport,
|
||||
connId: string) {.async.} =
|
||||
var
|
||||
conn: HttpConnectionRef
|
||||
connArg: RequestFence
|
||||
runLoop = false
|
||||
|
||||
proc getRemoteAddress(transp: StreamTransport): Opt[TransportAddress] {.
|
||||
raises: [].} =
|
||||
if isNil(transp): return Opt.none(TransportAddress)
|
||||
try:
|
||||
conn = await server.createConnCallback(server, transp)
|
||||
runLoop = true
|
||||
Opt.some(transp.remoteAddress())
|
||||
except CatchableError:
|
||||
Opt.none(TransportAddress)
|
||||
|
||||
proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] {.
|
||||
raises: [].} =
|
||||
if isNil(connection): return Opt.none(TransportAddress)
|
||||
getRemoteAddress(connection.transp)
|
||||
|
||||
proc getResponseFence*(connection: HttpConnectionRef,
|
||||
reqFence: RequestFence): Future[ResponseFence] {.
|
||||
async.} =
|
||||
try:
|
||||
let res = await connection.server.processCallback(reqFence)
|
||||
ResponseFence.ok(res)
|
||||
except CancelledError:
|
||||
server.connections.del(connId)
|
||||
await transp.closeWait()
|
||||
return
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.InterruptError))
|
||||
except HttpCriticalError as exc:
|
||||
let error = HttpProcessError.init(HttpServerError.CriticalError, exc,
|
||||
transp.remoteAddress(), exc.code)
|
||||
connArg = RequestFence.err(error)
|
||||
runLoop = false
|
||||
let address = connection.getRemoteAddress()
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.CriticalError, exc, address, exc.code))
|
||||
except HttpRecoverableError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.RecoverableError, exc, address, exc.code))
|
||||
except CatchableError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.CatchableError, exc, address, Http503))
|
||||
|
||||
if not(runLoop):
|
||||
try:
|
||||
# We still want to notify process callback about failure, but we ignore
|
||||
# result.
|
||||
discard await server.processCallback(connArg)
|
||||
except CancelledError:
|
||||
runLoop = false
|
||||
except CatchableError as exc:
|
||||
# There should be no exceptions, so we will raise `Defect`.
|
||||
raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]")
|
||||
proc getResponseFence*(server: HttpServerRef,
|
||||
connFence: ConnectionFence): Future[ResponseFence] {.
|
||||
async.} =
|
||||
doAssert(connFence.isErr())
|
||||
try:
|
||||
let
|
||||
reqFence = RequestFence.err(connFence.error)
|
||||
res = await server.processCallback(reqFence)
|
||||
ResponseFence.ok(res)
|
||||
except CancelledError:
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.InterruptError))
|
||||
except HttpCriticalError as exc:
|
||||
let address = Opt.none(TransportAddress)
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.CriticalError, exc, address, exc.code))
|
||||
except HttpRecoverableError as exc:
|
||||
let address = Opt.none(TransportAddress)
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.RecoverableError, exc, address, exc.code))
|
||||
except CatchableError as exc:
|
||||
let address = Opt.none(TransportAddress)
|
||||
ResponseFence.err(HttpProcessError.init(
|
||||
HttpServerError.CatchableError, exc, address, Http503))
|
||||
|
||||
var breakLoop = false
|
||||
while runLoop:
|
||||
var
|
||||
arg: RequestFence
|
||||
resp: HttpResponseRef
|
||||
|
||||
try:
|
||||
let request =
|
||||
if server.headersTimeout.isInfinite():
|
||||
await conn.getRequest()
|
||||
else:
|
||||
await conn.getRequest().wait(server.headersTimeout)
|
||||
arg = RequestFence.ok(request)
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
except AsyncTimeoutError as exc:
|
||||
let error = HttpProcessError.init(HttpServerError.TimeoutError, exc,
|
||||
transp.remoteAddress(), Http408)
|
||||
arg = RequestFence.err(error)
|
||||
except HttpRecoverableError as exc:
|
||||
let error = HttpProcessError.init(HttpServerError.RecoverableError, exc,
|
||||
transp.remoteAddress(), exc.code)
|
||||
arg = RequestFence.err(error)
|
||||
except HttpCriticalError as exc:
|
||||
let error = HttpProcessError.init(HttpServerError.CriticalError, exc,
|
||||
transp.remoteAddress(), exc.code)
|
||||
arg = RequestFence.err(error)
|
||||
except HttpDisconnectError as exc:
|
||||
if HttpServerFlags.NotifyDisconnect in server.flags:
|
||||
let error = HttpProcessError.init(HttpServerError.DisconnectError, exc,
|
||||
transp.remoteAddress(), Http400)
|
||||
arg = RequestFence.err(error)
|
||||
proc getRequestFence*(server: HttpServerRef,
|
||||
connection: HttpConnectionRef): Future[RequestFence] {.
|
||||
async.} =
|
||||
try:
|
||||
let res =
|
||||
if server.headersTimeout.isInfinite():
|
||||
await connection.getRequest()
|
||||
else:
|
||||
breakLoop = true
|
||||
except CatchableError as exc:
|
||||
let error = HttpProcessError.init(HttpServerError.CatchableError, exc,
|
||||
transp.remoteAddress(), Http500)
|
||||
arg = RequestFence.err(error)
|
||||
await connection.getRequest().wait(server.headersTimeout)
|
||||
RequestFence.ok(res)
|
||||
except CancelledError:
|
||||
RequestFence.err(HttpProcessError.init(HttpServerError.InterruptError))
|
||||
except AsyncTimeoutError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
RequestFence.err(HttpProcessError.init(
|
||||
HttpServerError.TimeoutError, exc, address, Http408))
|
||||
except HttpRecoverableError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
RequestFence.err(HttpProcessError.init(
|
||||
HttpServerError.RecoverableError, exc, address, exc.code))
|
||||
except HttpCriticalError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
RequestFence.err(HttpProcessError.init(
|
||||
HttpServerError.CriticalError, exc, address, exc.code))
|
||||
except HttpDisconnectError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
RequestFence.err(HttpProcessError.init(
|
||||
HttpServerError.DisconnectError, exc, address, Http400))
|
||||
except CatchableError as exc:
|
||||
let address = connection.getRemoteAddress()
|
||||
RequestFence.err(HttpProcessError.init(
|
||||
HttpServerError.CatchableError, exc, address, Http500))
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
proc getConnectionFence*(server: HttpServerRef,
|
||||
transp: StreamTransport): Future[ConnectionFence] {.
|
||||
async.} =
|
||||
try:
|
||||
let res = await server.createConnCallback(server, transp)
|
||||
ConnectionFence.ok(res)
|
||||
except CancelledError:
|
||||
await transp.closeWait()
|
||||
ConnectionFence.err(HttpProcessError.init(HttpServerError.InterruptError))
|
||||
except HttpCriticalError as exc:
|
||||
await transp.closeWait()
|
||||
let address = transp.getRemoteAddress()
|
||||
ConnectionFence.err(HttpProcessError.init(
|
||||
HttpServerError.CriticalError, exc, address, exc.code))
|
||||
|
||||
breakLoop = false
|
||||
var lastErrorCode: Opt[HttpCode]
|
||||
|
||||
try:
|
||||
resp = await conn.server.processCallback(arg)
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
except HttpCriticalError as exc:
|
||||
lastErrorCode = Opt.some(exc.code)
|
||||
except HttpRecoverableError as exc:
|
||||
lastErrorCode = Opt.some(exc.code)
|
||||
except CatchableError:
|
||||
lastErrorCode = Opt.some(Http503)
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
|
||||
if arg.isErr():
|
||||
let code = arg.error().code
|
||||
try:
|
||||
case arg.error().error
|
||||
of HttpServerError.TimeoutError:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, code, false)
|
||||
of HttpServerError.RecoverableError:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, code, false)
|
||||
of HttpServerError.CriticalError:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, code, false)
|
||||
of HttpServerError.CatchableError:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, code, false)
|
||||
of HttpServerError.DisconnectError:
|
||||
discard
|
||||
except CancelledError:
|
||||
# We swallowing `CancelledError` in a loop, but we going to exit
|
||||
# loop ASAP.
|
||||
discard
|
||||
break
|
||||
proc processRequest(server: HttpServerRef,
|
||||
connection: HttpConnectionRef,
|
||||
connId: string): Future[bool] {.async.} =
|
||||
let requestFence = await getRequestFence(server, connection)
|
||||
if requestFence.isErr():
|
||||
case requestFence.error.kind
|
||||
of HttpServerError.InterruptError:
|
||||
return false
|
||||
of HttpServerError.DisconnectError:
|
||||
if HttpServerFlags.NotifyDisconnect notin server.flags:
|
||||
return false
|
||||
else:
|
||||
let request = arg.get()
|
||||
var keepConn = HttpResponseFlags.KeepAlive in request.getResponseFlags()
|
||||
if lastErrorCode.isNone():
|
||||
if isNil(resp):
|
||||
# Response was `nil`.
|
||||
try:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http404, false)
|
||||
except CancelledError:
|
||||
keepConn = false
|
||||
else:
|
||||
try:
|
||||
case resp.state
|
||||
of HttpResponseState.Empty:
|
||||
# Response was ignored
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http404,
|
||||
keepConn)
|
||||
of HttpResponseState.Prepared:
|
||||
# Response was prepared but not sent.
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http409,
|
||||
keepConn)
|
||||
else:
|
||||
# some data was already sent to the client.
|
||||
keepConn = resp.keepalive()
|
||||
except CancelledError:
|
||||
keepConn = false
|
||||
else:
|
||||
try:
|
||||
discard await conn.sendErrorResponse(HttpVersion11,
|
||||
lastErrorCode.get(), false)
|
||||
except CancelledError:
|
||||
keepConn = false
|
||||
discard
|
||||
|
||||
# Closing and releasing all the request resources.
|
||||
try:
|
||||
await request.closeWait()
|
||||
except CancelledError:
|
||||
# We swallowing `CancelledError` in a loop, but we still need to close
|
||||
# `request` before exiting.
|
||||
await request.closeWait()
|
||||
defer:
|
||||
if requestFence.isOk():
|
||||
await requestFence.get().closeWait()
|
||||
|
||||
if not(keepConn):
|
||||
break
|
||||
let responseFence = await getResponseFence(connection, requestFence)
|
||||
if responseFence.isErr() and
|
||||
(responseFence.error.kind == HttpServerError.InterruptError):
|
||||
return false
|
||||
|
||||
# Connection could be `nil` only when secure handshake is failed.
|
||||
if not(isNil(conn)):
|
||||
try:
|
||||
await conn.closeWait()
|
||||
except CancelledError:
|
||||
# Cancellation could be happened while we closing `conn`. But we still
|
||||
# need to close it.
|
||||
await conn.closeWait()
|
||||
if responseFence.isErr():
|
||||
await connection.sendErrorResponse(requestFence, responseFence.error)
|
||||
else:
|
||||
await connection.sendDefaultResponse(requestFence, responseFence.get())
|
||||
|
||||
server.connections.del(connId)
|
||||
# if server.maxConnections > 0:
|
||||
# server.semaphore.release()
|
||||
proc processLoop(holder: HttpConnectionHolderRef) {.async.} =
|
||||
let
|
||||
server = holder.server
|
||||
transp = holder.transp
|
||||
connectionId = holder.connectionId
|
||||
connection =
|
||||
block:
|
||||
let res = await server.getConnectionFence(transp)
|
||||
if res.isErr():
|
||||
if res.error.kind != HttpServerError.InterruptError:
|
||||
discard await server.getResponseFence(res)
|
||||
server.connections.del(connectionId)
|
||||
return
|
||||
res.get()
|
||||
|
||||
holder.connection = connection
|
||||
|
||||
defer:
|
||||
server.connections.del(connectionId)
|
||||
await connection.closeWait()
|
||||
|
||||
var runLoop = true
|
||||
while runLoop:
|
||||
runLoop = await server.processRequest(connection, connectionId)
|
||||
|
||||
proc acceptClientLoop(server: HttpServerRef) {.async.} =
|
||||
var breakLoop = false
|
||||
while true:
|
||||
try:
|
||||
# if server.maxConnections > 0:
|
||||
|
@ -872,27 +1002,26 @@ proc acceptClientLoop(server: HttpServerRef) {.async.} =
|
|||
# We are unable to identify remote peer, it means that remote peer
|
||||
# disconnected before identification.
|
||||
await transp.closeWait()
|
||||
breakLoop = false
|
||||
break
|
||||
else:
|
||||
let connId = resId.get()
|
||||
server.connections[connId] = processLoop(server, transp, connId)
|
||||
let holder = HttpConnectionHolderRef.new(server, transp, resId.get())
|
||||
server.connections[connId] = holder
|
||||
holder.future = processLoop(holder)
|
||||
except CancelledError:
|
||||
# Server was stopped
|
||||
breakLoop = true
|
||||
break
|
||||
except TransportOsError:
|
||||
# This is some critical unrecoverable error.
|
||||
breakLoop = true
|
||||
break
|
||||
except TransportTooManyError:
|
||||
# Non critical error
|
||||
breakLoop = false
|
||||
discard
|
||||
except TransportAbortedError:
|
||||
# Non critical error
|
||||
breakLoop = false
|
||||
discard
|
||||
except CatchableError:
|
||||
# Unexpected error
|
||||
breakLoop = true
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
|
||||
proc state*(server: HttpServerRef): HttpServerState {.raises: [].} =
|
||||
|
@ -922,11 +1051,11 @@ proc drop*(server: HttpServerRef) {.async.} =
|
|||
## Drop all pending HTTP connections.
|
||||
var pending: seq[Future[void]]
|
||||
if server.state in {ServerStopped, ServerRunning}:
|
||||
for fut in server.connections.values():
|
||||
if not(fut.finished()):
|
||||
fut.cancel()
|
||||
pending.add(fut)
|
||||
for holder in server.connections.values():
|
||||
if not(isNil(holder.future)) and not(holder.future.finished()):
|
||||
pending.add(holder.future.cancelAndWait())
|
||||
await allFutures(pending)
|
||||
server.connections.clear()
|
||||
|
||||
proc closeWait*(server: HttpServerRef) {.async.} =
|
||||
## Stop HTTP server and drop all the pending connections.
|
||||
|
|
|
@ -24,6 +24,28 @@ type
|
|||
|
||||
SecureHttpConnectionRef* = ref SecureHttpConnection
|
||||
|
||||
proc closeSecConnection(conn: HttpConnectionRef) {.async.} =
|
||||
if conn.state == HttpState.Alive:
|
||||
conn.state = HttpState.Closing
|
||||
var pending: seq[Future[void]]
|
||||
pending.add(conn.writer.closeWait())
|
||||
pending.add(conn.reader.closeWait())
|
||||
try:
|
||||
await allFutures(pending)
|
||||
except CancelledError:
|
||||
await allFutures(pending)
|
||||
# After we going to close everything else.
|
||||
pending.setLen(3)
|
||||
pending[0] = conn.mainReader.closeWait()
|
||||
pending[1] = conn.mainWriter.closeWait()
|
||||
pending[2] = conn.transp.closeWait()
|
||||
try:
|
||||
await allFutures(pending)
|
||||
except CancelledError:
|
||||
await allFutures(pending)
|
||||
untrackCounter(HttpServerSecureConnectionTrackerName)
|
||||
conn.state = HttpState.Closed
|
||||
|
||||
proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef,
|
||||
transp: StreamTransport): SecureHttpConnectionRef =
|
||||
var res = SecureHttpConnectionRef()
|
||||
|
@ -37,6 +59,8 @@ proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef,
|
|||
res.tlsStream = tlsStream
|
||||
res.reader = AsyncStreamReader(tlsStream.reader)
|
||||
res.writer = AsyncStreamWriter(tlsStream.writer)
|
||||
res.closeCb = closeSecConnection
|
||||
trackCounter(HttpServerSecureConnectionTrackerName)
|
||||
res
|
||||
|
||||
proc createSecConnection(server: HttpServerRef,
|
||||
|
@ -100,7 +124,7 @@ proc new*(htype: typedesc[SecureHttpServerRef],
|
|||
createConnCallback: createSecConnection,
|
||||
baseUri: serverUri,
|
||||
serverIdent: serverIdent,
|
||||
flags: serverFlags,
|
||||
flags: serverFlags + {HttpServerFlags.Secure},
|
||||
socketFlags: socketFlags,
|
||||
maxConnections: maxConnections,
|
||||
bufferSize: bufferSize,
|
||||
|
@ -114,7 +138,7 @@ proc new*(htype: typedesc[SecureHttpServerRef],
|
|||
# else:
|
||||
# nil
|
||||
lifetime: newFuture[void]("http.server.lifetime"),
|
||||
connections: initTable[string, Future[void]](),
|
||||
connections: initOrderedTable[string, HttpConnectionHolderRef](),
|
||||
tlsCertificate: tlsCertificate,
|
||||
tlsPrivateKey: tlsPrivateKey,
|
||||
secureFlags: secureFlags
|
||||
|
|
|
@ -171,11 +171,16 @@ type
|
|||
dump*: proc(): string {.gcsafe, raises: [].}
|
||||
isLeaked*: proc(): bool {.gcsafe, raises: [].}
|
||||
|
||||
TrackerCounter* = object
|
||||
opened*: uint64
|
||||
closed*: uint64
|
||||
|
||||
PDispatcherBase = ref object of RootRef
|
||||
timers*: HeapQueue[TimerCallback]
|
||||
callbacks*: Deque[AsyncCallback]
|
||||
idlers*: Deque[AsyncCallback]
|
||||
trackers*: Table[string, TrackerBase]
|
||||
counters*: Table[string, TrackerCounter]
|
||||
|
||||
proc sentinelCallbackImpl(arg: pointer) {.gcsafe.} =
|
||||
raiseAssert "Sentinel callback MUST not be scheduled"
|
||||
|
@ -404,7 +409,8 @@ when defined(windows):
|
|||
timers: initHeapQueue[TimerCallback](),
|
||||
callbacks: initDeque[AsyncCallback](64),
|
||||
idlers: initDeque[AsyncCallback](),
|
||||
trackers: initTable[string, TrackerBase]()
|
||||
trackers: initTable[string, TrackerBase](),
|
||||
counters: initTable[string, TrackerCounter]()
|
||||
)
|
||||
res.callbacks.addLast(SentinelCallback)
|
||||
initAPI(res)
|
||||
|
@ -814,7 +820,8 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or
|
|||
callbacks: initDeque[AsyncCallback](asyncEventsCount),
|
||||
idlers: initDeque[AsyncCallback](),
|
||||
keys: newSeq[ReadyKey](asyncEventsCount),
|
||||
trackers: initTable[string, TrackerBase]()
|
||||
trackers: initTable[string, TrackerBase](),
|
||||
counters: initTable[string, TrackerCounter]()
|
||||
)
|
||||
res.callbacks.addLast(SentinelCallback)
|
||||
initAPI(res)
|
||||
|
@ -1505,16 +1512,54 @@ proc waitFor*[T](fut: Future[T]): T {.raises: [CatchableError].} =
|
|||
|
||||
fut.read()
|
||||
|
||||
proc addTracker*[T](id: string, tracker: T) =
|
||||
proc addTracker*[T](id: string, tracker: T) {.
|
||||
deprecated: "Please use trackCounter facility instead".} =
|
||||
## Add new ``tracker`` object to current thread dispatcher with identifier
|
||||
## ``id``.
|
||||
let loop = getThreadDispatcher()
|
||||
loop.trackers[id] = tracker
|
||||
getThreadDispatcher().trackers[id] = tracker
|
||||
|
||||
proc getTracker*(id: string): TrackerBase =
|
||||
proc getTracker*(id: string): TrackerBase {.
|
||||
deprecated: "Please use getTrackerCounter() instead".} =
|
||||
## Get ``tracker`` from current thread dispatcher using identifier ``id``.
|
||||
let loop = getThreadDispatcher()
|
||||
result = loop.trackers.getOrDefault(id, nil)
|
||||
getThreadDispatcher().trackers.getOrDefault(id, nil)
|
||||
|
||||
proc trackCounter*(name: string) {.noinit.} =
|
||||
## Increase tracker counter with name ``name`` by 1.
|
||||
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
|
||||
inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).opened)
|
||||
|
||||
proc untrackCounter*(name: string) {.noinit.} =
|
||||
## Decrease tracker counter with name ``name`` by 1.
|
||||
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
|
||||
inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).closed)
|
||||
|
||||
proc getTrackerCounter*(name: string): TrackerCounter {.noinit.} =
|
||||
## Return value of counter with name ``name``.
|
||||
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
|
||||
getThreadDispatcher().counters.getOrDefault(name, tracker)
|
||||
|
||||
proc isCounterLeaked*(name: string): bool {.noinit.} =
|
||||
## Returns ``true`` if leak is detected, number of `opened` not equal to
|
||||
## number of `closed` requests.
|
||||
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
|
||||
let res = getThreadDispatcher().counters.getOrDefault(name, tracker)
|
||||
res.opened == res.closed
|
||||
|
||||
iterator trackerCounters*(
|
||||
loop: PDispatcher
|
||||
): tuple[name: string, value: TrackerCounter] =
|
||||
## Iterates over `loop` thread dispatcher tracker counter table, returns all
|
||||
## the tracker counter's names and values.
|
||||
doAssert(not(isNil(loop)))
|
||||
for key, value in loop.counters.pairs():
|
||||
yield (key, value)
|
||||
|
||||
iterator trackerCounterKeys*(loop: PDispatcher): string =
|
||||
doAssert(not(isNil(loop)))
|
||||
## Iterates over `loop` thread dispatcher tracker counter table, returns all
|
||||
## tracker names.
|
||||
for key in loop.counters.keys():
|
||||
yield key
|
||||
|
||||
when chronosFutureTracking:
|
||||
iterator pendingFutures*(): FutureBase =
|
||||
|
|
|
@ -23,8 +23,6 @@ const
|
|||
AsyncProcessTrackerName* = "async.process"
|
||||
## AsyncProcess leaks tracker name
|
||||
|
||||
|
||||
|
||||
type
|
||||
AsyncProcessError* = object of CatchableError
|
||||
|
||||
|
@ -109,49 +107,9 @@ type
|
|||
stdError*: string
|
||||
status*: int
|
||||
|
||||
AsyncProcessTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle =
|
||||
ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto)
|
||||
|
||||
proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.}
|
||||
|
||||
proc getAsyncProcessTracker(): AsyncProcessTracker {.inline.} =
|
||||
var res = cast[AsyncProcessTracker](getTracker(AsyncProcessTrackerName))
|
||||
if isNil(res):
|
||||
res = setupAsyncProcessTracker()
|
||||
res
|
||||
|
||||
proc dumpAsyncProcessTracking(): string {.gcsafe.} =
|
||||
var tracker = getAsyncProcessTracker()
|
||||
let res = "Started async processes: " & $tracker.opened & "\n" &
|
||||
"Closed async processes: " & $tracker.closed
|
||||
res
|
||||
|
||||
proc leakAsyncProccessTracker(): bool {.gcsafe.} =
|
||||
var tracker = getAsyncProcessTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackAsyncProccess(t: AsyncProcessRef) {.inline.} =
|
||||
var tracker = getAsyncProcessTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackAsyncProcess(t: AsyncProcessRef) {.inline.} =
|
||||
var tracker = getAsyncProcessTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} =
|
||||
var res = AsyncProcessTracker(
|
||||
opened: 0,
|
||||
closed: 0,
|
||||
dump: dumpAsyncProcessTracking,
|
||||
isLeaked: leakAsyncProccessTracker
|
||||
)
|
||||
addTracker(AsyncProcessTrackerName, res)
|
||||
res
|
||||
|
||||
proc init*(t: typedesc[AsyncFD], handle: ProcessStreamHandle): AsyncFD =
|
||||
case handle.kind
|
||||
of ProcessStreamHandleKind.ProcHandle:
|
||||
|
@ -502,7 +460,7 @@ when defined(windows):
|
|||
flags: pipes.flags
|
||||
)
|
||||
|
||||
trackAsyncProccess(process)
|
||||
trackCounter(AsyncProcessTrackerName)
|
||||
return process
|
||||
|
||||
proc peekProcessExitCode(p: AsyncProcessRef): AsyncProcessResult[int] =
|
||||
|
@ -919,7 +877,7 @@ else:
|
|||
flags: pipes.flags
|
||||
)
|
||||
|
||||
trackAsyncProccess(process)
|
||||
trackCounter(AsyncProcessTrackerName)
|
||||
return process
|
||||
|
||||
proc peekProcessExitCode(p: AsyncProcessRef,
|
||||
|
@ -1237,7 +1195,7 @@ proc closeWait*(p: AsyncProcessRef) {.async.} =
|
|||
discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0))
|
||||
await p.pipes.closeProcessStreams(p.options)
|
||||
discard p.closeThreadAndProcessHandle()
|
||||
untrackAsyncProcess(p)
|
||||
untrackCounter(AsyncProcessTrackerName)
|
||||
|
||||
proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter =
|
||||
doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer,
|
||||
|
|
|
@ -96,10 +96,6 @@ type
|
|||
reader*: AsyncStreamReader
|
||||
writer*: AsyncStreamWriter
|
||||
|
||||
AsyncStreamTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter
|
||||
|
||||
proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer =
|
||||
|
@ -332,79 +328,6 @@ template checkStreamClosed*(t: untyped) =
|
|||
template checkStreamFinished*(t: untyped) =
|
||||
if t.atEof(): raiseAsyncStreamWriteEOFError()
|
||||
|
||||
proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.
|
||||
gcsafe, raises: [].}
|
||||
proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} =
|
||||
var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName))
|
||||
if isNil(res):
|
||||
res = setupAsyncStreamReaderTracker()
|
||||
res
|
||||
|
||||
proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} =
|
||||
var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName))
|
||||
if isNil(res):
|
||||
res = setupAsyncStreamWriterTracker()
|
||||
res
|
||||
|
||||
proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} =
|
||||
var tracker = getAsyncStreamReaderTracker()
|
||||
let res = "Opened async stream readers: " & $tracker.opened & "\n" &
|
||||
"Closed async stream readers: " & $tracker.closed
|
||||
res
|
||||
|
||||
proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} =
|
||||
var tracker = getAsyncStreamWriterTracker()
|
||||
let res = "Opened async stream writers: " & $tracker.opened & "\n" &
|
||||
"Closed async stream writers: " & $tracker.closed
|
||||
res
|
||||
|
||||
proc leakAsyncStreamReader(): bool {.gcsafe.} =
|
||||
var tracker = getAsyncStreamReaderTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc leakAsyncStreamWriter(): bool {.gcsafe.} =
|
||||
var tracker = getAsyncStreamWriterTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} =
|
||||
var tracker = getAsyncStreamReaderTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackAsyncStreamReader*(t: AsyncStreamReader) {.inline.} =
|
||||
var tracker = getAsyncStreamReaderTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc trackAsyncStreamWriter(t: AsyncStreamWriter) {.inline.} =
|
||||
var tracker = getAsyncStreamWriterTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} =
|
||||
var tracker = getAsyncStreamWriterTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} =
|
||||
var res = AsyncStreamTracker(
|
||||
opened: 0,
|
||||
closed: 0,
|
||||
dump: dumpAsyncStreamReaderTracking,
|
||||
isLeaked: leakAsyncStreamReader
|
||||
)
|
||||
addTracker(AsyncStreamReaderTrackerName, res)
|
||||
res
|
||||
|
||||
proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} =
|
||||
var res = AsyncStreamTracker(
|
||||
opened: 0,
|
||||
closed: 0,
|
||||
dump: dumpAsyncStreamWriterTracking,
|
||||
isLeaked: leakAsyncStreamWriter
|
||||
)
|
||||
addTracker(AsyncStreamWriterTrackerName, res)
|
||||
res
|
||||
|
||||
template readLoop(body: untyped): untyped =
|
||||
while true:
|
||||
if rstream.buffer.dataLen() == 0:
|
||||
|
@ -977,9 +900,9 @@ proc close*(rw: AsyncStreamRW) =
|
|||
if not(rw.future.finished()):
|
||||
rw.future.complete()
|
||||
when rw is AsyncStreamReader:
|
||||
untrackAsyncStreamReader(rw)
|
||||
untrackCounter(AsyncStreamReaderTrackerName)
|
||||
elif rw is AsyncStreamWriter:
|
||||
untrackAsyncStreamWriter(rw)
|
||||
untrackCounter(AsyncStreamWriterTrackerName)
|
||||
rw.state = AsyncStreamState.Closed
|
||||
|
||||
when rw is AsyncStreamReader:
|
||||
|
@ -1028,7 +951,7 @@ proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
|
|||
child.wsource = wsource
|
||||
child.tsource = wsource.tsource
|
||||
child.queue = newAsyncQueue[WriteItem](queueSize)
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
|
||||
|
@ -1042,7 +965,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
|
|||
if not isNil(udata):
|
||||
GC_ref(udata)
|
||||
child.udata = cast[pointer](udata)
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
|
||||
|
@ -1053,7 +976,7 @@ proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
|
|||
child.rsource = rsource
|
||||
child.tsource = rsource.tsource
|
||||
child.buffer = AsyncBuffer.init(bufferSize)
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
|
||||
|
@ -1068,7 +991,7 @@ proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
|
|||
if not isNil(udata):
|
||||
GC_ref(udata)
|
||||
child.udata = cast[pointer](udata)
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc init*(child: AsyncStreamWriter, tsource: StreamTransport) =
|
||||
|
@ -1077,7 +1000,7 @@ proc init*(child: AsyncStreamWriter, tsource: StreamTransport) =
|
|||
child.writerLoop = nil
|
||||
child.wsource = nil
|
||||
child.tsource = tsource
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport,
|
||||
|
@ -1087,7 +1010,7 @@ proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport,
|
|||
child.writerLoop = nil
|
||||
child.wsource = nil
|
||||
child.tsource = tsource
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*(child, wsource: AsyncStreamWriter) =
|
||||
|
@ -1096,7 +1019,7 @@ proc init*(child, wsource: AsyncStreamWriter) =
|
|||
child.writerLoop = nil
|
||||
child.wsource = wsource
|
||||
child.tsource = wsource.tsource
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) =
|
||||
|
@ -1108,7 +1031,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) =
|
|||
if not isNil(udata):
|
||||
GC_ref(udata)
|
||||
child.udata = cast[pointer](udata)
|
||||
trackAsyncStreamWriter(child)
|
||||
trackCounter(AsyncStreamWriterTrackerName)
|
||||
child.startWriter()
|
||||
|
||||
proc init*(child: AsyncStreamReader, tsource: StreamTransport) =
|
||||
|
@ -1117,7 +1040,7 @@ proc init*(child: AsyncStreamReader, tsource: StreamTransport) =
|
|||
child.readerLoop = nil
|
||||
child.rsource = nil
|
||||
child.tsource = tsource
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc init*[T](child: AsyncStreamReader, tsource: StreamTransport,
|
||||
|
@ -1130,7 +1053,7 @@ proc init*[T](child: AsyncStreamReader, tsource: StreamTransport,
|
|||
if not isNil(udata):
|
||||
GC_ref(udata)
|
||||
child.udata = cast[pointer](udata)
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc init*(child, rsource: AsyncStreamReader) =
|
||||
|
@ -1139,7 +1062,7 @@ proc init*(child, rsource: AsyncStreamReader) =
|
|||
child.readerLoop = nil
|
||||
child.rsource = rsource
|
||||
child.tsource = rsource.tsource
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) =
|
||||
|
@ -1151,7 +1074,7 @@ proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) =
|
|||
if not isNil(udata):
|
||||
GC_ref(udata)
|
||||
child.udata = cast[pointer](udata)
|
||||
trackAsyncStreamReader(child)
|
||||
trackCounter(AsyncStreamReaderTrackerName)
|
||||
child.startReader()
|
||||
|
||||
proc newAsyncStreamReader*[T](rsource: AsyncStreamReader,
|
||||
|
|
|
@ -53,10 +53,6 @@ type
|
|||
rwsabuf: WSABUF # Reader WSABUF structure
|
||||
wwsabuf: WSABUF # Writer WSABUF structure
|
||||
|
||||
DgramTransportTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
const
|
||||
DgramTransportTrackerName* = "datagram.transport"
|
||||
|
||||
|
@ -88,39 +84,6 @@ template setReadError(t, e: untyped) =
|
|||
(t).state.incl(ReadError)
|
||||
(t).error = getTransportOsError(e)
|
||||
|
||||
proc setupDgramTransportTracker(): DgramTransportTracker {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
proc getDgramTransportTracker(): DgramTransportTracker {.inline.} =
|
||||
var res = cast[DgramTransportTracker](getTracker(DgramTransportTrackerName))
|
||||
if isNil(res):
|
||||
res = setupDgramTransportTracker()
|
||||
doAssert(not(isNil(res)))
|
||||
res
|
||||
|
||||
proc dumpTransportTracking(): string {.gcsafe.} =
|
||||
var tracker = getDgramTransportTracker()
|
||||
"Opened transports: " & $tracker.opened & "\n" &
|
||||
"Closed transports: " & $tracker.closed
|
||||
|
||||
proc leakTransport(): bool {.gcsafe.} =
|
||||
let tracker = getDgramTransportTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackDgram(t: DatagramTransport) {.inline.} =
|
||||
var tracker = getDgramTransportTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackDgram(t: DatagramTransport) {.inline.} =
|
||||
var tracker = getDgramTransportTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} =
|
||||
let res = DgramTransportTracker(
|
||||
opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport)
|
||||
addTracker(DgramTransportTrackerName, res)
|
||||
res
|
||||
|
||||
when defined(windows):
|
||||
template setWriterWSABuffer(t, v: untyped) =
|
||||
(t).wwsabuf.buf = cast[cstring](v.buf)
|
||||
|
@ -213,7 +176,7 @@ when defined(windows):
|
|||
transp.state.incl(ReadPaused)
|
||||
if ReadClosed in transp.state and not(transp.future.finished()):
|
||||
# Stop tracking transport
|
||||
untrackDgram(transp)
|
||||
untrackCounter(DgramTransportTrackerName)
|
||||
# If `ReadClosed` present, then close(transport) was called.
|
||||
transp.future.complete()
|
||||
GC_unref(transp)
|
||||
|
@ -259,7 +222,7 @@ when defined(windows):
|
|||
# WSARecvFrom session.
|
||||
if ReadClosed in transp.state and not(transp.future.finished()):
|
||||
# Stop tracking transport
|
||||
untrackDgram(transp)
|
||||
untrackCounter(DgramTransportTrackerName)
|
||||
transp.future.complete()
|
||||
GC_unref(transp)
|
||||
break
|
||||
|
@ -394,7 +357,7 @@ when defined(windows):
|
|||
len: ULONG(len(res.buffer)))
|
||||
GC_ref(res)
|
||||
# Start tracking transport
|
||||
trackDgram(res)
|
||||
trackCounter(DgramTransportTrackerName)
|
||||
if NoAutoRead notin flags:
|
||||
let rres = res.resumeRead()
|
||||
if rres.isErr(): raiseTransportOsError(rres.error())
|
||||
|
@ -592,7 +555,7 @@ else:
|
|||
res.future = newFuture[void]("datagram.transport")
|
||||
GC_ref(res)
|
||||
# Start tracking transport
|
||||
trackDgram(res)
|
||||
trackCounter(DgramTransportTrackerName)
|
||||
if NoAutoRead notin flags:
|
||||
let rres = res.resumeRead()
|
||||
if rres.isErr(): raiseTransportOsError(rres.error())
|
||||
|
@ -603,7 +566,7 @@ proc close*(transp: DatagramTransport) =
|
|||
proc continuation(udata: pointer) {.raises: [].} =
|
||||
if not(transp.future.finished()):
|
||||
# Stop tracking transport
|
||||
untrackDgram(transp)
|
||||
untrackCounter(DgramTransportTrackerName)
|
||||
transp.future.complete()
|
||||
GC_unref(transp)
|
||||
|
||||
|
|
|
@ -54,15 +54,6 @@ type
|
|||
ReuseAddr,
|
||||
ReusePort
|
||||
|
||||
|
||||
StreamTransportTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
StreamServerTracker* = ref object of TrackerBase
|
||||
opened*: int64
|
||||
closed*: int64
|
||||
|
||||
ReadMessagePredicate* = proc (data: openArray[byte]): tuple[consumed: int,
|
||||
done: bool] {.
|
||||
gcsafe, raises: [].}
|
||||
|
@ -199,71 +190,6 @@ template shiftVectorFile(v: var StreamVector, o: untyped) =
|
|||
(v).buf = cast[pointer](cast[uint]((v).buf) - uint(o))
|
||||
(v).offset += uint(o)
|
||||
|
||||
proc setupStreamTransportTracker(): StreamTransportTracker {.
|
||||
gcsafe, raises: [].}
|
||||
proc setupStreamServerTracker(): StreamServerTracker {.
|
||||
gcsafe, raises: [].}
|
||||
|
||||
proc getStreamTransportTracker(): StreamTransportTracker {.inline.} =
|
||||
var res = cast[StreamTransportTracker](getTracker(StreamTransportTrackerName))
|
||||
if isNil(res):
|
||||
res = setupStreamTransportTracker()
|
||||
doAssert(not(isNil(res)))
|
||||
res
|
||||
|
||||
proc getStreamServerTracker(): StreamServerTracker {.inline.} =
|
||||
var res = cast[StreamServerTracker](getTracker(StreamServerTrackerName))
|
||||
if isNil(res):
|
||||
res = setupStreamServerTracker()
|
||||
doAssert(not(isNil(res)))
|
||||
res
|
||||
|
||||
proc dumpTransportTracking(): string {.gcsafe.} =
|
||||
var tracker = getStreamTransportTracker()
|
||||
"Opened transports: " & $tracker.opened & "\n" &
|
||||
"Closed transports: " & $tracker.closed
|
||||
|
||||
proc dumpServerTracking(): string {.gcsafe.} =
|
||||
var tracker = getStreamServerTracker()
|
||||
"Opened servers: " & $tracker.opened & "\n" &
|
||||
"Closed servers: " & $tracker.closed
|
||||
|
||||
proc leakTransport(): bool {.gcsafe.} =
|
||||
var tracker = getStreamTransportTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc leakServer(): bool {.gcsafe.} =
|
||||
var tracker = getStreamServerTracker()
|
||||
tracker.opened != tracker.closed
|
||||
|
||||
proc trackStream(t: StreamTransport) {.inline.} =
|
||||
var tracker = getStreamTransportTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackStream(t: StreamTransport) {.inline.} =
|
||||
var tracker = getStreamTransportTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc trackServer(s: StreamServer) {.inline.} =
|
||||
var tracker = getStreamServerTracker()
|
||||
inc(tracker.opened)
|
||||
|
||||
proc untrackServer(s: StreamServer) {.inline.} =
|
||||
var tracker = getStreamServerTracker()
|
||||
inc(tracker.closed)
|
||||
|
||||
proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} =
|
||||
let res = StreamTransportTracker(
|
||||
opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport)
|
||||
addTracker(StreamTransportTrackerName, res)
|
||||
res
|
||||
|
||||
proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} =
|
||||
let res = StreamServerTracker(
|
||||
opened: 0, closed: 0, dump: dumpServerTracking, isLeaked: leakServer)
|
||||
addTracker(StreamServerTrackerName, res)
|
||||
res
|
||||
|
||||
proc completePendingWriteQueue(queue: var Deque[StreamVector],
|
||||
v: int) {.inline.} =
|
||||
while len(queue) > 0:
|
||||
|
@ -280,7 +206,7 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector],
|
|||
|
||||
proc clean(server: StreamServer) {.inline.} =
|
||||
if not(server.loopFuture.finished()):
|
||||
untrackServer(server)
|
||||
untrackCounter(StreamServerTrackerName)
|
||||
server.loopFuture.complete()
|
||||
if not(isNil(server.udata)) and (GCUserData in server.flags):
|
||||
GC_unref(cast[ref int](server.udata))
|
||||
|
@ -288,7 +214,7 @@ proc clean(server: StreamServer) {.inline.} =
|
|||
|
||||
proc clean(transp: StreamTransport) {.inline.} =
|
||||
if not(transp.future.finished()):
|
||||
untrackStream(transp)
|
||||
untrackCounter(StreamTransportTrackerName)
|
||||
transp.future.complete()
|
||||
GC_unref(transp)
|
||||
|
||||
|
@ -784,7 +710,7 @@ when defined(windows):
|
|||
else:
|
||||
let transp = newStreamSocketTransport(sock, bufferSize, child)
|
||||
# Start tracking transport
|
||||
trackStream(transp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(transp)
|
||||
else:
|
||||
sock.closeSocket()
|
||||
|
@ -853,7 +779,7 @@ when defined(windows):
|
|||
let transp = newStreamPipeTransport(AsyncFD(pipeHandle),
|
||||
bufferSize, child)
|
||||
# Start tracking transport
|
||||
trackStream(transp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(transp)
|
||||
pipeContinuation(nil)
|
||||
|
||||
|
@ -909,7 +835,7 @@ when defined(windows):
|
|||
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
|
||||
nil, flags)
|
||||
# Start tracking transport
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
asyncSpawn server.function(server, ntransp)
|
||||
of ERROR_OPERATION_ABORTED:
|
||||
# CancelIO() interrupt or close call.
|
||||
|
@ -1013,7 +939,7 @@ when defined(windows):
|
|||
ntransp = newStreamSocketTransport(server.asock,
|
||||
server.bufferSize, nil)
|
||||
# Start tracking transport
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
asyncSpawn server.function(server, ntransp)
|
||||
|
||||
of ERROR_OPERATION_ABORTED:
|
||||
|
@ -1156,7 +1082,7 @@ when defined(windows):
|
|||
ntransp = newStreamSocketTransport(server.asock,
|
||||
server.bufferSize, nil)
|
||||
# Start tracking transport
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(ntransp)
|
||||
of ERROR_OPERATION_ABORTED:
|
||||
# CancelIO() interrupt or close.
|
||||
|
@ -1216,7 +1142,7 @@ when defined(windows):
|
|||
retFuture.fail(getTransportOsError(error))
|
||||
return
|
||||
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(ntransp)
|
||||
|
||||
of ERROR_OPERATION_ABORTED, ERROR_PIPE_NOT_CONNECTED:
|
||||
|
@ -1626,7 +1552,7 @@ else:
|
|||
|
||||
let transp = newStreamSocketTransport(sock, bufferSize, child)
|
||||
# Start tracking transport
|
||||
trackStream(transp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(transp)
|
||||
|
||||
proc cancel(udata: pointer) =
|
||||
|
@ -1639,7 +1565,7 @@ else:
|
|||
if res == 0:
|
||||
let transp = newStreamSocketTransport(sock, bufferSize, child)
|
||||
# Start tracking transport
|
||||
trackStream(transp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(transp)
|
||||
break
|
||||
else:
|
||||
|
@ -1694,7 +1620,7 @@ else:
|
|||
newStreamSocketTransport(sock, server.bufferSize, transp)
|
||||
else:
|
||||
newStreamSocketTransport(sock, server.bufferSize, nil)
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
asyncSpawn server.function(server, ntransp)
|
||||
else:
|
||||
# Client was accepted, so we not going to raise assertion, but
|
||||
|
@ -1782,7 +1708,7 @@ else:
|
|||
else:
|
||||
newStreamSocketTransport(sock, server.bufferSize, nil)
|
||||
# Start tracking transport
|
||||
trackStream(ntransp)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
retFuture.complete(ntransp)
|
||||
else:
|
||||
discard closeFd(cint(sock))
|
||||
|
@ -2098,7 +2024,7 @@ proc createStreamServer*(host: TransportAddress,
|
|||
sres.apending = false
|
||||
|
||||
# Start tracking server
|
||||
trackServer(sres)
|
||||
trackCounter(StreamServerTrackerName)
|
||||
GC_ref(sres)
|
||||
sres
|
||||
|
||||
|
@ -2671,7 +2597,7 @@ proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil,
|
|||
? register2(fd)
|
||||
var res = newStreamPipeTransport(fd, bufferSize, child)
|
||||
# Start tracking transport
|
||||
trackStream(res)
|
||||
trackCounter(StreamTransportTrackerName)
|
||||
ok(res)
|
||||
|
||||
proc fromPipe*(fd: AsyncFD, child: StreamTransport = nil,
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/tables
|
||||
import unittest2
|
||||
import ../../chronos
|
||||
|
||||
|
@ -17,3 +18,14 @@ template asyncTest*(name: string, body: untyped): untyped =
|
|||
proc() {.async, gcsafe.} =
|
||||
body
|
||||
)())
|
||||
|
||||
template checkLeaks*(name: string): untyped =
|
||||
let counter = getTrackerCounter(name)
|
||||
if counter.opened != counter.closed:
|
||||
echo "[" & name & "] opened = ", counter.opened,
|
||||
", closed = ", counter.closed
|
||||
check counter.opened == counter.closed
|
||||
|
||||
template checkLeaks*(): untyped =
|
||||
for key in getThreadDispatcher().trackerCounterKeys():
|
||||
checkLeaks(key)
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
# MIT license (LICENSE-MIT)
|
||||
import unittest2
|
||||
import bearssl/[x509]
|
||||
import ../chronos
|
||||
import ../chronos/streams/[tlsstream, chunkstream, boundstream]
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
import ".."/chronos/streams/[tlsstream, chunkstream, boundstream]
|
||||
|
||||
{.used.}
|
||||
|
||||
|
@ -302,11 +302,7 @@ suite "AsyncStream test suite":
|
|||
check waitFor(testConsume()) == true
|
||||
|
||||
test "AsyncStream(StreamTransport) leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
||||
test "AsyncStream(AsyncStream) readExactly() test":
|
||||
proc testReadExactly2(): Future[bool] {.async.} =
|
||||
|
@ -613,11 +609,7 @@ suite "AsyncStream test suite":
|
|||
check waitFor(testWriteEof()) == true
|
||||
|
||||
test "AsyncStream(AsyncStream) leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
||||
suite "ChunkedStream test suite":
|
||||
test "ChunkedStream test vectors":
|
||||
|
@ -911,11 +903,7 @@ suite "ChunkedStream test suite":
|
|||
check waitFor(testSmallChunk(767309, 4457, 173)) == true
|
||||
|
||||
test "ChunkedStream leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
||||
suite "TLSStream test suite":
|
||||
const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
||||
|
@ -1039,11 +1027,7 @@ suite "TLSStream test suite":
|
|||
check res == "Some message\r\n"
|
||||
|
||||
test "TLSStream leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
||||
suite "BoundedStream test suite":
|
||||
|
||||
|
@ -1411,8 +1395,4 @@ suite "BoundedStream test suite":
|
|||
check waitFor(checkEmptyStreams()) == true
|
||||
|
||||
test "BoundedStream leaks test":
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[strutils, net]
|
||||
import unittest2
|
||||
import ../chronos
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
import ".."/chronos
|
||||
|
||||
{.used.}
|
||||
|
||||
|
@ -558,4 +558,4 @@ suite "Datagram Transport test suite":
|
|||
test "0.0.0.0/::0 (INADDR_ANY) test":
|
||||
check waitFor(testAnyAddress()) == 6
|
||||
test "Transports leak test":
|
||||
check getTracker("datagram.transport").isLeaked() == false
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,8 +6,9 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[strutils, sha1]
|
||||
import unittest2
|
||||
import ../chronos, ../chronos/apps/http/[httpserver, shttpserver, httpclient]
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
import ".."/chronos,
|
||||
".."/chronos/apps/http/[httpserver, shttpserver, httpclient]
|
||||
import stew/base10
|
||||
|
||||
{.used.}
|
||||
|
@ -138,7 +139,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -241,7 +242,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -324,7 +325,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -394,7 +395,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -470,7 +471,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -569,7 +570,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -667,7 +668,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -778,7 +779,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
|
||||
server.start()
|
||||
|
@ -909,7 +910,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
|
||||
server.start()
|
||||
|
@ -971,7 +972,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
|
||||
server.start()
|
||||
|
@ -1125,7 +1126,7 @@ suite "HTTP client testing suite":
|
|||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
|
||||
server.start()
|
||||
|
@ -1262,17 +1263,4 @@ suite "HTTP client testing suite":
|
|||
check waitFor(testServerSentEvents(false)) == true
|
||||
|
||||
test "Leaks test":
|
||||
proc getTrackerLeaks(tracker: string): bool =
|
||||
let tracker = getTracker(tracker)
|
||||
if isNil(tracker): false else: tracker.isLeaked()
|
||||
|
||||
check:
|
||||
getTrackerLeaks("http.body.reader") == false
|
||||
getTrackerLeaks("http.body.writer") == false
|
||||
getTrackerLeaks("httpclient.connection") == false
|
||||
getTrackerLeaks("httpclient.request") == false
|
||||
getTrackerLeaks("httpclient.response") == false
|
||||
getTrackerLeaks("async.stream.reader") == false
|
||||
getTrackerLeaks("async.stream.writer") == false
|
||||
getTrackerLeaks("stream.server") == false
|
||||
getTrackerLeaks("stream.transport") == false
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,10 +6,10 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[strutils, algorithm]
|
||||
import unittest2
|
||||
import ../chronos, ../chronos/apps/http/httpserver,
|
||||
../chronos/apps/http/httpcommon,
|
||||
../chronos/unittest2/asynctests
|
||||
import ".."/chronos/unittest2/asynctests,
|
||||
".."/chronos, ".."/chronos/apps/http/httpserver,
|
||||
".."/chronos/apps/http/httpcommon,
|
||||
".."/chronos/apps/http/httpdebug
|
||||
import stew/base10
|
||||
|
||||
{.used.}
|
||||
|
@ -84,7 +84,7 @@ suite "HTTP server testing suite":
|
|||
# Reraising exception, because processor should properly handle it.
|
||||
raise exc
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -100,14 +100,14 @@ suite "HTTP server testing suite":
|
|||
let request =
|
||||
case operation
|
||||
of GetBodyTest, ConsumeBodyTest, PostUrlTest:
|
||||
"POST / HTTP/1.0\r\n" &
|
||||
"POST / HTTP/1.1\r\n" &
|
||||
"Content-Type: application/x-www-form-urlencoded\r\n" &
|
||||
"Transfer-Encoding: chunked\r\n" &
|
||||
"Cookie: 2\r\n\r\n" &
|
||||
"5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" &
|
||||
"2\r\n9F\r\n0\r\n\r\n"
|
||||
of PostMultipartTest:
|
||||
"POST / HTTP/1.0\r\n" &
|
||||
"POST / HTTP/1.1\r\n" &
|
||||
"Host: 127.0.0.1:30080\r\n" &
|
||||
"Transfer-Encoding: chunked\r\n" &
|
||||
"Content-Type: multipart/form-data; boundary=f98f0\r\n\r\n" &
|
||||
|
@ -134,9 +134,9 @@ suite "HTTP server testing suite":
|
|||
let request = r.get()
|
||||
return await request.respond(Http200, "TEST_OK", HttpTable.init())
|
||||
else:
|
||||
if r.error().error == HttpServerError.TimeoutError:
|
||||
if r.error.kind == HttpServerError.TimeoutError:
|
||||
serverRes = true
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
|
||||
|
@ -148,7 +148,6 @@ suite "HTTP server testing suite":
|
|||
let server = res.get()
|
||||
server.start()
|
||||
let address = server.instance.localAddress()
|
||||
|
||||
let data = await httpClient(address, "")
|
||||
await server.stop()
|
||||
await server.closeWait()
|
||||
|
@ -165,9 +164,9 @@ suite "HTTP server testing suite":
|
|||
let request = r.get()
|
||||
return await request.respond(Http200, "TEST_OK", HttpTable.init())
|
||||
else:
|
||||
if r.error().error == HttpServerError.CriticalError:
|
||||
if r.error.kind == HttpServerError.CriticalError:
|
||||
serverRes = true
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
|
||||
|
@ -195,9 +194,9 @@ suite "HTTP server testing suite":
|
|||
let request = r.get()
|
||||
return await request.respond(Http200, "TEST_OK", HttpTable.init())
|
||||
else:
|
||||
if r.error().error == HttpServerError.CriticalError:
|
||||
if r.error.error == HttpServerError.CriticalError:
|
||||
serverRes = true
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -225,9 +224,9 @@ suite "HTTP server testing suite":
|
|||
if r.isOk():
|
||||
discard
|
||||
else:
|
||||
if r.error().error == HttpServerError.CriticalError:
|
||||
if r.error.error == HttpServerError.CriticalError:
|
||||
serverRes = true
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -280,7 +279,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -321,7 +320,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -367,7 +366,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -411,7 +410,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -456,7 +455,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -512,7 +511,7 @@ suite "HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -576,7 +575,7 @@ suite "HTTP server testing suite":
|
|||
await eventContinue.wait()
|
||||
return await request.respond(Http404, "", HttpTable.init())
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -1247,7 +1246,7 @@ suite "HTTP server testing suite":
|
|||
return response
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
|
@ -1311,7 +1310,7 @@ suite "HTTP server testing suite":
|
|||
let request = r.get()
|
||||
return await request.respond(Http200, "TEST_OK", HttpTable.init())
|
||||
else:
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
for test in TestMessages:
|
||||
let
|
||||
|
@ -1355,9 +1354,78 @@ suite "HTTP server testing suite":
|
|||
await server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Leaks test":
|
||||
asyncTest "HTTP debug tests":
|
||||
const
|
||||
TestsCount = 10
|
||||
TestRequest = "GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n"
|
||||
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isOk():
|
||||
let request = r.get()
|
||||
return await request.respond(Http200, "TEST_OK", HttpTable.init())
|
||||
else:
|
||||
return defaultResponse()
|
||||
|
||||
proc client(address: TransportAddress,
|
||||
data: string): Future[StreamTransport] {.async.} =
|
||||
var transp: StreamTransport
|
||||
var buffer = newSeq[byte](4096)
|
||||
var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
|
||||
try:
|
||||
transp = await connect(address)
|
||||
let wres {.used.} =
|
||||
await transp.write(data)
|
||||
let hres {.used.} =
|
||||
await transp.readUntil(addr buffer[0], len(buffer), sep)
|
||||
transp
|
||||
except CatchableError:
|
||||
if not(isNil(transp)): await transp.closeWait()
|
||||
nil
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
|
||||
serverFlags = {HttpServerFlags.Http11Pipeline},
|
||||
socketFlags = socketFlags)
|
||||
check res.isOk()
|
||||
|
||||
let server = res.get()
|
||||
server.start()
|
||||
let address = server.instance.localAddress()
|
||||
|
||||
let info = server.getServerInfo()
|
||||
|
||||
check:
|
||||
getTracker("async.stream.reader").isLeaked() == false
|
||||
getTracker("async.stream.writer").isLeaked() == false
|
||||
getTracker("stream.server").isLeaked() == false
|
||||
getTracker("stream.transport").isLeaked() == false
|
||||
info.connectionType == ConnectionType.NonSecure
|
||||
info.address == address
|
||||
info.state == HttpServerState.ServerRunning
|
||||
info.flags == {HttpServerFlags.Http11Pipeline}
|
||||
info.socketFlags == socketFlags
|
||||
|
||||
try:
|
||||
var clientFutures: seq[Future[StreamTransport]]
|
||||
for i in 0 ..< TestsCount:
|
||||
clientFutures.add(client(address, TestRequest))
|
||||
await allFutures(clientFutures)
|
||||
|
||||
let connections = server.getConnections()
|
||||
check len(connections) == TestsCount
|
||||
let currentTime = Moment.now()
|
||||
for index, connection in connections.pairs():
|
||||
let transp = clientFutures[index].read()
|
||||
check:
|
||||
connection.remoteAddress.get() == transp.localAddress()
|
||||
connection.localAddress.get() == transp.remoteAddress()
|
||||
connection.connectionType == ConnectionType.NonSecure
|
||||
connection.connectionState == ConnectionState.Alive
|
||||
(currentTime - connection.createMoment.get()) != ZeroDuration
|
||||
(currentTime - connection.acceptMoment) != ZeroDuration
|
||||
var pending: seq[Future[void]]
|
||||
for transpFut in clientFutures:
|
||||
pending.add(closeWait(transpFut.read()))
|
||||
await allFutures(pending)
|
||||
finally:
|
||||
await server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "Leaks test":
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/os
|
||||
import unittest2, stew/[base10, byteutils]
|
||||
import stew/[base10, byteutils]
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
|
||||
when defined(posix):
|
||||
|
@ -414,12 +414,4 @@ suite "Asynchronous process management test suite":
|
|||
check getCurrentFD() == markFD
|
||||
|
||||
test "Leaks test":
|
||||
proc getTrackerLeaks(tracker: string): bool =
|
||||
let tracker = getTracker(tracker)
|
||||
if isNil(tracker): false else: tracker.isLeaked()
|
||||
|
||||
check:
|
||||
getTrackerLeaks("async.process") == false
|
||||
getTrackerLeaks("async.stream.reader") == false
|
||||
getTrackerLeaks("async.stream.writer") == false
|
||||
getTrackerLeaks("stream.transport") == false
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/strutils
|
||||
import unittest2
|
||||
import ../chronos, ../chronos/apps/http/shttpserver
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
import ".."/chronos, ".."/chronos/apps/http/shttpserver
|
||||
import stew/base10
|
||||
|
||||
{.used.}
|
||||
|
@ -115,7 +115,7 @@ suite "Secure HTTP server testing suite":
|
|||
HttpTable.init())
|
||||
else:
|
||||
serverRes = false
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let serverFlags = {Secure}
|
||||
|
@ -154,7 +154,7 @@ suite "Secure HTTP server testing suite":
|
|||
else:
|
||||
serverRes = true
|
||||
testFut.complete()
|
||||
return dumbResponse()
|
||||
return defaultResponse()
|
||||
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let serverFlags = {Secure}
|
||||
|
@ -178,3 +178,6 @@ suite "Secure HTTP server testing suite":
|
|||
return serverRes and data == "EXCEPTION"
|
||||
|
||||
check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true
|
||||
|
||||
test "Leaks test":
|
||||
checkLeaks()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[strutils, os]
|
||||
import unittest2
|
||||
import ".."/chronos/unittest2/asynctests
|
||||
import ".."/chronos, ".."/chronos/[osdefs, oserrno]
|
||||
|
||||
{.used.}
|
||||
|
@ -1370,10 +1370,11 @@ suite "Stream Transport test suite":
|
|||
test prefixes[i] & "close() while in accept() waiting test":
|
||||
check waitFor(testAcceptClose(addresses[i])) == true
|
||||
test prefixes[i] & "Intermediate transports leak test #1":
|
||||
checkLeaks()
|
||||
when defined(windows):
|
||||
skip()
|
||||
else:
|
||||
check getTracker("stream.transport").isLeaked() == false
|
||||
checkLeaks(StreamTransportTrackerName)
|
||||
test prefixes[i] & "accept() too many file descriptors test":
|
||||
when defined(windows):
|
||||
skip()
|
||||
|
@ -1389,10 +1390,8 @@ suite "Stream Transport test suite":
|
|||
check waitFor(testPipe()) == true
|
||||
test "[IP] bind connect to local address":
|
||||
waitFor(testConnectBindLocalAddress())
|
||||
test "Servers leak test":
|
||||
check getTracker("stream.server").isLeaked() == false
|
||||
test "Transports leak test":
|
||||
check getTracker("stream.transport").isLeaked() == false
|
||||
test "Leaks test":
|
||||
checkLeaks()
|
||||
test "File descriptors leak test":
|
||||
when defined(windows):
|
||||
# Windows handle numbers depends on many conditions, so we can't use
|
||||
|
|
Loading…
Reference in New Issue