Merge commit '6c2ea675123ed0bf5c5d76c92ed4985bacd1a9ec' into dev/etan/zz-dbg

This commit is contained in:
Etan Kissling 2023-08-29 14:08:22 +02:00
commit 69a8112054
No known key found for this signature in database
GPG Key ID: B21DA824C5A3D03D
41 changed files with 2734 additions and 1344 deletions

View File

@ -96,7 +96,7 @@ jobs:
- name: Restore Nim DLLs dependencies (Windows) from cache
if: runner.os == 'Windows'
id: windows-dlls-cache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: external/dlls-${{ matrix.target.cpu }}
key: 'dlls-${{ matrix.target.cpu }}'

View File

@ -5,6 +5,5 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import chronos/[asyncloop, asyncsync, handles, transport, timer,
asyncproc, debugutils]
export asyncloop, asyncsync, handles, transport, timer, asyncproc, debugutils
import chronos/[asyncloop, asyncsync, handles, transport, timer, debugutils]
export asyncloop, asyncsync, handles, transport, timer, debugutils

View File

@ -17,6 +17,22 @@ let nimc = getEnv("NIMC", "nim") # Which nim compiler to use
let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js)
let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler
let verbose = getEnv("V", "") notin ["", "0"]
let testArguments =
when defined(windows):
[
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:debug -d:chronosPreviewV4",
"-d:release",
"-d:release -d:chronosPreviewV4"
]
else:
[
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:debug -d:chronosPreviewV4",
"-d:debug -d:chronosDebug -d:chronosEventEngine=poll -d:useSysAssert -d:useGcAssert",
"-d:release",
"-d:release -d:chronosPreviewV4"
]
let styleCheckStyle = if (NimMajor, NimMinor) < (1, 6): "hint" else: "error"
let cfg =
@ -31,12 +47,7 @@ proc run(args, path: string) =
build args & " -r", path
task test, "Run all tests":
for args in [
"-d:debug -d:chronosDebug",
"-d:debug -d:chronosPreviewV4",
"-d:debug -d:chronosDebug -d:useSysAssert -d:useGcAssert",
"-d:release",
"-d:release -d:chronosPreviewV4"]:
for args in testArguments:
run args, "tests/testall"
if (NimMajor, NimMinor) > (1, 6):
run args & " --mm:refc", "tests/testall"

View File

@ -25,71 +25,6 @@ type
bstate*: HttpState
streams*: seq[AsyncStreamWriter]
HttpBodyTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe, raises: [].}
proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe, raises: [].}
proc getHttpBodyWriterTracker(): HttpBodyTracker {.inline.} =
var res = cast[HttpBodyTracker](getTracker(HttpBodyWriterTrackerName))
if isNil(res):
res = setupHttpBodyWriterTracker()
res
proc getHttpBodyReaderTracker(): HttpBodyTracker {.inline.} =
var res = cast[HttpBodyTracker](getTracker(HttpBodyReaderTrackerName))
if isNil(res):
res = setupHttpBodyReaderTracker()
res
proc dumpHttpBodyWriterTracking(): string {.gcsafe.} =
let tracker = getHttpBodyWriterTracker()
"Opened HTTP body writers: " & $tracker.opened & "\n" &
"Closed HTTP body writers: " & $tracker.closed
proc dumpHttpBodyReaderTracking(): string {.gcsafe.} =
let tracker = getHttpBodyReaderTracker()
"Opened HTTP body readers: " & $tracker.opened & "\n" &
"Closed HTTP body readers: " & $tracker.closed
proc leakHttpBodyWriter(): bool {.gcsafe.} =
var tracker = getHttpBodyWriterTracker()
tracker.opened != tracker.closed
proc leakHttpBodyReader(): bool {.gcsafe.} =
var tracker = getHttpBodyReaderTracker()
tracker.opened != tracker.closed
proc trackHttpBodyWriter(t: HttpBodyWriter) {.inline.} =
inc(getHttpBodyWriterTracker().opened)
proc untrackHttpBodyWriter*(t: HttpBodyWriter) {.inline.} =
inc(getHttpBodyWriterTracker().closed)
proc trackHttpBodyReader(t: HttpBodyReader) {.inline.} =
inc(getHttpBodyReaderTracker().opened)
proc untrackHttpBodyReader*(t: HttpBodyReader) {.inline.} =
inc(getHttpBodyReaderTracker().closed)
proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe.} =
var res = HttpBodyTracker(opened: 0, closed: 0,
dump: dumpHttpBodyWriterTracking,
isLeaked: leakHttpBodyWriter
)
addTracker(HttpBodyWriterTrackerName, res)
res
proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe.} =
var res = HttpBodyTracker(opened: 0, closed: 0,
dump: dumpHttpBodyReaderTracking,
isLeaked: leakHttpBodyReader
)
addTracker(HttpBodyReaderTrackerName, res)
res
proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader =
## HttpBodyReader is AsyncStreamReader which holds references to all the
## ``streams``. Also on close it will close all the ``streams``.
@ -98,7 +33,7 @@ proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader =
doAssert(len(streams) > 0, "At least one stream must be added")
var res = HttpBodyReader(bstate: HttpState.Alive, streams: @streams)
res.init(streams[0])
trackHttpBodyReader(res)
trackCounter(HttpBodyReaderTrackerName)
res
proc closeWait*(bstream: HttpBodyReader) {.async.} =
@ -113,7 +48,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async.} =
await allFutures(res)
await procCall(closeWait(AsyncStreamReader(bstream)))
bstream.bstate = HttpState.Closed
untrackHttpBodyReader(bstream)
untrackCounter(HttpBodyReaderTrackerName)
proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter =
## HttpBodyWriter is AsyncStreamWriter which holds references to all the
@ -123,7 +58,7 @@ proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter =
doAssert(len(streams) > 0, "At least one stream must be added")
var res = HttpBodyWriter(bstate: HttpState.Alive, streams: @streams)
res.init(streams[0])
trackHttpBodyWriter(res)
trackCounter(HttpBodyWriterTrackerName)
res
proc closeWait*(bstream: HttpBodyWriter) {.async.} =
@ -136,7 +71,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async.} =
await allFutures(res)
await procCall(closeWait(AsyncStreamWriter(bstream)))
bstream.bstate = HttpState.Closed
untrackHttpBodyWriter(bstream)
untrackCounter(HttpBodyWriterTrackerName)
proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [].} =
if len(bstream.streams) == 1:

View File

@ -108,6 +108,7 @@ type
remoteHostname*: string
flags*: set[HttpClientConnectionFlag]
timestamp*: Moment
duration*: Duration
HttpClientConnectionRef* = ref HttpClientConnection
@ -190,10 +191,6 @@ type
HttpClientFlags* = set[HttpClientFlag]
HttpClientTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
ServerSentEvent* = object
name*: string
data*: string
@ -204,100 +201,6 @@ type
# HttpClientResponseRef valid states are
# Open -> (Finished, Error) -> (Closing, Closed)
proc setupHttpClientConnectionTracker(): HttpClientTracker {.
gcsafe, raises: [].}
proc setupHttpClientRequestTracker(): HttpClientTracker {.
gcsafe, raises: [].}
proc setupHttpClientResponseTracker(): HttpClientTracker {.
gcsafe, raises: [].}
proc getHttpClientConnectionTracker(): HttpClientTracker {.inline.} =
var res = cast[HttpClientTracker](getTracker(HttpClientConnectionTrackerName))
if isNil(res):
res = setupHttpClientConnectionTracker()
res
proc getHttpClientRequestTracker(): HttpClientTracker {.inline.} =
var res = cast[HttpClientTracker](getTracker(HttpClientRequestTrackerName))
if isNil(res):
res = setupHttpClientRequestTracker()
res
proc getHttpClientResponseTracker(): HttpClientTracker {.inline.} =
var res = cast[HttpClientTracker](getTracker(HttpClientResponseTrackerName))
if isNil(res):
res = setupHttpClientResponseTracker()
res
proc dumpHttpClientConnectionTracking(): string {.gcsafe.} =
let tracker = getHttpClientConnectionTracker()
"Opened HTTP client connections: " & $tracker.opened & "\n" &
"Closed HTTP client connections: " & $tracker.closed
proc dumpHttpClientRequestTracking(): string {.gcsafe.} =
let tracker = getHttpClientRequestTracker()
"Opened HTTP client requests: " & $tracker.opened & "\n" &
"Closed HTTP client requests: " & $tracker.closed
proc dumpHttpClientResponseTracking(): string {.gcsafe.} =
let tracker = getHttpClientResponseTracker()
"Opened HTTP client responses: " & $tracker.opened & "\n" &
"Closed HTTP client responses: " & $tracker.closed
proc leakHttpClientConnection(): bool {.gcsafe.} =
var tracker = getHttpClientConnectionTracker()
tracker.opened != tracker.closed
proc leakHttpClientRequest(): bool {.gcsafe.} =
var tracker = getHttpClientRequestTracker()
tracker.opened != tracker.closed
proc leakHttpClientResponse(): bool {.gcsafe.} =
var tracker = getHttpClientResponseTracker()
tracker.opened != tracker.closed
proc trackHttpClientConnection(t: HttpClientConnectionRef) {.inline.} =
inc(getHttpClientConnectionTracker().opened)
proc untrackHttpClientConnection*(t: HttpClientConnectionRef) {.inline.} =
inc(getHttpClientConnectionTracker().closed)
proc trackHttpClientRequest(t: HttpClientRequestRef) {.inline.} =
inc(getHttpClientRequestTracker().opened)
proc untrackHttpClientRequest*(t: HttpClientRequestRef) {.inline.} =
inc(getHttpClientRequestTracker().closed)
proc trackHttpClientResponse(t: HttpClientResponseRef) {.inline.} =
inc(getHttpClientResponseTracker().opened)
proc untrackHttpClientResponse*(t: HttpClientResponseRef) {.inline.} =
inc(getHttpClientResponseTracker().closed)
proc setupHttpClientConnectionTracker(): HttpClientTracker {.gcsafe.} =
var res = HttpClientTracker(opened: 0, closed: 0,
dump: dumpHttpClientConnectionTracking,
isLeaked: leakHttpClientConnection
)
addTracker(HttpClientConnectionTrackerName, res)
res
proc setupHttpClientRequestTracker(): HttpClientTracker {.gcsafe.} =
var res = HttpClientTracker(opened: 0, closed: 0,
dump: dumpHttpClientRequestTracking,
isLeaked: leakHttpClientRequest
)
addTracker(HttpClientRequestTrackerName, res)
res
proc setupHttpClientResponseTracker(): HttpClientTracker {.gcsafe.} =
var res = HttpClientTracker(opened: 0, closed: 0,
dump: dumpHttpClientResponseTracking,
isLeaked: leakHttpClientResponse
)
addTracker(HttpClientResponseTrackerName, res)
res
template checkClosed(reqresp: untyped): untyped =
if reqresp.connection.state in {HttpClientConnectionState.Closing,
HttpClientConnectionState.Closed}:
@ -331,6 +234,12 @@ template setDuration(
reqresp.duration = timestamp - reqresp.timestamp
reqresp.connection.setTimestamp(timestamp)
template setDuration(conn: HttpClientConnectionRef): untyped =
if not(isNil(conn)):
let timestamp = Moment.now()
conn.duration = timestamp - conn.timestamp
conn.setTimestamp(timestamp)
template isReady(conn: HttpClientConnectionRef): bool =
(conn.state == HttpClientConnectionState.Ready) and
(HttpClientConnectionFlag.KeepAlive in conn.flags) and
@ -556,7 +465,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef,
state: HttpClientConnectionState.Connecting,
remoteHostname: ha.id
)
trackHttpClientConnection(res)
trackCounter(HttpClientConnectionTrackerName)
res
of HttpClientScheme.Secure:
let treader = newAsyncStreamReader(transp)
@ -575,7 +484,7 @@ proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef,
state: HttpClientConnectionState.Connecting,
remoteHostname: ha.id
)
trackHttpClientConnection(res)
trackCounter(HttpClientConnectionTrackerName)
res
proc setError(request: HttpClientRequestRef, error: ref HttpError) {.
@ -615,13 +524,13 @@ proc closeWait(conn: HttpClientConnectionRef) {.async.} =
discard
await conn.transp.closeWait()
conn.state = HttpClientConnectionState.Closed
untrackHttpClientConnection(conn)
untrackCounter(HttpClientConnectionTrackerName)
proc connect(session: HttpSessionRef,
ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} =
## Establish new connection with remote server using ``url`` and ``flags``.
## On success returns ``HttpClientConnectionRef`` object.
var lastError = ""
# Here we trying to connect to every possible remote host address we got after
# DNS resolution.
for address in ha.addresses:
@ -645,9 +554,14 @@ proc connect(session: HttpSessionRef,
except CancelledError as exc:
await res.closeWait()
raise exc
except AsyncStreamError:
except TLSStreamProtocolError as exc:
await res.closeWait()
res.state = HttpClientConnectionState.Error
lastError = $exc.msg
except AsyncStreamError as exc:
await res.closeWait()
res.state = HttpClientConnectionState.Error
lastError = $exc.msg
of HttpClientScheme.Nonsecure:
res.state = HttpClientConnectionState.Ready
res
@ -655,7 +569,11 @@ proc connect(session: HttpSessionRef,
return conn
# If all attempts to connect to the remote host have failed.
raiseHttpConnectionError("Could not connect to remote host")
if len(lastError) > 0:
raiseHttpConnectionError("Could not connect to remote host, reason: " &
lastError)
else:
raiseHttpConnectionError("Could not connect to remote host")
proc removeConnection(session: HttpSessionRef,
conn: HttpClientConnectionRef) {.async.} =
@ -685,9 +603,9 @@ proc acquireConnection(
): Future[HttpClientConnectionRef] {.async.} =
## Obtain connection from ``session`` or establish a new one.
var default: seq[HttpClientConnectionRef]
let timestamp = Moment.now()
if session.connectionPoolEnabled(flags):
# Trying to reuse existing connection from our connection's pool.
let timestamp = Moment.now()
# We looking for non-idle connection at `Ready` state, all idle connections
# will be freed by sessionWatcher().
for connection in session.connections.getOrDefault(ha.id):
@ -704,6 +622,8 @@ proc acquireConnection(
connection.state = HttpClientConnectionState.Acquired
session.connections.mgetOrPut(ha.id, default).add(connection)
inc(session.connectionsCount)
connection.setTimestamp(timestamp)
connection.setDuration()
return connection
proc releaseConnection(session: HttpSessionRef,
@ -835,7 +755,7 @@ proc closeWait*(request: HttpClientRequestRef) {.async.} =
request.session = nil
request.error = nil
request.state = HttpReqRespState.Closed
untrackHttpClientRequest(request)
untrackCounter(HttpClientRequestTrackerName)
proc closeWait*(response: HttpClientResponseRef) {.async.} =
if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
@ -848,7 +768,7 @@ proc closeWait*(response: HttpClientResponseRef) {.async.} =
response.session = nil
response.error = nil
response.state = HttpReqRespState.Closed
untrackHttpClientResponse(response)
untrackCounter(HttpClientResponseTrackerName)
proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
): HttpResult[HttpClientResponseRef] {.raises: [] .} =
@ -958,7 +878,7 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
httpPipeline:
res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive)
res.connection.flags.incl(HttpClientConnectionFlag.Response)
trackHttpClientResponse(res)
trackCounter(HttpClientResponseTrackerName)
ok(res)
proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {.
@ -997,7 +917,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
version: version, flags: flags, headers: HttpTable.init(headers),
address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body
)
trackHttpClientRequest(res)
trackCounter(HttpClientRequestTrackerName)
res
proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
@ -1013,7 +933,7 @@ proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,
version: version, flags: flags, headers: HttpTable.init(headers),
address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body
)
trackHttpClientRequest(res)
trackCounter(HttpClientRequestTrackerName)
ok(res)
proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef,

View File

@ -13,6 +13,15 @@ import ../../streams/[asyncstream, boundstream]
export asyncloop, asyncsync, results, httputils, strutils
const
HttpServerUnsecureConnectionTrackerName* =
"httpserver.unsecure.connection"
HttpServerSecureConnectionTrackerName* =
"httpserver.secure.connection"
HttpServerRequestTrackerName* =
"httpserver.request"
HttpServerResponseTrackerName* =
"httpserver.response"
HeadersMark* = @[0x0d'u8, 0x0a'u8, 0x0d'u8, 0x0a'u8]
PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete}

View File

@ -0,0 +1,129 @@
#
# Chronos HTTP/S server implementation
# (c) Copyright 2021-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/tables
import stew/results
import ../../timer
import httpserver, shttpserver
from httpclient import HttpClientScheme
from httpcommon import HttpState
from ../../osdefs import SocketHandle
from ../../transports/common import TransportAddress, ServerFlags
export HttpClientScheme, SocketHandle, TransportAddress, ServerFlags, HttpState
{.push raises: [].}
type
ConnectionType* {.pure.} = enum
NonSecure, Secure
ConnectionState* {.pure.} = enum
Accepted, Alive, Closing, Closed
ServerConnectionInfo* = object
handle*: SocketHandle
connectionType*: ConnectionType
connectionState*: ConnectionState
query*: Opt[string]
remoteAddress*: Opt[TransportAddress]
localAddress*: Opt[TransportAddress]
acceptMoment*: Moment
createMoment*: Opt[Moment]
ServerInfo* = object
connectionType*: ConnectionType
address*: TransportAddress
state*: HttpServerState
maxConnections*: int
backlogSize*: int
baseUri*: Uri
serverIdent*: string
flags*: set[HttpServerFlags]
socketFlags*: set[ServerFlags]
headersTimeout*: Duration
bufferSize*: int
maxHeadersSize*: int
maxRequestBodySize*: int
proc getConnectionType*(
server: HttpServerRef | SecureHttpServerRef): ConnectionType =
when server is SecureHttpServerRef:
ConnectionType.Secure
else:
if HttpServerFlags.Secure in server.flags:
ConnectionType.Secure
else:
ConnectionType.NonSecure
proc getServerInfo*(server: HttpServerRef|SecureHttpServerRef): ServerInfo =
ServerInfo(
connectionType: server.getConnectionType(),
address: server.address,
state: server.state(),
maxConnections: server.maxConnections,
backlogSize: server.backlogSize,
baseUri: server.baseUri,
serverIdent: server.serverIdent,
flags: server.flags,
socketFlags: server.socketFlags,
headersTimeout: server.headersTimeout,
bufferSize: server.bufferSize,
maxHeadersSize: server.maxHeadersSize,
maxRequestBodySize: server.maxRequestBodySize
)
proc getConnectionState*(holder: HttpConnectionHolderRef): ConnectionState =
if not(isNil(holder.connection)):
case holder.connection.state
of HttpState.Alive: ConnectionState.Alive
of HttpState.Closing: ConnectionState.Closing
of HttpState.Closed: ConnectionState.Closed
else:
ConnectionState.Accepted
proc getQueryString*(holder: HttpConnectionHolderRef): Opt[string] =
if not(isNil(holder.connection)):
holder.connection.currentRawQuery
else:
Opt.none(string)
proc init*(t: typedesc[ServerConnectionInfo],
holder: HttpConnectionHolderRef): ServerConnectionInfo =
let
localAddress =
try:
Opt.some(holder.transp.localAddress())
except CatchableError:
Opt.none(TransportAddress)
remoteAddress =
try:
Opt.some(holder.transp.remoteAddress())
except CatchableError:
Opt.none(TransportAddress)
queryString = holder.getQueryString()
ServerConnectionInfo(
handle: SocketHandle(holder.transp.fd),
connectionType: holder.server.getConnectionType(),
connectionState: holder.getConnectionState(),
remoteAddress: remoteAddress,
localAddress: localAddress,
acceptMoment: holder.acceptMoment,
query: queryString,
createMoment:
if not(isNil(holder.connection)):
Opt.some(holder.connection.createMoment)
else:
Opt.none(Moment)
)
proc getConnections*(server: HttpServerRef): seq[ServerConnectionInfo] =
var res: seq[ServerConnectionInfo]
for holder in server.connections.values():
res.add(ServerConnectionInfo.init(holder))
res

View File

@ -25,20 +25,24 @@ type
QueryCommaSeparatedArray
## Enable usage of comma as an array item delimiter in url-encoded
## entities (e.g. query string or POST body).
Http11Pipeline
## Enable HTTP/1.1 pipelining.
HttpServerError* {.pure.} = enum
TimeoutError, CatchableError, RecoverableError, CriticalError,
DisconnectError
InterruptError, TimeoutError, CatchableError, RecoverableError,
CriticalError, DisconnectError
HttpServerState* {.pure.} = enum
ServerRunning, ServerStopped, ServerClosed
HttpProcessError* = object
error*: HttpServerError
kind*: HttpServerError
code*: HttpCode
exc*: ref CatchableError
remote*: TransportAddress
remote*: Opt[TransportAddress]
ConnectionFence* = Result[HttpConnectionRef, HttpProcessError]
ResponseFence* = Result[HttpResponseRef, HttpProcessError]
RequestFence* = Result[HttpRequestRef, HttpProcessError]
HttpRequestFlags* {.pure.} = enum
@ -50,8 +54,11 @@ type
HttpResponseStreamType* {.pure.} = enum
Plain, SSE, Chunked
HttpProcessExitType* {.pure.} = enum
KeepAlive, Graceful, Immediate
HttpResponseState* {.pure.} = enum
Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb
Empty, Prepared, Sending, Finished, Failed, Cancelled, Default
HttpProcessCallback* =
proc(req: RequestFence): Future[HttpResponseRef] {.
@ -62,6 +69,20 @@ type
transp: StreamTransport): Future[HttpConnectionRef] {.
gcsafe, raises: [].}
HttpCloseConnectionCallback* =
proc(connection: HttpConnectionRef): Future[void] {.
gcsafe, raises: [].}
HttpConnectionHolder* = object of RootObj
connection*: HttpConnectionRef
server*: HttpServerRef
future*: Future[void]
transp*: StreamTransport
acceptMoment*: Moment
connectionId*: string
HttpConnectionHolderRef* = ref HttpConnectionHolder
HttpServer* = object of RootObj
instance*: StreamServer
address*: TransportAddress
@ -72,7 +93,7 @@ type
serverIdent*: string
flags*: set[HttpServerFlags]
socketFlags*: set[ServerFlags]
connections*: Table[string, Future[void]]
connections*: OrderedTable[string, HttpConnectionHolderRef]
acceptLoop*: Future[void]
lifetime*: Future[void]
headersTimeout*: Duration
@ -120,11 +141,14 @@ type
HttpConnection* = object of RootObj
state*: HttpState
server*: HttpServerRef
transp: StreamTransport
transp*: StreamTransport
mainReader*: AsyncStreamReader
mainWriter*: AsyncStreamWriter
reader*: AsyncStreamReader
writer*: AsyncStreamWriter
closeCb*: HttpCloseConnectionCallback
createMoment*: Moment
currentRawQuery*: Opt[string]
buffer: seq[byte]
HttpConnectionRef* = ref HttpConnection
@ -132,9 +156,24 @@ type
ByteChar* = string | seq[byte]
proc init(htype: typedesc[HttpProcessError], error: HttpServerError,
exc: ref CatchableError, remote: TransportAddress,
code: HttpCode): HttpProcessError {.raises: [].} =
HttpProcessError(error: error, exc: exc, remote: remote, code: code)
exc: ref CatchableError, remote: Opt[TransportAddress],
code: HttpCode): HttpProcessError {.
raises: [].} =
HttpProcessError(kind: error, exc: exc, remote: remote, code: code)
proc init(htype: typedesc[HttpProcessError],
error: HttpServerError): HttpProcessError {.
raises: [].} =
HttpProcessError(kind: error)
proc new(htype: typedesc[HttpConnectionHolderRef], server: HttpServerRef,
transp: StreamTransport,
connectionId: string): HttpConnectionHolderRef =
HttpConnectionHolderRef(
server: server, transp: transp, acceptMoment: Moment.now(),
connectionId: connectionId)
proc error*(e: HttpProcessError): HttpServerError = e.kind
proc createConnection(server: HttpServerRef,
transp: StreamTransport): Future[HttpConnectionRef] {.
@ -149,7 +188,7 @@ proc new*(htype: typedesc[HttpServerRef],
serverIdent = "",
maxConnections: int = -1,
bufferSize: int = 4096,
backlogSize: int = 100,
backlogSize: int = DefaultBacklogSize,
httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576): HttpResult[HttpServerRef] {.
@ -174,7 +213,7 @@ proc new*(htype: typedesc[HttpServerRef],
return err(exc.msg)
var res = HttpServerRef(
address: address,
address: serverInstance.localAddress(),
instance: serverInstance,
processCallback: processCallback,
createConnCallback: createConnection,
@ -194,10 +233,37 @@ proc new*(htype: typedesc[HttpServerRef],
# else:
# nil
lifetime: newFuture[void]("http.server.lifetime"),
connections: initTable[string, Future[void]]()
connections: initOrderedTable[string, HttpConnectionHolderRef]()
)
ok(res)
proc getServerFlags(req: HttpRequestRef): set[HttpServerFlags] =
var defaultFlags: set[HttpServerFlags] = {}
if isNil(req): return defaultFlags
if isNil(req.connection): return defaultFlags
if isNil(req.connection.server): return defaultFlags
req.connection.server.flags
proc getResponseFlags(req: HttpRequestRef): set[HttpResponseFlags] =
var defaultFlags: set[HttpResponseFlags] = {}
case req.version
of HttpVersion11:
if HttpServerFlags.Http11Pipeline notin req.getServerFlags():
return defaultFlags
let header = req.headers.getString(ConnectionHeader, "keep-alive")
if header == "keep-alive":
{HttpResponseFlags.KeepAlive}
else:
defaultFlags
else:
defaultFlags
proc getResponseVersion(reqFence: RequestFence): HttpVersion {.raises: [].} =
if reqFence.isErr():
HttpVersion11
else:
reqFence.get().version
proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} =
if req.response.isNone():
var resp = HttpResponseRef(
@ -206,10 +272,7 @@ proc getResponse*(req: HttpRequestRef): HttpResponseRef {.raises: [].} =
version: req.version,
headersTable: HttpTable.init(),
connection: req.connection,
flags: if req.version == HttpVersion11:
{HttpResponseFlags.KeepAlive}
else:
{}
flags: req.getResponseFlags()
)
req.response = Opt.some(resp)
resp
@ -222,9 +285,14 @@ proc getHostname*(server: HttpServerRef): string =
else:
server.baseUri.hostname
proc dumbResponse*(): HttpResponseRef {.raises: [].} =
proc defaultResponse*(): HttpResponseRef {.raises: [].} =
## Create an empty response to return when request processor got no request.
HttpResponseRef(state: HttpResponseState.Dumb, version: HttpVersion11)
HttpResponseRef(state: HttpResponseState.Default, version: HttpVersion11)
proc dumbResponse*(): HttpResponseRef {.raises: [],
deprecated: "Please use defaultResponse() instead".} =
## Create an empty response to return when request processor got no request.
defaultResponse()
proc getId(transp: StreamTransport): Result[string, string] {.inline.} =
## Returns string unique transport's identifier as string.
@ -358,6 +426,7 @@ proc prepareRequest(conn: HttpConnectionRef,
if strip(expectHeader).toLowerAscii() == "100-continue":
request.requestFlags.incl(HttpRequestFlags.ClientExpect)
trackCounter(HttpServerRequestTrackerName)
ok(request)
proc getBodyReader*(request: HttpRequestRef): HttpResult[HttpBodyReader] =
@ -566,7 +635,7 @@ proc preferredContentType*(request: HttpRequestRef,
proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
code: HttpCode, keepAlive = true,
datatype = "text/text",
databody = ""): Future[bool] {.async.} =
databody = "") {.async.} =
var answer = $version & " " & $code & "\r\n"
answer.add(DateHeader)
answer.add(": ")
@ -592,13 +661,115 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion,
answer.add(databody)
try:
await conn.writer.write(answer)
return true
except CancelledError as exc:
raise exc
except CatchableError:
# We ignore errors here, because we indicating error already.
discard
proc sendErrorResponse(
conn: HttpConnectionRef,
reqFence: RequestFence,
respError: HttpProcessError
): Future[HttpProcessExitType] {.async.} =
let version = getResponseVersion(reqFence)
try:
if reqFence.isOk():
case respError.kind
of HttpServerError.CriticalError:
await conn.sendErrorResponse(version, respError.code, false)
HttpProcessExitType.Graceful
of HttpServerError.RecoverableError:
await conn.sendErrorResponse(version, respError.code, true)
HttpProcessExitType.Graceful
of HttpServerError.CatchableError:
await conn.sendErrorResponse(version, respError.code, false)
HttpProcessExitType.Graceful
of HttpServerError.DisconnectError,
HttpServerError.InterruptError,
HttpServerError.TimeoutError:
raiseAssert("Unexpected response error: " & $respError.kind)
else:
HttpProcessExitType.Graceful
except CancelledError:
return false
except AsyncStreamWriteError:
return false
except AsyncStreamIncompleteError:
return false
HttpProcessExitType.Immediate
except CatchableError:
HttpProcessExitType.Immediate
proc sendDefaultResponse(
conn: HttpConnectionRef,
reqFence: RequestFence,
response: HttpResponseRef
): Future[HttpProcessExitType] {.async.} =
let
version = getResponseVersion(reqFence)
keepConnection =
if isNil(response) or (HttpResponseFlags.KeepAlive notin response.flags):
HttpProcessExitType.Graceful
else:
HttpProcessExitType.KeepAlive
template toBool(hpet: HttpProcessExitType): bool =
case hpet
of HttpProcessExitType.KeepAlive:
true
of HttpProcessExitType.Immediate:
false
of HttpProcessExitType.Graceful:
false
try:
if reqFence.isOk():
if isNil(response):
await conn.sendErrorResponse(version, Http404, keepConnection.toBool())
keepConnection
else:
case response.state
of HttpResponseState.Empty:
# Response was ignored, so we respond with not found.
await conn.sendErrorResponse(version, Http404,
keepConnection.toBool())
keepConnection
of HttpResponseState.Prepared:
# Response was prepared but not sent, so we can respond with some
# error code
await conn.sendErrorResponse(HttpVersion11, Http409,
keepConnection.toBool())
keepConnection
of HttpResponseState.Sending, HttpResponseState.Failed,
HttpResponseState.Cancelled:
# Just drop connection, because we dont know at what stage we are
HttpProcessExitType.Immediate
of HttpResponseState.Default:
# Response was ignored, so we respond with not found.
await conn.sendErrorResponse(version, Http404,
keepConnection.toBool())
keepConnection
of HttpResponseState.Finished:
keepConnection
else:
case reqFence.error.kind
of HttpServerError.TimeoutError:
await conn.sendErrorResponse(version, reqFence.error.code, false)
HttpProcessExitType.Graceful
of HttpServerError.CriticalError:
await conn.sendErrorResponse(version, reqFence.error.code, false)
HttpProcessExitType.Graceful
of HttpServerError.RecoverableError:
await conn.sendErrorResponse(version, reqFence.error.code, false)
HttpProcessExitType.Graceful
of HttpServerError.CatchableError:
await conn.sendErrorResponse(version, reqFence.error.code, false)
HttpProcessExitType.Graceful
of HttpServerError.DisconnectError:
# When `HttpServerFlags.NotifyDisconnect` is set.
HttpProcessExitType.Immediate
of HttpServerError.InterruptError:
raiseAssert("Unexpected request error: " & $reqFence.error.kind)
except CancelledError:
HttpProcessExitType.Immediate
except CatchableError:
HttpProcessExitType.Immediate
proc getRequest(conn: HttpConnectionRef): Future[HttpRequestRef] {.async.} =
try:
@ -631,31 +802,38 @@ proc init*(value: var HttpConnection, server: HttpServerRef,
mainWriter: newAsyncStreamWriter(transp)
)
proc closeUnsecureConnection(conn: HttpConnectionRef) {.async.} =
if conn.state == HttpState.Alive:
conn.state = HttpState.Closing
var pending: seq[Future[void]]
pending.add(conn.mainReader.closeWait())
pending.add(conn.mainWriter.closeWait())
pending.add(conn.transp.closeWait())
try:
await allFutures(pending)
except CancelledError:
await allFutures(pending)
untrackCounter(HttpServerUnsecureConnectionTrackerName)
reset(conn[])
conn.state = HttpState.Closed
proc new(ht: typedesc[HttpConnectionRef], server: HttpServerRef,
transp: StreamTransport): HttpConnectionRef =
var res = HttpConnectionRef()
res[].init(server, transp)
res.reader = res.mainReader
res.writer = res.mainWriter
res.closeCb = closeUnsecureConnection
res.createMoment = Moment.now()
trackCounter(HttpServerUnsecureConnectionTrackerName)
res
proc closeWait*(conn: HttpConnectionRef) {.async.} =
if conn.state == HttpState.Alive:
conn.state = HttpState.Closing
var pending: seq[Future[void]]
if conn.reader != conn.mainReader:
pending.add(conn.reader.closeWait())
if conn.writer != conn.mainWriter:
pending.add(conn.writer.closeWait())
if len(pending) > 0:
await allFutures(pending)
# After we going to close everything else.
pending.setLen(3)
pending[0] = conn.mainReader.closeWait()
pending[1] = conn.mainWriter.closeWait()
pending[2] = conn.transp.closeWait()
await allFutures(pending)
conn.state = HttpState.Closed
proc gracefulCloseWait*(conn: HttpConnectionRef) {.async.} =
await conn.transp.shutdownWait()
await conn.closeCb(conn)
proc closeWait*(conn: HttpConnectionRef): Future[void] =
conn.closeCb(conn)
proc closeWait*(req: HttpRequestRef) {.async.} =
if req.state == HttpState.Alive:
@ -663,7 +841,14 @@ proc closeWait*(req: HttpRequestRef) {.async.} =
req.state = HttpState.Closing
let resp = req.response.get()
if (HttpResponseFlags.Stream in resp.flags) and not(isNil(resp.writer)):
await resp.writer.closeWait()
var writer = resp.writer.closeWait()
try:
await writer
except CancelledError:
await writer
reset(resp[])
untrackCounter(HttpServerRequestTrackerName)
reset(req[])
req.state = HttpState.Closed
proc createConnection(server: HttpServerRef,
@ -681,175 +866,190 @@ proc `keepalive=`*(resp: HttpResponseRef, value: bool) =
proc keepalive*(resp: HttpResponseRef): bool {.raises: [].} =
HttpResponseFlags.KeepAlive in resp.flags
proc processLoop(server: HttpServerRef, transp: StreamTransport,
connId: string) {.async.} =
var
conn: HttpConnectionRef
connArg: RequestFence
runLoop = false
proc getRemoteAddress(transp: StreamTransport): Opt[TransportAddress] {.
raises: [].} =
if isNil(transp): return Opt.none(TransportAddress)
try:
conn = await server.createConnCallback(server, transp)
runLoop = true
Opt.some(transp.remoteAddress())
except CatchableError:
Opt.none(TransportAddress)
proc getRemoteAddress(connection: HttpConnectionRef): Opt[TransportAddress] {.
raises: [].} =
if isNil(connection): return Opt.none(TransportAddress)
getRemoteAddress(connection.transp)
proc getResponseFence*(connection: HttpConnectionRef,
reqFence: RequestFence): Future[ResponseFence] {.
async.} =
try:
let res = await connection.server.processCallback(reqFence)
ResponseFence.ok(res)
except CancelledError:
server.connections.del(connId)
await transp.closeWait()
return
ResponseFence.err(HttpProcessError.init(
HttpServerError.InterruptError))
except HttpCriticalError as exc:
let error = HttpProcessError.init(HttpServerError.CriticalError, exc,
transp.remoteAddress(), exc.code)
connArg = RequestFence.err(error)
runLoop = false
let address = connection.getRemoteAddress()
ResponseFence.err(HttpProcessError.init(
HttpServerError.CriticalError, exc, address, exc.code))
except HttpRecoverableError as exc:
let address = connection.getRemoteAddress()
ResponseFence.err(HttpProcessError.init(
HttpServerError.RecoverableError, exc, address, exc.code))
except CatchableError as exc:
let address = connection.getRemoteAddress()
ResponseFence.err(HttpProcessError.init(
HttpServerError.CatchableError, exc, address, Http503))
if not(runLoop):
try:
# We still want to notify process callback about failure, but we ignore
# result.
discard await server.processCallback(connArg)
except CancelledError:
runLoop = false
except CatchableError as exc:
# There should be no exceptions, so we will raise `Defect`.
raiseHttpDefect("Unexpected exception catched [" & $exc.name & "]")
proc getResponseFence*(server: HttpServerRef,
connFence: ConnectionFence): Future[ResponseFence] {.
async.} =
doAssert(connFence.isErr())
try:
let
reqFence = RequestFence.err(connFence.error)
res = await server.processCallback(reqFence)
ResponseFence.ok(res)
except CancelledError:
ResponseFence.err(HttpProcessError.init(
HttpServerError.InterruptError))
except HttpCriticalError as exc:
let address = Opt.none(TransportAddress)
ResponseFence.err(HttpProcessError.init(
HttpServerError.CriticalError, exc, address, exc.code))
except HttpRecoverableError as exc:
let address = Opt.none(TransportAddress)
ResponseFence.err(HttpProcessError.init(
HttpServerError.RecoverableError, exc, address, exc.code))
except CatchableError as exc:
let address = Opt.none(TransportAddress)
ResponseFence.err(HttpProcessError.init(
HttpServerError.CatchableError, exc, address, Http503))
var breakLoop = false
while runLoop:
var
arg: RequestFence
resp: HttpResponseRef
try:
let request =
if server.headersTimeout.isInfinite():
await conn.getRequest()
else:
await conn.getRequest().wait(server.headersTimeout)
arg = RequestFence.ok(request)
except CancelledError:
breakLoop = true
except AsyncTimeoutError as exc:
let error = HttpProcessError.init(HttpServerError.TimeoutError, exc,
transp.remoteAddress(), Http408)
arg = RequestFence.err(error)
except HttpRecoverableError as exc:
let error = HttpProcessError.init(HttpServerError.RecoverableError, exc,
transp.remoteAddress(), exc.code)
arg = RequestFence.err(error)
except HttpCriticalError as exc:
let error = HttpProcessError.init(HttpServerError.CriticalError, exc,
transp.remoteAddress(), exc.code)
arg = RequestFence.err(error)
except HttpDisconnectError as exc:
if HttpServerFlags.NotifyDisconnect in server.flags:
let error = HttpProcessError.init(HttpServerError.DisconnectError, exc,
transp.remoteAddress(), Http400)
arg = RequestFence.err(error)
proc getRequestFence*(server: HttpServerRef,
connection: HttpConnectionRef): Future[RequestFence] {.
async.} =
try:
let res =
if server.headersTimeout.isInfinite():
await connection.getRequest()
else:
breakLoop = true
except CatchableError as exc:
let error = HttpProcessError.init(HttpServerError.CatchableError, exc,
transp.remoteAddress(), Http500)
arg = RequestFence.err(error)
await connection.getRequest().wait(server.headersTimeout)
connection.currentRawQuery = Opt.some(res.rawPath)
RequestFence.ok(res)
except CancelledError:
RequestFence.err(HttpProcessError.init(HttpServerError.InterruptError))
except AsyncTimeoutError as exc:
let address = connection.getRemoteAddress()
RequestFence.err(HttpProcessError.init(
HttpServerError.TimeoutError, exc, address, Http408))
except HttpRecoverableError as exc:
let address = connection.getRemoteAddress()
RequestFence.err(HttpProcessError.init(
HttpServerError.RecoverableError, exc, address, exc.code))
except HttpCriticalError as exc:
let address = connection.getRemoteAddress()
RequestFence.err(HttpProcessError.init(
HttpServerError.CriticalError, exc, address, exc.code))
except HttpDisconnectError as exc:
let address = connection.getRemoteAddress()
RequestFence.err(HttpProcessError.init(
HttpServerError.DisconnectError, exc, address, Http400))
except CatchableError as exc:
let address = connection.getRemoteAddress()
RequestFence.err(HttpProcessError.init(
HttpServerError.CatchableError, exc, address, Http500))
if breakLoop:
break
proc getConnectionFence*(server: HttpServerRef,
transp: StreamTransport): Future[ConnectionFence] {.
async.} =
try:
let res = await server.createConnCallback(server, transp)
ConnectionFence.ok(res)
except CancelledError:
ConnectionFence.err(HttpProcessError.init(HttpServerError.InterruptError))
except HttpCriticalError as exc:
# On error `transp` will be closed by `createConnCallback()` call.
let address = Opt.none(TransportAddress)
ConnectionFence.err(HttpProcessError.init(
HttpServerError.CriticalError, exc, address, exc.code))
except CatchableError as exc:
# On error `transp` will be closed by `createConnCallback()` call.
let address = Opt.none(TransportAddress)
ConnectionFence.err(HttpProcessError.init(
HttpServerError.CriticalError, exc, address, Http503))
breakLoop = false
var lastErrorCode: Opt[HttpCode]
try:
resp = await conn.server.processCallback(arg)
except CancelledError:
breakLoop = true
except HttpCriticalError as exc:
lastErrorCode = Opt.some(exc.code)
except HttpRecoverableError as exc:
lastErrorCode = Opt.some(exc.code)
except CatchableError:
lastErrorCode = Opt.some(Http503)
if breakLoop:
break
if arg.isErr():
let code = arg.error().code
try:
case arg.error().error
of HttpServerError.TimeoutError:
discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HttpServerError.RecoverableError:
discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HttpServerError.CriticalError:
discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HttpServerError.CatchableError:
discard await conn.sendErrorResponse(HttpVersion11, code, false)
of HttpServerError.DisconnectError:
discard
except CancelledError:
# We swallowing `CancelledError` in a loop, but we going to exit
# loop ASAP.
discard
break
proc processRequest(server: HttpServerRef,
connection: HttpConnectionRef,
connId: string): Future[HttpProcessExitType] {.async.} =
let requestFence = await getRequestFence(server, connection)
if requestFence.isErr():
case requestFence.error.kind
of HttpServerError.InterruptError:
return HttpProcessExitType.Immediate
of HttpServerError.DisconnectError:
if HttpServerFlags.NotifyDisconnect notin server.flags:
return HttpProcessExitType.Immediate
else:
let request = arg.get()
var keepConn = if request.version == HttpVersion11: true else: false
if lastErrorCode.isNone():
if isNil(resp):
# Response was `nil`.
try:
discard await conn.sendErrorResponse(HttpVersion11, Http404, false)
except CancelledError:
keepConn = false
else:
try:
case resp.state
of HttpResponseState.Empty:
# Response was ignored
discard await conn.sendErrorResponse(HttpVersion11, Http404,
keepConn)
of HttpResponseState.Prepared:
# Response was prepared but not sent.
discard await conn.sendErrorResponse(HttpVersion11, Http409,
keepConn)
else:
# some data was already sent to the client.
keepConn = resp.keepalive()
except CancelledError:
keepConn = false
else:
try:
discard await conn.sendErrorResponse(HttpVersion11,
lastErrorCode.get(), false)
except CancelledError:
keepConn = false
discard
# Closing and releasing all the request resources.
let responseFence = await getResponseFence(connection, requestFence)
if responseFence.isErr() and
(responseFence.error.kind == HttpServerError.InterruptError):
if requestFence.isOk():
await requestFence.get().closeWait()
return HttpProcessExitType.Immediate
let res =
if responseFence.isErr():
await connection.sendErrorResponse(requestFence, responseFence.error)
else:
await connection.sendDefaultResponse(requestFence, responseFence.get())
if requestFence.isOk():
await requestFence.get().closeWait()
res
proc processLoop(holder: HttpConnectionHolderRef) {.async.} =
let
server = holder.server
transp = holder.transp
connectionId = holder.connectionId
connection =
block:
let res = await server.getConnectionFence(transp)
if res.isErr():
if res.error.kind != HttpServerError.InterruptError:
discard await server.getResponseFence(res)
server.connections.del(connectionId)
return
res.get()
holder.connection = connection
var runLoop = HttpProcessExitType.KeepAlive
while runLoop == HttpProcessExitType.KeepAlive:
runLoop =
try:
await request.closeWait()
await server.processRequest(connection, connectionId)
except CancelledError:
# We swallowing `CancelledError` in a loop, but we still need to close
# `request` before exiting.
await request.closeWait()
HttpProcessExitType.Immediate
except CatchableError as exc:
raiseAssert "Unexpected error [" & $exc.name & "] happens: " & $exc.msg
if not(keepConn):
break
# Connection could be `nil` only when secure handshake is failed.
if not(isNil(conn)):
try:
await conn.closeWait()
except CancelledError:
# Cancellation could be happened while we closing `conn`. But we still
# need to close it.
await conn.closeWait()
server.connections.del(connId)
# if server.maxConnections > 0:
# server.semaphore.release()
server.connections.del(connectionId)
case runLoop
of HttpProcessExitType.KeepAlive:
await connection.closeWait()
of HttpProcessExitType.Immediate:
await connection.closeWait()
of HttpProcessExitType.Graceful:
await connection.gracefulCloseWait()
proc acceptClientLoop(server: HttpServerRef) {.async.} =
var breakLoop = false
while true:
var runLoop = true
while runLoop:
try:
# if server.maxConnections > 0:
# await server.semaphore.acquire()
@ -859,28 +1059,18 @@ proc acceptClientLoop(server: HttpServerRef) {.async.} =
# We are unable to identify remote peer, it means that remote peer
# disconnected before identification.
await transp.closeWait()
breakLoop = false
runLoop = false
else:
let connId = resId.get()
server.connections[connId] = processLoop(server, transp, connId)
except CancelledError:
# Server was stopped
breakLoop = true
except TransportOsError:
# This is some critical unrecoverable error.
breakLoop = true
except TransportTooManyError:
# Non critical error
breakLoop = false
except TransportAbortedError:
# Non critical error
breakLoop = false
except CatchableError:
# Unexpected error
breakLoop = true
if breakLoop:
break
let holder = HttpConnectionHolderRef.new(server, transp, resId.get())
server.connections[connId] = holder
holder.future = processLoop(holder)
except TransportTooManyError, TransportAbortedError:
# Non-critical error
discard
except CancelledError, TransportOsError, CatchableError:
# Critical, cancellation or unexpected error
runLoop = false
proc state*(server: HttpServerRef): HttpServerState {.raises: [].} =
## Returns current HTTP server's state.
@ -909,11 +1099,11 @@ proc drop*(server: HttpServerRef) {.async.} =
## Drop all pending HTTP connections.
var pending: seq[Future[void]]
if server.state in {ServerStopped, ServerRunning}:
for fut in server.connections.values():
if not(fut.finished()):
fut.cancel()
pending.add(fut)
for holder in server.connections.values():
if not(isNil(holder.future)) and not(holder.future.finished()):
pending.add(holder.future.cancelAndWait())
await allFutures(pending)
server.connections.clear()
proc closeWait*(server: HttpServerRef) {.async.} =
## Stop HTTP server and drop all the pending connections.

View File

@ -197,3 +197,7 @@ proc toList*(ht: HttpTables, normKey = false): auto =
for key, value in ht.stringItems(normKey):
res.add((key, value))
res
proc clear*(ht: var HttpTables) =
## Resets the HtppTable so that it is empty.
ht.table.clear()

View File

@ -24,6 +24,29 @@ type
SecureHttpConnectionRef* = ref SecureHttpConnection
proc closeSecConnection(conn: HttpConnectionRef) {.async.} =
if conn.state == HttpState.Alive:
conn.state = HttpState.Closing
var pending: seq[Future[void]]
pending.add(conn.writer.closeWait())
pending.add(conn.reader.closeWait())
try:
await allFutures(pending)
except CancelledError:
await allFutures(pending)
# After we going to close everything else.
pending.setLen(3)
pending[0] = conn.mainReader.closeWait()
pending[1] = conn.mainWriter.closeWait()
pending[2] = conn.transp.closeWait()
try:
await allFutures(pending)
except CancelledError:
await allFutures(pending)
reset(cast[SecureHttpConnectionRef](conn)[])
untrackCounter(HttpServerSecureConnectionTrackerName)
conn.state = HttpState.Closed
proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef,
transp: StreamTransport): SecureHttpConnectionRef =
var res = SecureHttpConnectionRef()
@ -37,6 +60,8 @@ proc new*(ht: typedesc[SecureHttpConnectionRef], server: SecureHttpServerRef,
res.tlsStream = tlsStream
res.reader = AsyncStreamReader(tlsStream.reader)
res.writer = AsyncStreamWriter(tlsStream.writer)
res.closeCb = closeSecConnection
trackCounter(HttpServerSecureConnectionTrackerName)
res
proc createSecConnection(server: HttpServerRef,
@ -50,9 +75,16 @@ proc createSecConnection(server: HttpServerRef,
except CancelledError as exc:
await HttpConnectionRef(sconn).closeWait()
raise exc
except TLSStreamError:
except TLSStreamError as exc:
await HttpConnectionRef(sconn).closeWait()
raiseHttpCriticalError("Unable to establish secure connection")
let msg = "Unable to establish secure connection, reason [" &
$exc.msg & "]"
raiseHttpCriticalError(msg)
except CatchableError as exc:
await HttpConnectionRef(sconn).closeWait()
let msg = "Unexpected error while trying to establish secure connection, " &
"reason [" & $exc.msg & "]"
raiseHttpCriticalError(msg)
proc new*(htype: typedesc[SecureHttpServerRef],
address: TransportAddress,
@ -66,7 +98,7 @@ proc new*(htype: typedesc[SecureHttpServerRef],
secureFlags: set[TLSFlags] = {},
maxConnections: int = -1,
bufferSize: int = 4096,
backlogSize: int = 100,
backlogSize: int = DefaultBacklogSize,
httpHeadersTimeout = 10.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576
@ -100,7 +132,7 @@ proc new*(htype: typedesc[SecureHttpServerRef],
createConnCallback: createSecConnection,
baseUri: serverUri,
serverIdent: serverIdent,
flags: serverFlags,
flags: serverFlags + {HttpServerFlags.Secure},
socketFlags: socketFlags,
maxConnections: maxConnections,
bufferSize: bufferSize,
@ -114,7 +146,7 @@ proc new*(htype: typedesc[SecureHttpServerRef],
# else:
# nil
lifetime: newFuture[void]("http.server.lifetime"),
connections: initTable[string, Future[void]](),
connections: initOrderedTable[string, HttpConnectionHolderRef](),
tlsCertificate: tlsCertificate,
tlsPrivateKey: tlsPrivateKey,
secureFlags: secureFlags

View File

@ -171,11 +171,16 @@ type
dump*: proc(): string {.gcsafe, raises: [].}
isLeaked*: proc(): bool {.gcsafe, raises: [].}
TrackerCounter* = object
opened*: uint64
closed*: uint64
PDispatcherBase = ref object of RootRef
timers*: HeapQueue[TimerCallback]
callbacks*: Deque[AsyncCallback]
idlers*: Deque[AsyncCallback]
trackers*: Table[string, TrackerBase]
counters*: Table[string, TrackerCounter]
proc sentinelCallbackImpl(arg: pointer) {.gcsafe.} =
raiseAssert "Sentinel callback MUST not be scheduled"
@ -308,6 +313,7 @@ when defined(windows):
getAcceptExSockAddrs*: WSAPROC_GETACCEPTEXSOCKADDRS
transmitFile*: WSAPROC_TRANSMITFILE
getQueuedCompletionStatusEx*: LPFN_GETQUEUEDCOMPLETIONSTATUSEX
disconnectEx*: WSAPROC_DISCONNECTEX
flags: set[DispatcherFlag]
PtrCustomOverlapped* = ptr CustomOverlapped
@ -388,6 +394,13 @@ when defined(windows):
"dispatcher's TransmitFile()")
loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer)
block:
let res = getFunc(sock, funcPointer, WSAID_DISCONNECTEX)
if not(res):
raiseOsDefect(osLastError(), "initAPI(): Unable to initialize " &
"dispatcher's DisconnectEx()")
loop.disconnectEx = cast[WSAPROC_DISCONNECTEX](funcPointer)
if closeFd(sock) != 0:
raiseOsDefect(osLastError(), "initAPI(): Unable to close control socket")
@ -404,7 +417,8 @@ when defined(windows):
timers: initHeapQueue[TimerCallback](),
callbacks: initDeque[AsyncCallback](64),
idlers: initDeque[AsyncCallback](),
trackers: initTable[string, TrackerBase]()
trackers: initTable[string, TrackerBase](),
counters: initTable[string, TrackerCounter]()
)
res.callbacks.addLast(SentinelCallback)
initAPI(res)
@ -811,10 +825,11 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or
var res = PDispatcher(
selector: selector,
timers: initHeapQueue[TimerCallback](),
callbacks: initDeque[AsyncCallback](asyncEventsCount),
callbacks: initDeque[AsyncCallback](chronosEventsCount),
idlers: initDeque[AsyncCallback](),
keys: newSeq[ReadyKey](asyncEventsCount),
trackers: initTable[string, TrackerBase]()
keys: newSeq[ReadyKey](chronosEventsCount),
trackers: initTable[string, TrackerBase](),
counters: initTable[string, TrackerCounter]()
)
res.callbacks.addLast(SentinelCallback)
initAPI(res)
@ -994,7 +1009,7 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or
## You can execute ``aftercb`` before actual socket close operation.
closeSocket(fd, aftercb)
when asyncEventEngine in ["epoll", "kqueue"]:
when chronosEventEngine in ["epoll", "kqueue"]:
type
ProcessHandle* = distinct int
SignalHandle* = distinct int
@ -1108,7 +1123,7 @@ elif defined(macosx) or defined(freebsd) or defined(netbsd) or
if not isNil(adata.reader.function):
loop.callbacks.addLast(adata.reader)
when asyncEventEngine in ["epoll", "kqueue"]:
when chronosEventEngine in ["epoll", "kqueue"]:
let customSet = {Event.Timer, Event.Signal, Event.Process,
Event.Vnode}
if customSet * events != {}:
@ -1242,10 +1257,7 @@ proc callIdle*(cbproc: CallbackFunc) =
include asyncfutures2
when defined(macosx) or defined(macos) or defined(freebsd) or
defined(netbsd) or defined(openbsd) or defined(dragonfly) or
defined(linux) or defined(windows):
when (chronosEventEngine in ["epoll", "kqueue"]) or defined(windows):
proc waitSignal*(signal: int): Future[void] {.raises: [].} =
var retFuture = newFuture[void]("chronos.waitSignal()")
@ -1505,16 +1517,54 @@ proc waitFor*[T](fut: Future[T]): T {.raises: [CatchableError].} =
fut.read()
proc addTracker*[T](id: string, tracker: T) =
proc addTracker*[T](id: string, tracker: T) {.
deprecated: "Please use trackCounter facility instead".} =
## Add new ``tracker`` object to current thread dispatcher with identifier
## ``id``.
let loop = getThreadDispatcher()
loop.trackers[id] = tracker
getThreadDispatcher().trackers[id] = tracker
proc getTracker*(id: string): TrackerBase =
proc getTracker*(id: string): TrackerBase {.
deprecated: "Please use getTrackerCounter() instead".} =
## Get ``tracker`` from current thread dispatcher using identifier ``id``.
let loop = getThreadDispatcher()
result = loop.trackers.getOrDefault(id, nil)
getThreadDispatcher().trackers.getOrDefault(id, nil)
proc trackCounter*(name: string) {.noinit.} =
## Increase tracker counter with name ``name`` by 1.
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).opened)
proc untrackCounter*(name: string) {.noinit.} =
## Decrease tracker counter with name ``name`` by 1.
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
inc(getThreadDispatcher().counters.mgetOrPut(name, tracker).closed)
proc getTrackerCounter*(name: string): TrackerCounter {.noinit.} =
## Return value of counter with name ``name``.
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
getThreadDispatcher().counters.getOrDefault(name, tracker)
proc isCounterLeaked*(name: string): bool {.noinit.} =
## Returns ``true`` if leak is detected, number of `opened` not equal to
## number of `closed` requests.
let tracker = TrackerCounter(opened: 0'u64, closed: 0'u64)
let res = getThreadDispatcher().counters.getOrDefault(name, tracker)
res.opened == res.closed
iterator trackerCounters*(
loop: PDispatcher
): tuple[name: string, value: TrackerCounter] =
## Iterates over `loop` thread dispatcher tracker counter table, returns all
## the tracker counter's names and values.
doAssert(not(isNil(loop)))
for key, value in loop.counters.pairs():
yield (key, value)
iterator trackerCounterKeys*(loop: PDispatcher): string =
doAssert(not(isNil(loop)))
## Iterates over `loop` thread dispatcher tracker counter table, returns all
## tracker names.
for key in loop.counters.keys():
yield key
when chronosFutureTracking:
iterator pendingFutures*(): FutureBase =

View File

@ -175,9 +175,25 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
nnkElseExpr.newTree(
newStmtList(
quote do: {.push warning[resultshadowed]: off.},
# var result: `baseType`
nnkVarSection.newTree(
nnkIdentDefs.newTree(ident "result", baseType, newEmptyNode())),
# var result {.used.}: `baseType`
# In the proc body, result may or may not end up being used
# depending on how the body is written - with implicit returns /
# expressions in particular, it is likely but not guaranteed that
# it is not used. Ideally, we would avoid emitting it in this
# case to avoid the default initializaiton. {.used.} typically
# works better than {.push.} which has a tendency to leak out of
# scope.
# TODO figure out if there's a way to detect `result` usage in
# the proc body _after_ template exapnsion, and therefore
# avoid creating this variable - one option is to create an
# addtional when branch witha fake `result` and check
# `compiles(procBody)` - this is not without cost though
nnkVarSection.newTree(nnkIdentDefs.newTree(
nnkPragmaExpr.newTree(
ident "result",
nnkPragma.newTree(ident "used")),
baseType, newEmptyNode())
),
quote do: {.pop.},
)
)

View File

@ -23,10 +23,9 @@ const
AsyncProcessTrackerName* = "async.process"
## AsyncProcess leaks tracker name
type
AsyncProcessError* = object of CatchableError
AsyncProcessError* = object of AsyncError
AsyncProcessTimeoutError* = object of AsyncProcessError
AsyncProcessResult*[T] = Result[T, OSErrorCode]
@ -109,49 +108,12 @@ type
stdError*: string
status*: int
AsyncProcessTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
WaitOperation {.pure.} = enum
Kill, Terminate
template Pipe*(t: typedesc[AsyncProcess]): ProcessStreamHandle =
ProcessStreamHandle(kind: ProcessStreamHandleKind.Auto)
proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.}
proc getAsyncProcessTracker(): AsyncProcessTracker {.inline.} =
var res = cast[AsyncProcessTracker](getTracker(AsyncProcessTrackerName))
if isNil(res):
res = setupAsyncProcessTracker()
res
proc dumpAsyncProcessTracking(): string {.gcsafe.} =
var tracker = getAsyncProcessTracker()
let res = "Started async processes: " & $tracker.opened & "\n" &
"Closed async processes: " & $tracker.closed
res
proc leakAsyncProccessTracker(): bool {.gcsafe.} =
var tracker = getAsyncProcessTracker()
tracker.opened != tracker.closed
proc trackAsyncProccess(t: AsyncProcessRef) {.inline.} =
var tracker = getAsyncProcessTracker()
inc(tracker.opened)
proc untrackAsyncProcess(t: AsyncProcessRef) {.inline.} =
var tracker = getAsyncProcessTracker()
inc(tracker.closed)
proc setupAsyncProcessTracker(): AsyncProcessTracker {.gcsafe.} =
var res = AsyncProcessTracker(
opened: 0,
closed: 0,
dump: dumpAsyncProcessTracking,
isLeaked: leakAsyncProccessTracker
)
addTracker(AsyncProcessTrackerName, res)
res
proc init*(t: typedesc[AsyncFD], handle: ProcessStreamHandle): AsyncFD =
case handle.kind
of ProcessStreamHandleKind.ProcHandle:
@ -336,6 +298,11 @@ proc raiseAsyncProcessError(msg: string, exc: ref CatchableError = nil) {.
msg & " ([" & $exc.name & "]: " & $exc.msg & ")"
raise newException(AsyncProcessError, message)
proc raiseAsyncProcessTimeoutError() {.
noreturn, noinit, noinline, raises: [AsyncProcessTimeoutError].} =
let message = "Operation timed out"
raise newException(AsyncProcessTimeoutError, message)
proc raiseAsyncProcessError(msg: string, error: OSErrorCode|cint) {.
noreturn, noinit, noinline, raises: [AsyncProcessError].} =
when error is OSErrorCode:
@ -502,7 +469,7 @@ when defined(windows):
flags: pipes.flags
)
trackAsyncProccess(process)
trackCounter(AsyncProcessTrackerName)
return process
proc peekProcessExitCode(p: AsyncProcessRef): AsyncProcessResult[int] =
@ -919,7 +886,7 @@ else:
flags: pipes.flags
)
trackAsyncProccess(process)
trackCounter(AsyncProcessTrackerName)
return process
proc peekProcessExitCode(p: AsyncProcessRef,
@ -1231,13 +1198,52 @@ proc closeProcessStreams(pipes: AsyncProcessPipes,
res
allFutures(pending)
proc opAndWaitForExit(p: AsyncProcessRef, op: WaitOperation,
timeout = InfiniteDuration): Future[int] {.async.} =
let timerFut =
if timeout == InfiniteDuration:
newFuture[void]("chronos.killAndwaitForExit")
else:
sleepAsync(timeout)
while true:
if p.running().get(true):
# We ignore operation errors because we going to repeat calling
# operation until process will not exit.
case op
of WaitOperation.Kill:
discard p.kill()
of WaitOperation.Terminate:
discard p.terminate()
else:
let exitCode = p.peekExitCode().valueOr:
raiseAsyncProcessError("Unable to peek process exit code", error)
if not(timerFut.finished()):
await cancelAndWait(timerFut)
return exitCode
let waitFut = p.waitForExit().wait(100.milliseconds)
discard await race(FutureBase(waitFut), FutureBase(timerFut))
if waitFut.finished() and not(waitFut.failed()):
let res = p.peekExitCode()
if res.isOk():
if not(timerFut.finished()):
await cancelAndWait(timerFut)
return res.get()
if timerFut.finished():
if not(waitFut.finished()):
await waitFut.cancelAndWait()
raiseAsyncProcessTimeoutError()
proc closeWait*(p: AsyncProcessRef) {.async.} =
# Here we ignore all possible errrors, because we do not want to raise
# exceptions.
discard closeProcessHandles(p.pipes, p.options, OSErrorCode(0))
await p.pipes.closeProcessStreams(p.options)
discard p.closeThreadAndProcessHandle()
untrackAsyncProcess(p)
untrackCounter(AsyncProcessTrackerName)
proc stdinStream*(p: AsyncProcessRef): AsyncStreamWriter =
doAssert(p.pipes.stdinHolder.kind == StreamKind.Writer,
@ -1258,14 +1264,15 @@ proc execCommand*(command: string,
options = {AsyncProcessOption.EvalCommand},
timeout = InfiniteDuration
): Future[int] {.async.} =
let poptions = options + {AsyncProcessOption.EvalCommand}
let process = await startProcess(command, options = poptions)
let res =
try:
await process.waitForExit(timeout)
finally:
await process.closeWait()
return res
let
poptions = options + {AsyncProcessOption.EvalCommand}
process = await startProcess(command, options = poptions)
res =
try:
await process.waitForExit(timeout)
finally:
await process.closeWait()
res
proc execCommandEx*(command: string,
options = {AsyncProcessOption.EvalCommand},
@ -1298,10 +1305,43 @@ proc execCommandEx*(command: string,
finally:
await process.closeWait()
return res
res
proc pid*(p: AsyncProcessRef): int =
## Returns process ``p`` identifier.
int(p.processId)
template processId*(p: AsyncProcessRef): int = pid(p)
proc killAndWaitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] =
## Perform continuous attempts to kill the ``p`` process for specified period
## of time ``timeout``.
##
## On Posix systems, killing means sending ``SIGKILL`` to the process ``p``,
## On Windows, it uses ``TerminateProcess`` to kill the process ``p``.
##
## If the process ``p`` fails to be killed within the ``timeout`` time, it
## will raise ``AsyncProcessTimeoutError``.
##
## In case of error this it will raise ``AsyncProcessError``.
##
## Returns process ``p`` exit code.
opAndWaitForExit(p, WaitOperation.Kill, timeout)
proc terminateAndWaitForExit*(p: AsyncProcessRef,
timeout = InfiniteDuration): Future[int] =
## Perform continuous attempts to terminate the ``p`` process for specified
## period of time ``timeout``.
##
## On Posix systems, terminating means sending ``SIGTERM`` to the process
## ``p``, on Windows, it uses ``TerminateProcess`` to terminate the process
## ``p``.
##
## If the process ``p`` fails to be terminated within the ``timeout`` time, it
## will raise ``AsyncProcessTimeoutError``.
##
## In case of error this it will raise ``AsyncProcessError``.
##
## Returns process ``p`` exit code.
opAndWaitForExit(p, WaitOperation.Terminate, timeout)

View File

@ -49,6 +49,27 @@ when (NimMajor, NimMinor) >= (1, 4):
## using `AsyncProcessOption.EvalCommand` and API calls such as
## ``execCommand(command)`` and ``execCommandEx(command)``.
chronosEventsCount* {.intdefine.} = 64
## Number of OS poll events retrieved by syscall (epoll, kqueue, poll).
chronosInitialSize* {.intdefine.} = 64
## Initial size of Selector[T]'s array of file descriptors.
chronosEventEngine* {.strdefine.}: string =
when defined(linux) and not(defined(android) or defined(emscripten)):
"epoll"
elif defined(macosx) or defined(macos) or defined(ios) or
defined(freebsd) or defined(netbsd) or defined(openbsd) or
defined(dragonfly):
"kqueue"
elif defined(android) or defined(emscripten):
"poll"
elif defined(posix):
"poll"
else:
""
## OS polling engine type which is going to be used by chronos.
else:
# 1.2 doesn't support `booldefine` in `when` properly
const
@ -69,6 +90,21 @@ else:
"/system/bin/sh"
else:
"/bin/sh"
chronosEventsCount*: int = 64
chronosInitialSize*: int = 64
chronosEventEngine* {.strdefine.}: string =
when defined(linux) and not(defined(android) or defined(emscripten)):
"epoll"
elif defined(macosx) or defined(macos) or defined(ios) or
defined(freebsd) or defined(netbsd) or defined(openbsd) or
defined(dragonfly):
"kqueue"
elif defined(android) or defined(emscripten):
"poll"
elif defined(posix):
"poll"
else:
""
when defined(debug) or defined(chronosConfig):
import std/macros
@ -83,3 +119,6 @@ when defined(debug) or defined(chronosConfig):
printOption("chronosFutureTracking", chronosFutureTracking)
printOption("chronosDumpAsync", chronosDumpAsync)
printOption("chronosProcShell", chronosProcShell)
printOption("chronosEventEngine", chronosEventEngine)
printOption("chronosEventsCount", chronosEventsCount)
printOption("chronosInitialSize", chronosInitialSize)

View File

@ -97,12 +97,12 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
var nmask: Sigset
if sigemptyset(nmask) < 0:
return err(osLastError())
let epollFd = epoll_create(asyncEventsCount)
let epollFd = epoll_create(chronosEventsCount)
if epollFd < 0:
return err(osLastError())
let selector = Selector[T](
epollFd: epollFd,
fds: initTable[int32, SelectorKey[T]](asyncInitialSize),
fds: initTable[int32, SelectorKey[T]](chronosInitialSize),
signalMask: nmask,
virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1
childrenExited: false,
@ -627,7 +627,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
readyKeys: var openArray[ReadyKey]
): SelectResult[int] =
var
queueEvents: array[asyncEventsCount, EpollEvent]
queueEvents: array[chronosEventsCount, EpollEvent]
k: int = 0
verifySelectParams(timeout, -1, int(high(cint)))
@ -668,7 +668,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
ok(k)
proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] =
var res = newSeq[ReadyKey](asyncEventsCount)
var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res)
res.setLen(count)
ok(res)

View File

@ -110,7 +110,7 @@ proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
let selector = Selector[T](
kqFd: kqFd,
fds: initTable[int32, SelectorKey[T]](asyncInitialSize),
fds: initTable[int32, SelectorKey[T]](chronosInitialSize),
virtualId: -1'i32, # Should start with -1, because `InvalidIdent` == -1
virtualHoles: initDeque[int32]()
)
@ -559,7 +559,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
): SelectResult[int] =
var
tv: Timespec
queueEvents: array[asyncEventsCount, KEvent]
queueEvents: array[chronosEventsCount, KEvent]
verifySelectParams(timeout, -1, high(int))
@ -575,7 +575,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
addr tv
else:
nil
maxEventsCount = cint(min(asyncEventsCount, len(readyKeys)))
maxEventsCount = cint(min(chronosEventsCount, len(readyKeys)))
eventsCount =
block:
var res = 0
@ -601,7 +601,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
proc select2*[T](s: Selector[T],
timeout: int): Result[seq[ReadyKey], OSErrorCode] =
var res = newSeq[ReadyKey](asyncEventsCount)
var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res)
res.setLen(count)
ok(res)

View File

@ -16,7 +16,7 @@ import stew/base10
type
SelectorImpl[T] = object
fds: Table[int32, SelectorKey[T]]
pollfds: seq[TPollFd]
pollfds: seq[TPollfd]
Selector*[T] = ref SelectorImpl[T]
type
@ -50,7 +50,7 @@ proc freeKey[T](s: Selector[T], key: int32) =
proc new*(t: typedesc[Selector], T: typedesc): SelectResult[Selector[T]] =
let selector = Selector[T](
fds: initTable[int32, SelectorKey[T]](asyncInitialSize)
fds: initTable[int32, SelectorKey[T]](chronosInitialSize)
)
ok(selector)
@ -72,7 +72,7 @@ proc trigger2*(event: SelectEvent): SelectResult[void] =
if res == -1:
err(osLastError())
elif res != sizeof(uint64):
err(OSErrorCode(osdefs.EINVAL))
err(osdefs.EINVAL)
else:
ok()
@ -98,13 +98,14 @@ template toPollEvents(events: set[Event]): cshort =
res
template pollAdd[T](s: Selector[T], sock: cint, events: set[Event]) =
s.pollfds.add(TPollFd(fd: sock, events: toPollEvents(events), revents: 0))
s.pollfds.add(TPollfd(fd: sock, events: toPollEvents(events), revents: 0))
template pollUpdate[T](s: Selector[T], sock: cint, events: set[Event]) =
var updated = false
for mitem in s.pollfds.mitems():
if mitem.fd == sock:
mitem.events = toPollEvents(events)
updated = true
break
if not(updated):
raiseAssert "Descriptor [" & $sock & "] is not registered in the queue!"
@ -177,7 +178,6 @@ proc unregister2*[T](s: Selector[T], event: SelectEvent): SelectResult[void] =
proc prepareKey[T](s: Selector[T], event: var TPollfd): Opt[ReadyKey] =
let
defaultKey = SelectorKey[T](ident: InvalidIdent)
fdi32 = int32(event.fd)
revents = event.revents
@ -224,7 +224,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
eventsCount =
if maxEventsCount > 0:
let res = handleEintr(poll(addr(s.pollfds[0]), Tnfds(maxEventsCount),
timeout))
cint(timeout)))
if res < 0:
return err(osLastError())
res
@ -241,7 +241,7 @@ proc selectInto2*[T](s: Selector[T], timeout: int,
ok(k)
proc select2*[T](s: Selector[T], timeout: int): SelectResult[seq[ReadyKey]] =
var res = newSeq[ReadyKey](asyncEventsCount)
var res = newSeq[ReadyKey](chronosEventsCount)
let count = ? selectInto2(s, timeout, res)
res.setLen(count)
ok(res)

View File

@ -237,6 +237,10 @@ when defined(windows):
GUID(D1: 0xb5367df0'u32, D2: 0xcbac'u16, D3: 0x11cf'u16,
D4: [0x95'u8, 0xca'u8, 0x00'u8, 0x80'u8,
0x5f'u8, 0x48'u8, 0xa1'u8, 0x92'u8])
WSAID_DISCONNECTEX* =
GUID(D1: 0x7fda2e11'u32, D2: 0x8630'u16, D3: 0x436f'u16,
D4: [0xa0'u8, 0x31'u8, 0xf5'u8, 0x36'u8,
0xa6'u8, 0xee'u8, 0xc1'u8, 0x57'u8])
GAA_FLAG_INCLUDE_PREFIX* = 0x0010'u32
@ -497,6 +501,11 @@ when defined(windows):
lpTransmitBuffers: pointer, dwReserved: DWORD): WINBOOL {.
stdcall, gcsafe, raises: [].}
WSAPROC_DISCONNECTEX* = proc (
hSocket: SocketHandle, lpOverlapped: POVERLAPPED, dwFlags: DWORD,
dwReserved: DWORD): WINBOOL {.
stdcall, gcsafe, raises: [].}
LPFN_GETQUEUEDCOMPLETIONSTATUSEX* = proc (
completionPort: HANDLE, lpPortEntries: ptr OVERLAPPED_ENTRY,
ulCount: ULONG, ulEntriesRemoved: var ULONG,
@ -699,7 +708,7 @@ when defined(windows):
res: var ptr AddrInfo): cint {.
stdcall, dynlib: "ws2_32", importc: "getaddrinfo", sideEffect.}
proc freeaddrinfo*(ai: ptr AddrInfo) {.
proc freeAddrInfo*(ai: ptr AddrInfo) {.
stdcall, dynlib: "ws2_32", importc: "freeaddrinfo", sideEffect.}
proc createIoCompletionPort*(fileHandle: HANDLE,
@ -870,16 +879,20 @@ elif defined(macos) or defined(macosx):
setrlimit, getpid, pthread_sigmask, sigprocmask,
sigemptyset, sigaddset, sigismember, fcntl, accept,
pipe, write, signal, read, setsockopt, getsockopt,
getcwd, chdir, waitpid, kill,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK,
SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -891,16 +904,20 @@ elif defined(macos) or defined(macosx):
setrlimit, getpid, pthread_sigmask, sigprocmask,
sigemptyset, sigaddset, sigismember, fcntl, accept,
pipe, write, signal, read, setsockopt, getsockopt,
getcwd, chdir, waitpid, kill,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK,
SIG_BLOCK, SIG_UNBLOCK, SHUT_RD, SHUT_WR, SHUT_RDWR,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -912,6 +929,21 @@ elif defined(macos) or defined(macosx):
numer*: uint32
denom*: uint32
TPollfd* {.importc: "struct pollfd", pure, final,
header: "<poll.h>".} = object
fd*: cint
events*: cshort
revents*: cshort
Tnfds* {.importc: "nfds_t", header: "<poll.h>".} = cuint
const
POLLIN* = 0x0001
POLLOUT* = 0x0004
POLLERR* = 0x0008
POLLHUP* = 0x0010
POLLNVAL* = 0x0020
proc posix_gettimeofday*(tp: var Timeval, unused: pointer = nil) {.
importc: "gettimeofday", header: "<sys/time.h>".}
@ -921,6 +953,9 @@ elif defined(macos) or defined(macosx):
proc mach_absolute_time*(): uint64 {.
importc, header: "<mach/mach_time.h>".}
proc poll*(a1: ptr TPollfd, a2: Tnfds, a3: cint): cint {.
importc, header: "<poll.h>", sideEffect.}
elif defined(linux):
from std/posix import close, shutdown, sigemptyset, sigaddset, sigismember,
sigdelset, write, read, waitid, getaddrinfo,
@ -929,17 +964,22 @@ elif defined(linux):
recvfrom, sendto, send, bindSocket, recv, connect,
unlink, listen, sendmsg, recvmsg, getpid, fcntl,
pthread_sigmask, sigprocmask, clock_gettime, signal,
getcwd, chdir, waitpid, kill,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode,
SigInfo, Id, Tmsghdr, IOVec, RLimit,
SigInfo, Id, Tmsghdr, IOVec, RLimit, Timeval, TFdSet,
SockAddr, SockLen, Sockaddr_storage, Sockaddr_in,
Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle,
Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD,
FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK,
SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_REUSEADDR, SO_REUSEPORT,
MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT,
SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS,
SOCK_DGRAM,
SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -952,17 +992,22 @@ elif defined(linux):
recvfrom, sendto, send, bindSocket, recv, connect,
unlink, listen, sendmsg, recvmsg, getpid, fcntl,
pthread_sigmask, sigprocmask, clock_gettime, signal,
getcwd, chdir, waitpid, kill,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
ClockId, Itimerspec, Timespec, Sigset, Time, Pid, Mode,
SigInfo, Id, Tmsghdr, IOVec, RLimit,
SigInfo, Id, Tmsghdr, IOVec, RLimit, TFdSet, Timeval,
SockAddr, SockLen, Sockaddr_storage, Sockaddr_in,
Sockaddr_in6, Sockaddr_un, AddrInfo, SocketHandle,
Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
CLOCK_MONOTONIC, F_GETFL, F_SETFL, F_GETFD, F_SETFD,
FD_CLOEXEC, O_NONBLOCK, SIG_BLOCK, SIG_UNBLOCK,
SOL_SOCKET, SO_ERROR, RLIMIT_NOFILE, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_REUSEADDR, SO_REUSEPORT,
MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_REUSEADDR, SO_REUSEPORT,
SO_BROADCAST, IPPROTO_IP, IPV6_MULTICAST_HOPS,
SOCK_DGRAM,
SOCK_DGRAM, SOCK_STREAM, SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -1001,13 +1046,22 @@ elif defined(linux):
EPOLL_CTL_DEL* = 2
EPOLL_CTL_MOD* = 3
# https://github.com/torvalds/linux/blob/ff6992735ade75aae3e35d16b17da1008d753d28/include/uapi/linux/eventpoll.h#L77
when defined(linux) and defined(amd64):
{.pragma: epollPacked, packed.}
else:
{.pragma: epollPacked.}
type
EpollData* {.importc: "union epoll_data",
header: "<sys/epoll.h>", pure, final.} = object
EpollData* {.importc: "epoll_data_t",
header: "<sys/epoll.h>", pure, final, union.} = object
`ptr`* {.importc: "ptr".}: pointer
fd* {.importc: "fd".}: cint
u32* {.importc: "u32".}: uint32
u64* {.importc: "u64".}: uint64
EpollEvent* {.importc: "struct epoll_event", header: "<sys/epoll.h>",
pure, final.} = object
EpollEvent* {.importc: "struct epoll_event",
header: "<sys/epoll.h>", pure, final, epollPacked.} = object
events*: uint32 # Epoll events
data*: EpollData # User data variable
@ -1062,16 +1116,22 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
setrlimit, getpid, pthread_sigmask, sigemptyset,
sigaddset, sigismember, fcntl, accept, pipe, write,
signal, read, setsockopt, getsockopt, clock_gettime,
getcwd, chdir, waitpid, kill,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC,
SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,
@ -1083,15 +1143,22 @@ elif defined(freebsd) or defined(openbsd) or defined(netbsd) or
setrlimit, getpid, pthread_sigmask, sigemptyset,
sigaddset, sigismember, fcntl, accept, pipe, write,
signal, read, setsockopt, getsockopt, clock_gettime,
getcwd, chdir, waitpid, kill, select, pselect,
socketpair, poll, freeAddrInfo,
Timeval, Timespec, Pid, Mode, Time, Sigset, SockAddr,
SockLen, Sockaddr_storage, Sockaddr_in, Sockaddr_in6,
Sockaddr_un, SocketHandle, AddrInfo, RLimit,
Sockaddr_un, SocketHandle, AddrInfo, RLimit, TFdSet,
Suseconds, TPollfd, Tnfds,
FD_CLR, FD_ISSET, FD_SET, FD_ZERO,
F_GETFL, F_SETFL, F_GETFD, F_SETFD, FD_CLOEXEC,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, MSG_NOSIGNAL,
AF_INET, AF_INET6, SO_ERROR, SO_REUSEADDR,
O_NONBLOCK, SOL_SOCKET, SOCK_RAW, SOCK_DGRAM,
SOCK_STREAM, MSG_NOSIGNAL, MSG_PEEK,
AF_INET, AF_INET6, AF_UNIX, SO_ERROR, SO_REUSEADDR,
SO_REUSEPORT, SO_BROADCAST, IPPROTO_IP,
IPV6_MULTICAST_HOPS, SOCK_DGRAM, RLIMIT_NOFILE,
SIG_BLOCK, SIG_UNBLOCK, CLOCK_MONOTONIC,
SHUT_RD, SHUT_WR, SHUT_RDWR,
POLLIN, POLLOUT, POLLERR, POLLHUP, POLLNVAL,
SIGHUP, SIGINT, SIGQUIT, SIGILL, SIGTRAP, SIGABRT,
SIGBUS, SIGFPE, SIGKILL, SIGUSR1, SIGSEGV, SIGUSR2,
SIGPIPE, SIGALRM, SIGTERM, SIGPIPE, SIGCHLD, SIGSTOP,

View File

@ -28,13 +28,15 @@ type
pendingRequests: seq[BucketWaiter]
manuallyReplenished: AsyncEvent
proc update(bucket: TokenBucket) =
proc update(bucket: TokenBucket, currentTime: Moment) =
if bucket.fillDuration == default(Duration):
bucket.budget = min(bucket.budgetCap, bucket.budget)
return
if currentTime < bucket.lastUpdate:
return
let
currentTime = Moment.now()
timeDelta = currentTime - bucket.lastUpdate
fillPercent = timeDelta.milliseconds.float / bucket.fillDuration.milliseconds.float
replenished =
@ -46,7 +48,7 @@ proc update(bucket: TokenBucket) =
bucket.lastUpdate += milliseconds(deltaFromReplenished)
bucket.budget = min(bucket.budgetCap, bucket.budget + replenished)
proc tryConsume*(bucket: TokenBucket, tokens: int): bool =
proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool =
## If `tokens` are available, consume them,
## Otherwhise, return false.
@ -54,7 +56,7 @@ proc tryConsume*(bucket: TokenBucket, tokens: int): bool =
bucket.budget -= tokens
return true
bucket.update()
bucket.update(now)
if bucket.budget >= tokens:
bucket.budget -= tokens
@ -93,12 +95,12 @@ proc worker(bucket: TokenBucket) {.async.} =
bucket.workFuture = nil
proc consume*(bucket: TokenBucket, tokens: int): Future[void] =
proc consume*(bucket: TokenBucket, tokens: int, now = Moment.now()): Future[void] =
## Wait for `tokens` to be available, and consume them.
let retFuture = newFuture[void]("TokenBucket.consume")
if isNil(bucket.workFuture) or bucket.workFuture.finished():
if bucket.tryConsume(tokens):
if bucket.tryConsume(tokens, now):
retFuture.complete()
return retFuture
@ -119,10 +121,10 @@ proc consume*(bucket: TokenBucket, tokens: int): Future[void] =
return retFuture
proc replenish*(bucket: TokenBucket, tokens: int) =
proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) =
## Add `tokens` to the budget (capped to the bucket capacity)
bucket.budget += tokens
bucket.update()
bucket.update(now)
bucket.manuallyReplenished.fire()
proc new*(

View File

@ -32,29 +32,9 @@
# backwards-compatible.
import stew/results
import osdefs, osutils, oserrno
import config, osdefs, osutils, oserrno
export results, oserrno
const
asyncEventsCount* {.intdefine.} = 64
## Number of epoll events retrieved by syscall.
asyncInitialSize* {.intdefine.} = 64
## Initial size of Selector[T]'s array of file descriptors.
asyncEventEngine* {.strdefine.} =
when defined(linux):
"epoll"
elif defined(macosx) or defined(macos) or defined(ios) or
defined(freebsd) or defined(netbsd) or defined(openbsd) or
defined(dragonfly):
"kqueue"
elif defined(posix):
"poll"
else:
""
## Engine type which is going to be used by module.
hasThreadSupport = compileOption("threads")
when defined(nimdoc):
type
@ -281,7 +261,9 @@ else:
var err = newException(IOSelectorsException, msg)
raise err
when asyncEventEngine in ["epoll", "kqueue"]:
when chronosEventEngine in ["epoll", "kqueue"]:
const hasThreadSupport = compileOption("threads")
proc blockSignals(newmask: Sigset,
oldmask: var Sigset): Result[void, OSErrorCode] =
var nmask = newmask
@ -324,11 +306,11 @@ else:
doAssert((timeout >= min) and (timeout <= max),
"Cannot select with incorrect timeout value, got " & $timeout)
when asyncEventEngine == "epoll":
when chronosEventEngine == "epoll":
include ./ioselects/ioselectors_epoll
elif asyncEventEngine == "kqueue":
elif chronosEventEngine == "kqueue":
include ./ioselects/ioselectors_kqueue
elif asyncEventEngine == "poll":
elif chronosEventEngine == "poll":
include ./ioselects/ioselectors_poll
else:
{.fatal: "Event engine `" & asyncEventEngine & "` is not supported!".}
{.fatal: "Event engine `" & chronosEventEngine & "` is not supported!".}

View File

@ -38,8 +38,12 @@ when defined(nimdoc):
## be prepared to retry the call if there were unsent bytes.
##
## On error, ``-1`` is returned.
elif defined(emscripten):
elif defined(linux) or defined(android):
proc sendfile*(outfd, infd: int, offset: int, count: var int): int =
raiseAssert "sendfile() is not implemented yet"
elif (defined(linux) or defined(android)) and not(defined(emscripten)):
proc osSendFile*(outfd, infd: cint, offset: ptr int, count: int): int
{.importc: "sendfile", header: "<sys/sendfile.h>".}

View File

@ -96,10 +96,6 @@ type
reader*: AsyncStreamReader
writer*: AsyncStreamWriter
AsyncStreamTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter
proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer =
@ -332,79 +328,6 @@ template checkStreamClosed*(t: untyped) =
template checkStreamFinished*(t: untyped) =
if t.atEof(): raiseAsyncStreamWriteEOFError()
proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.
gcsafe, raises: [].}
proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.
gcsafe, raises: [].}
proc getAsyncStreamReaderTracker(): AsyncStreamTracker {.inline.} =
var res = cast[AsyncStreamTracker](getTracker(AsyncStreamReaderTrackerName))
if isNil(res):
res = setupAsyncStreamReaderTracker()
res
proc getAsyncStreamWriterTracker(): AsyncStreamTracker {.inline.} =
var res = cast[AsyncStreamTracker](getTracker(AsyncStreamWriterTrackerName))
if isNil(res):
res = setupAsyncStreamWriterTracker()
res
proc dumpAsyncStreamReaderTracking(): string {.gcsafe.} =
var tracker = getAsyncStreamReaderTracker()
let res = "Opened async stream readers: " & $tracker.opened & "\n" &
"Closed async stream readers: " & $tracker.closed
res
proc dumpAsyncStreamWriterTracking(): string {.gcsafe.} =
var tracker = getAsyncStreamWriterTracker()
let res = "Opened async stream writers: " & $tracker.opened & "\n" &
"Closed async stream writers: " & $tracker.closed
res
proc leakAsyncStreamReader(): bool {.gcsafe.} =
var tracker = getAsyncStreamReaderTracker()
tracker.opened != tracker.closed
proc leakAsyncStreamWriter(): bool {.gcsafe.} =
var tracker = getAsyncStreamWriterTracker()
tracker.opened != tracker.closed
proc trackAsyncStreamReader(t: AsyncStreamReader) {.inline.} =
var tracker = getAsyncStreamReaderTracker()
inc(tracker.opened)
proc untrackAsyncStreamReader*(t: AsyncStreamReader) {.inline.} =
var tracker = getAsyncStreamReaderTracker()
inc(tracker.closed)
proc trackAsyncStreamWriter(t: AsyncStreamWriter) {.inline.} =
var tracker = getAsyncStreamWriterTracker()
inc(tracker.opened)
proc untrackAsyncStreamWriter*(t: AsyncStreamWriter) {.inline.} =
var tracker = getAsyncStreamWriterTracker()
inc(tracker.closed)
proc setupAsyncStreamReaderTracker(): AsyncStreamTracker {.gcsafe.} =
var res = AsyncStreamTracker(
opened: 0,
closed: 0,
dump: dumpAsyncStreamReaderTracking,
isLeaked: leakAsyncStreamReader
)
addTracker(AsyncStreamReaderTrackerName, res)
res
proc setupAsyncStreamWriterTracker(): AsyncStreamTracker {.gcsafe.} =
var res = AsyncStreamTracker(
opened: 0,
closed: 0,
dump: dumpAsyncStreamWriterTracking,
isLeaked: leakAsyncStreamWriter
)
addTracker(AsyncStreamWriterTrackerName, res)
res
template readLoop(body: untyped): untyped =
while true:
if rstream.buffer.dataLen() == 0:
@ -953,10 +876,10 @@ proc join*(rw: AsyncStreamRW): Future[void] =
else:
var retFuture = newFuture[void]("async.stream.writer.join")
proc continuation(udata: pointer) {.gcsafe.} =
proc continuation(udata: pointer) {.gcsafe, raises:[].} =
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
proc cancellation(udata: pointer) {.gcsafe, raises:[].} =
rw.future.removeCallback(continuation, cast[pointer](retFuture))
if not(rw.future.finished()):
@ -980,9 +903,9 @@ proc close*(rw: AsyncStreamRW) =
if not(rw.future.finished()):
rw.future.complete()
when rw is AsyncStreamReader:
untrackAsyncStreamReader(rw)
untrackCounter(AsyncStreamReaderTrackerName)
elif rw is AsyncStreamWriter:
untrackAsyncStreamWriter(rw)
untrackCounter(AsyncStreamWriterTrackerName)
rw.state = AsyncStreamState.Closed
when rw is AsyncStreamReader:
@ -1031,7 +954,7 @@ proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
child.wsource = wsource
child.tsource = wsource.tsource
child.queue = newAsyncQueue[WriteItem](queueSize)
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
@ -1045,7 +968,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
if not isNil(udata):
GC_ref(udata)
child.udata = cast[pointer](udata)
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
@ -1056,7 +979,7 @@ proc init*(child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
child.rsource = rsource
child.tsource = rsource.tsource
child.buffer = AsyncBuffer.init(bufferSize)
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
@ -1071,7 +994,7 @@ proc init*[T](child, rsource: AsyncStreamReader, loop: StreamReaderLoop,
if not isNil(udata):
GC_ref(udata)
child.udata = cast[pointer](udata)
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc init*(child: AsyncStreamWriter, tsource: StreamTransport) =
@ -1080,7 +1003,7 @@ proc init*(child: AsyncStreamWriter, tsource: StreamTransport) =
child.writerLoop = nil
child.wsource = nil
child.tsource = tsource
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport,
@ -1090,7 +1013,7 @@ proc init*[T](child: AsyncStreamWriter, tsource: StreamTransport,
child.writerLoop = nil
child.wsource = nil
child.tsource = tsource
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*(child, wsource: AsyncStreamWriter) =
@ -1099,7 +1022,7 @@ proc init*(child, wsource: AsyncStreamWriter) =
child.writerLoop = nil
child.wsource = wsource
child.tsource = wsource.tsource
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) =
@ -1111,7 +1034,7 @@ proc init*[T](child, wsource: AsyncStreamWriter, udata: ref T) =
if not isNil(udata):
GC_ref(udata)
child.udata = cast[pointer](udata)
trackAsyncStreamWriter(child)
trackCounter(AsyncStreamWriterTrackerName)
child.startWriter()
proc init*(child: AsyncStreamReader, tsource: StreamTransport) =
@ -1120,7 +1043,7 @@ proc init*(child: AsyncStreamReader, tsource: StreamTransport) =
child.readerLoop = nil
child.rsource = nil
child.tsource = tsource
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc init*[T](child: AsyncStreamReader, tsource: StreamTransport,
@ -1133,7 +1056,7 @@ proc init*[T](child: AsyncStreamReader, tsource: StreamTransport,
if not isNil(udata):
GC_ref(udata)
child.udata = cast[pointer](udata)
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc init*(child, rsource: AsyncStreamReader) =
@ -1142,7 +1065,7 @@ proc init*(child, rsource: AsyncStreamReader) =
child.readerLoop = nil
child.rsource = rsource
child.tsource = rsource.tsource
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) =
@ -1154,7 +1077,7 @@ proc init*[T](child, rsource: AsyncStreamReader, udata: ref T) =
if not isNil(udata):
GC_ref(udata)
child.udata = cast[pointer](udata)
trackAsyncStreamReader(child)
trackCounter(AsyncStreamReaderTrackerName)
child.startReader()
proc newAsyncStreamReader*[T](rsource: AsyncStreamReader,

View File

@ -95,6 +95,7 @@ type
trustAnchors: TrustAnchorStore
SomeTLSStreamType* = TLSStreamReader|TLSStreamWriter|TLSAsyncStream
SomeTrustAnchorType* = TrustAnchorStore | openArray[X509TrustAnchor]
TLSStreamError* = object of AsyncStreamError
TLSStreamHandshakeError* = object of TLSStreamError
@ -139,12 +140,14 @@ proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError =
proc raiseTLSStreamProtocolError[T](message: T) {.noreturn, noinline.} =
raise newTLSStreamProtocolImpl(message)
proc new*(T: typedesc[TrustAnchorStore], anchors: openArray[X509TrustAnchor]): TrustAnchorStore =
proc new*(T: typedesc[TrustAnchorStore],
anchors: openArray[X509TrustAnchor]): TrustAnchorStore =
var res: seq[X509TrustAnchor]
for anchor in anchors:
res.add(anchor)
doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]), "Anchors should be copied")
return TrustAnchorStore(anchors: res)
doAssert(unsafeAddr(anchor) != unsafeAddr(res[^1]),
"Anchors should be copied")
TrustAnchorStore(anchors: res)
proc tlsWriteRec(engine: ptr SslEngineContext,
writer: TLSStreamWriter): Future[TLSResult] {.async.} =
@ -453,15 +456,16 @@ proc getSignerAlgo(xc: X509Certificate): int =
else:
int(x509DecoderGetSignerKeyType(dc))
proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
wsource: AsyncStreamWriter,
serverName: string,
bufferSize = SSL_BUFSIZE_BIDI,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
flags: set[TLSFlags] = {},
trustAnchors: TrustAnchorStore | openArray[X509TrustAnchor] = MozillaTrustAnchors
): TLSAsyncStream =
proc newTLSClientAsyncStream*(
rsource: AsyncStreamReader,
wsource: AsyncStreamWriter,
serverName: string,
bufferSize = SSL_BUFSIZE_BIDI,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
flags: set[TLSFlags] = {},
trustAnchors: SomeTrustAnchorType = MozillaTrustAnchors
): TLSAsyncStream =
## Create new TLS asynchronous stream for outbound (client) connections
## using reading stream ``rsource`` and writing stream ``wsource``.
##
@ -484,7 +488,8 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
## a ``TrustAnchorStore`` you should reuse the same instance for
## every call to avoid making a copy of the trust anchors per call.
when trustAnchors is TrustAnchorStore:
doAssert(len(trustAnchors.anchors) > 0, "Empty trust anchor list is invalid")
doAssert(len(trustAnchors.anchors) > 0,
"Empty trust anchor list is invalid")
else:
doAssert(len(trustAnchors) > 0, "Empty trust anchor list is invalid")
var res = TLSAsyncStream()
@ -524,7 +529,7 @@ proc newTLSClientAsyncStream*(rsource: AsyncStreamReader,
uint16(maxVersion))
if TLSFlags.NoVerifyServerName in flags:
let err = sslClientReset(res.ccontext, "", 0)
let err = sslClientReset(res.ccontext, nil, 0)
if err == 0:
raise newException(TLSStreamInitError, "Could not initialize TLS layer")
else:

416
chronos/threadsync.nim Normal file
View File

@ -0,0 +1,416 @@
#
# Chronos multithreaded synchronization primitives
#
# (c) Copyright 2023-Present Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
## This module implements some core async thread synchronization primitives.
import stew/results
import "."/[timer, asyncloop]
export results
{.push raises: [].}
const hasThreadSupport* = compileOption("threads")
when not(hasThreadSupport):
{.fatal: "Compile this program with threads enabled!".}
import "."/[osdefs, osutils, oserrno]
type
ThreadSignal* = object
when defined(windows):
event: HANDLE
elif defined(linux):
efd: AsyncFD
else:
rfd, wfd: AsyncFD
ThreadSignalPtr* = ptr ThreadSignal
proc new*(t: typedesc[ThreadSignalPtr]): Result[ThreadSignalPtr, string] =
## Create new ThreadSignal object.
let res = cast[ptr ThreadSignal](allocShared0(sizeof(ThreadSignal)))
when defined(windows):
var sa = getSecurityAttributes()
let event = osdefs.createEvent(addr sa, DWORD(0), DWORD(0), nil)
if event == HANDLE(0):
deallocShared(res)
return err(osErrorMsg(osLastError()))
res[] = ThreadSignal(event: event)
elif defined(linux):
let efd = eventfd(0, EFD_CLOEXEC or EFD_NONBLOCK)
if efd == -1:
deallocShared(res)
return err(osErrorMsg(osLastError()))
res[] = ThreadSignal(efd: AsyncFD(efd))
else:
var sockets: array[2, cint]
block:
let sres = socketpair(AF_UNIX, SOCK_DGRAM, 0, sockets)
if sres < 0:
deallocShared(res)
return err(osErrorMsg(osLastError()))
# MacOS do not have SOCK_NONBLOCK and SOCK_CLOEXEC, so we forced to use
# setDescriptorFlags() for every socket.
block:
let sres = setDescriptorFlags(sockets[0], true, true)
if sres.isErr():
discard closeFd(sockets[0])
discard closeFd(sockets[1])
deallocShared(res)
return err(osErrorMsg(sres.error))
block:
let sres = setDescriptorFlags(sockets[1], true, true)
if sres.isErr():
discard closeFd(sockets[0])
discard closeFd(sockets[1])
deallocShared(res)
return err(osErrorMsg(sres.error))
res[] = ThreadSignal(rfd: AsyncFD(sockets[0]), wfd: AsyncFD(sockets[1]))
ok(ThreadSignalPtr(res))
when not(defined(windows)):
type
WaitKind {.pure.} = enum
Read, Write
when defined(linux):
proc checkBusy(fd: cint): bool = false
else:
proc checkBusy(fd: cint): bool =
var data = 0'u64
let res = handleEintr(recv(SocketHandle(fd),
addr data, sizeof(uint64), MSG_PEEK))
if res == sizeof(uint64):
true
else:
false
func toTimeval(a: Duration): Timeval =
## Convert Duration ``a`` to ``Timeval`` object.
let nanos = a.nanoseconds
let m = nanos mod Second.nanoseconds()
Timeval(
tv_sec: Time(nanos div Second.nanoseconds()),
tv_usec: Suseconds(m div Microsecond.nanoseconds())
)
proc waitReady(fd: cint, kind: WaitKind,
timeout: Duration): Result[bool, OSErrorCode] =
var
tv: Timeval
fdset =
block:
var res: TFdSet
FD_ZERO(res)
FD_SET(SocketHandle(fd), res)
res
let
ptv =
if not(timeout.isInfinite()):
tv = timeout.toTimeval()
addr tv
else:
nil
nfd = cint(fd) + 1
res =
case kind
of WaitKind.Read:
handleEintr(select(nfd, addr fdset, nil, nil, ptv))
of WaitKind.Write:
handleEintr(select(nfd, nil, addr fdset, nil, ptv))
if res > 0:
ok(true)
elif res == 0:
ok(false)
else:
err(osLastError())
proc safeUnregisterAndCloseFd(fd: AsyncFD): Result[void, OSErrorCode] =
let loop = getThreadDispatcher()
if loop.contains(fd):
? unregister2(fd)
if closeFd(cint(fd)) != 0:
err(osLastError())
else:
ok()
proc close*(signal: ThreadSignalPtr): Result[void, string] =
## Close ThreadSignal object and free all the resources.
defer: deallocShared(signal)
when defined(windows):
# We do not need to perform unregistering on Windows, we can only close it.
if closeHandle(signal[].event) == 0'u32:
return err(osErrorMsg(osLastError()))
elif defined(linux):
let res = safeUnregisterAndCloseFd(signal[].efd)
if res.isErr():
return err(osErrorMsg(res.error))
else:
let res1 = safeUnregisterAndCloseFd(signal[].rfd)
let res2 = safeUnregisterAndCloseFd(signal[].wfd)
if res1.isErr(): return err(osErrorMsg(res1.error))
if res2.isErr(): return err(osErrorMsg(res2.error))
ok()
proc fireSync*(signal: ThreadSignalPtr,
timeout = InfiniteDuration): Result[bool, string] =
## Set state of ``signal`` to signaled state in blocking way.
##
## Returns ``false`` if signal was not signalled in time, and ``true``
## if operation was successful.
when defined(windows):
if setEvent(signal[].event) == 0'u32:
return err(osErrorMsg(osLastError()))
ok(true)
else:
let
eventFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].wfd)
checkFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].rfd)
if checkBusy(checkFd):
# Signal is already in signalled state
return ok(true)
var data = 1'u64
while true:
let res =
when defined(linux):
handleEintr(write(eventFd, addr data, sizeof(uint64)))
else:
handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64),
MSG_NOSIGNAL))
if res < 0:
let errorCode = osLastError()
case errorCode
of EAGAIN:
let wres = waitReady(eventFd, WaitKind.Write, timeout)
if wres.isErr():
return err(osErrorMsg(wres.error))
if not(wres.get()):
return ok(false)
else:
return err(osErrorMsg(errorCode))
elif res != sizeof(data):
return err(osErrorMsg(EINVAL))
else:
return ok(true)
proc waitSync*(signal: ThreadSignalPtr,
timeout = InfiniteDuration): Result[bool, string] =
## Wait until the signal become signaled. This procedure is ``NOT`` async,
## so it blocks execution flow, but this procedure do not need asynchronous
## event loop to be present.
when defined(windows):
let
timeoutWin =
if timeout.isInfinite():
INFINITE
else:
DWORD(timeout.milliseconds())
handle = signal[].event
res = waitForSingleObject(handle, timeoutWin)
if res == WAIT_OBJECT_0:
ok(true)
elif res == WAIT_TIMEOUT:
ok(false)
elif res == WAIT_ABANDONED:
err("The wait operation has been abandoned")
else:
err("The wait operation has been failed")
else:
let eventFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].rfd)
var
data = 0'u64
timer = timeout
while true:
let wres =
block:
let
start = Moment.now()
res = waitReady(eventFd, WaitKind.Read, timer)
timer = timer - (Moment.now() - start)
res
if wres.isErr():
return err(osErrorMsg(wres.error))
if not(wres.get()):
return ok(false)
let res =
when defined(linux):
handleEintr(read(eventFd, addr data, sizeof(uint64)))
else:
handleEintr(recv(SocketHandle(eventFd), addr data, sizeof(uint64),
cint(0)))
if res < 0:
let errorCode = osLastError()
# If errorCode == EAGAIN it means that reading operation is already
# pending and so some other consumer reading eventfd or pipe end, in
# this case we going to ignore error and wait for another event.
if errorCode != EAGAIN:
return err(osErrorMsg(errorCode))
elif res != sizeof(data):
return err(osErrorMsg(EINVAL))
else:
return ok(true)
proc fire*(signal: ThreadSignalPtr): Future[void] =
## Set state of ``signal`` to signaled in asynchronous way.
var retFuture = newFuture[void]("asyncthreadsignal.fire")
when defined(windows):
if setEvent(signal[].event) == 0'u32:
retFuture.fail(newException(AsyncError, osErrorMsg(osLastError())))
else:
retFuture.complete()
else:
var data = 1'u64
let
eventFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].wfd)
checkFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].rfd)
proc continuation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
let res =
when defined(linux):
handleEintr(write(eventFd, addr data, sizeof(uint64)))
else:
handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64),
MSG_NOSIGNAL))
if res < 0:
let errorCode = osLastError()
discard removeWriter2(AsyncFD(eventFd))
retFuture.fail(newException(AsyncError, osErrorMsg(errorCode)))
elif res != sizeof(data):
discard removeWriter2(AsyncFD(eventFd))
retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL)))
else:
let eres = removeWriter2(AsyncFD(eventFd))
if eres.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(eres.error)))
else:
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
discard removeWriter2(AsyncFD(eventFd))
if checkBusy(checkFd):
# Signal is already in signalled state
retFuture.complete()
return retFuture
let res =
when defined(linux):
handleEintr(write(eventFd, addr data, sizeof(uint64)))
else:
handleEintr(send(SocketHandle(eventFd), addr data, sizeof(uint64),
MSG_NOSIGNAL))
if res < 0:
let errorCode = osLastError()
case errorCode
of EAGAIN:
let loop = getThreadDispatcher()
if not(loop.contains(AsyncFD(eventFd))):
let rres = register2(AsyncFD(eventFd))
if rres.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(rres.error)))
return retFuture
let wres = addWriter2(AsyncFD(eventFd), continuation)
if wres.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(wres.error)))
else:
retFuture.cancelCallback = cancellation
else:
retFuture.fail(newException(AsyncError, osErrorMsg(errorCode)))
elif res != sizeof(data):
retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL)))
else:
retFuture.complete()
retFuture
when defined(windows):
proc wait*(signal: ThreadSignalPtr) {.async.} =
let handle = signal[].event
let res = await waitForSingleObject(handle, InfiniteDuration)
# There should be no other response, because we use `InfiniteDuration`.
doAssert(res == WaitableResult.Ok)
else:
proc wait*(signal: ThreadSignalPtr): Future[void] =
var retFuture = newFuture[void]("asyncthreadsignal.wait")
var data = 1'u64
let eventFd =
when defined(linux):
cint(signal[].efd)
else:
cint(signal[].rfd)
proc continuation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
let res =
when defined(linux):
handleEintr(read(eventFd, addr data, sizeof(uint64)))
else:
handleEintr(recv(SocketHandle(eventFd), addr data, sizeof(uint64),
cint(0)))
if res < 0:
let errorCode = osLastError()
# If errorCode == EAGAIN it means that reading operation is already
# pending and so some other consumer reading eventfd or pipe end, in
# this case we going to ignore error and wait for another event.
if errorCode != EAGAIN:
discard removeReader2(AsyncFD(eventFd))
retFuture.fail(newException(AsyncError, osErrorMsg(errorCode)))
elif res != sizeof(data):
discard removeReader2(AsyncFD(eventFd))
retFuture.fail(newException(AsyncError, osErrorMsg(EINVAL)))
else:
let eres = removeReader2(AsyncFD(eventFd))
if eres.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(eres.error)))
else:
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe, raises: [].} =
if not(retFuture.finished()):
# Future is already cancelled so we ignore errors.
discard removeReader2(AsyncFD(eventFd))
let loop = getThreadDispatcher()
if not(loop.contains(AsyncFD(eventFd))):
let res = register2(AsyncFD(eventFd))
if res.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(res.error)))
return retFuture
let res = addReader2(AsyncFD(eventFd), continuation)
if res.isErr():
retFuture.fail(newException(AsyncError, osErrorMsg(res.error)))
return retFuture
retFuture.cancelCallback = cancellation
retFuture

View File

@ -298,6 +298,9 @@ proc getAddrInfo(address: string, port: Port, domain: Domain,
raises: [TransportAddressError].} =
## We have this one copy of ``getAddrInfo()`` because of AI_V4MAPPED in
## ``net.nim:getAddrInfo()``, which is not cross-platform.
##
## Warning: `ptr AddrInfo` returned by `getAddrInfo()` needs to be freed by
## calling `freeAddrInfo()`.
var hints: AddrInfo
var res: ptr AddrInfo = nil
hints.ai_family = toInt(domain)
@ -420,6 +423,7 @@ proc resolveTAddress*(address: string, port: Port,
if ta notin res:
res.add(ta)
it = it.ai_next
freeAddrInfo(aiList)
res
proc resolveTAddress*(address: string, domain: Domain): seq[TransportAddress] {.
@ -574,10 +578,8 @@ template getTransportUseClosedError*(): ref TransportUseClosedError =
newException(TransportUseClosedError, "Transport is already closed!")
template getTransportOsError*(err: OSErrorCode): ref TransportOsError =
var msg = "(" & $int(err) & ") " & osErrorMsg(err)
var tre = newException(TransportOsError, msg)
tre.code = err
tre
(ref TransportOsError)(
code: err, msg: "(" & $int(err) & ") " & osErrorMsg(err))
template getTransportOsError*(err: cint): ref TransportOsError =
getTransportOsError(OSErrorCode(err))
@ -608,15 +610,16 @@ template getTransportTooManyError*(
): ref TransportTooManyError =
let msg =
when defined(posix):
if code == OSErrorCode(0):
case code
of OSErrorCode(0):
"Too many open transports"
elif code == oserrno.EMFILE:
of EMFILE:
"[EMFILE] Too many open files in the process"
elif code == oserrno.ENFILE:
of ENFILE:
"[ENFILE] Too many open files in system"
elif code == oserrno.ENOBUFS:
of ENOBUFS:
"[ENOBUFS] No buffer space available"
elif code == oserrno.ENOMEM:
of ENOMEM:
"[ENOMEM] Not enough memory availble"
else:
"[" & $int(code) & "] Too many open transports"
@ -649,23 +652,26 @@ template getConnectionAbortedError*(
): ref TransportAbortedError =
let msg =
when defined(posix):
if code == OSErrorCode(0):
case code
of OSErrorCode(0), ECONNABORTED:
"[ECONNABORTED] Connection has been aborted before being accepted"
elif code == oserrno.EPERM:
of EPERM:
"[EPERM] Firewall rules forbid connection"
elif code == oserrno.ETIMEDOUT:
of ETIMEDOUT:
"[ETIMEDOUT] Operation has been timed out"
of ENOTCONN:
"[ENOTCONN] Transport endpoint is not connected"
else:
"[" & $int(code) & "] Connection has been aborted"
elif defined(windows):
case code
of OSErrorCode(0), oserrno.WSAECONNABORTED:
of OSErrorCode(0), WSAECONNABORTED:
"[ECONNABORTED] Connection has been aborted before being accepted"
of WSAENETDOWN:
"[ENETDOWN] Network is down"
of oserrno.WSAENETRESET:
of WSAENETRESET:
"[ENETRESET] Network dropped connection on reset"
of oserrno.WSAECONNRESET:
of WSAECONNRESET:
"[ECONNRESET] Connection reset by peer"
of WSAETIMEDOUT:
"[ETIMEDOUT] Connection timed out"
@ -675,3 +681,42 @@ template getConnectionAbortedError*(
"[" & $int(code) & "] Connection has been aborted"
newException(TransportAbortedError, msg)
template getTransportError*(ecode: OSErrorCode): untyped =
when defined(posix):
case ecode
of ECONNABORTED, EPERM, ETIMEDOUT, ENOTCONN:
getConnectionAbortedError(ecode)
of EMFILE, ENFILE, ENOBUFS, ENOMEM:
getTransportTooManyError(ecode)
else:
getTransportOsError(ecode)
else:
case ecode
of WSAECONNABORTED, WSAENETDOWN, WSAENETRESET, WSAECONNRESET, WSAETIMEDOUT:
getConnectionAbortedError(ecode)
of ERROR_TOO_MANY_OPEN_FILES, WSAENOBUFS, WSAEMFILE:
getTransportTooManyError(ecode)
else:
getTransportOsError(ecode)
proc raiseTransportError*(ecode: OSErrorCode) {.
raises: [TransportAbortedError, TransportTooManyError, TransportOsError],
noreturn.} =
## Raises transport specific OS error.
when defined(posix):
case ecode
of ECONNABORTED, EPERM, ETIMEDOUT, ENOTCONN:
raise getConnectionAbortedError(ecode)
of EMFILE, ENFILE, ENOBUFS, ENOMEM:
raise getTransportTooManyError(ecode)
else:
raise getTransportOsError(ecode)
else:
case ecode
of WSAECONNABORTED, WSAENETDOWN, WSAENETRESET, WSAECONNRESET, WSAETIMEDOUT:
raise getConnectionAbortedError(ecode)
of ERROR_TOO_MANY_OPEN_FILES, WSAENOBUFS, WSAEMFILE:
raise getTransportTooManyError(ecode)
else:
raise getTransportOsError(ecode)

View File

@ -53,10 +53,6 @@ type
rwsabuf: WSABUF # Reader WSABUF structure
wwsabuf: WSABUF # Writer WSABUF structure
DgramTransportTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
const
DgramTransportTrackerName* = "datagram.transport"
@ -88,39 +84,6 @@ template setReadError(t, e: untyped) =
(t).state.incl(ReadError)
(t).error = getTransportOsError(e)
proc setupDgramTransportTracker(): DgramTransportTracker {.
gcsafe, raises: [].}
proc getDgramTransportTracker(): DgramTransportTracker {.inline.} =
var res = cast[DgramTransportTracker](getTracker(DgramTransportTrackerName))
if isNil(res):
res = setupDgramTransportTracker()
doAssert(not(isNil(res)))
res
proc dumpTransportTracking(): string {.gcsafe.} =
var tracker = getDgramTransportTracker()
"Opened transports: " & $tracker.opened & "\n" &
"Closed transports: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
let tracker = getDgramTransportTracker()
tracker.opened != tracker.closed
proc trackDgram(t: DatagramTransport) {.inline.} =
var tracker = getDgramTransportTracker()
inc(tracker.opened)
proc untrackDgram(t: DatagramTransport) {.inline.} =
var tracker = getDgramTransportTracker()
inc(tracker.closed)
proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} =
let res = DgramTransportTracker(
opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport)
addTracker(DgramTransportTrackerName, res)
res
when defined(windows):
template setWriterWSABuffer(t, v: untyped) =
(t).wwsabuf.buf = cast[cstring](v.buf)
@ -213,7 +176,7 @@ when defined(windows):
transp.state.incl(ReadPaused)
if ReadClosed in transp.state and not(transp.future.finished()):
# Stop tracking transport
untrackDgram(transp)
untrackCounter(DgramTransportTrackerName)
# If `ReadClosed` present, then close(transport) was called.
transp.future.complete()
GC_unref(transp)
@ -259,7 +222,7 @@ when defined(windows):
# WSARecvFrom session.
if ReadClosed in transp.state and not(transp.future.finished()):
# Stop tracking transport
untrackDgram(transp)
untrackCounter(DgramTransportTrackerName)
transp.future.complete()
GC_unref(transp)
break
@ -394,7 +357,7 @@ when defined(windows):
len: ULONG(len(res.buffer)))
GC_ref(res)
# Start tracking transport
trackDgram(res)
trackCounter(DgramTransportTrackerName)
if NoAutoRead notin flags:
let rres = res.resumeRead()
if rres.isErr(): raiseTransportOsError(rres.error())
@ -503,11 +466,11 @@ else:
var res = if isNil(child): DatagramTransport() else: child
if sock == asyncInvalidSocket:
var proto = Protocol.IPPROTO_UDP
if local.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
let proto =
if local.family == AddressFamily.Unix:
Protocol.IPPROTO_IP
else:
Protocol.IPPROTO_UDP
localSock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM,
proto)
if localSock == asyncInvalidSocket:
@ -592,7 +555,7 @@ else:
res.future = newFuture[void]("datagram.transport")
GC_ref(res)
# Start tracking transport
trackDgram(res)
trackCounter(DgramTransportTrackerName)
if NoAutoRead notin flags:
let rres = res.resumeRead()
if rres.isErr(): raiseTransportOsError(rres.error())
@ -603,7 +566,7 @@ proc close*(transp: DatagramTransport) =
proc continuation(udata: pointer) {.raises: [].} =
if not(transp.future.finished()):
# Stop tracking transport
untrackDgram(transp)
untrackCounter(DgramTransportTrackerName)
transp.future.complete()
GC_unref(transp)

View File

@ -54,15 +54,6 @@ type
ReuseAddr,
ReusePort
StreamTransportTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
StreamServerTracker* = ref object of TrackerBase
opened*: int64
closed*: int64
ReadMessagePredicate* = proc (data: openArray[byte]): tuple[consumed: int,
done: bool] {.
gcsafe, raises: [].}
@ -70,6 +61,7 @@ type
const
StreamTransportTrackerName* = "stream.transport"
StreamServerTrackerName* = "stream.server"
DefaultBacklogSize* = high(int32)
when defined(windows):
type
@ -141,30 +133,28 @@ type
# transport for new client
proc remoteAddress*(transp: StreamTransport): TransportAddress {.
raises: [TransportError].} =
raises: [TransportAbortedError, TransportTooManyError, TransportOsError].} =
## Returns ``transp`` remote socket address.
if transp.kind != TransportKind.Socket:
raise newException(TransportError, "Socket required!")
doAssert(transp.kind == TransportKind.Socket, "Socket transport required!")
if transp.remote.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
raiseTransportError(osLastError())
fromSAddr(addr saddr, slen, transp.remote)
transp.remote
proc localAddress*(transp: StreamTransport): TransportAddress {.
raises: [TransportError].} =
raises: [TransportAbortedError, TransportTooManyError, TransportOsError].} =
## Returns ``transp`` local socket address.
if transp.kind != TransportKind.Socket:
raise newException(TransportError, "Socket required!")
doAssert(transp.kind == TransportKind.Socket, "Socket transport required!")
if transp.local.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
raiseTransportError(osLastError())
fromSAddr(addr saddr, slen, transp.local)
transp.local
@ -201,71 +191,6 @@ template shiftVectorFile(v: var StreamVector, o: untyped) =
(v).buf = cast[pointer](cast[uint]((v).buf) - uint(o))
(v).offset += uint(o)
proc setupStreamTransportTracker(): StreamTransportTracker {.
gcsafe, raises: [].}
proc setupStreamServerTracker(): StreamServerTracker {.
gcsafe, raises: [].}
proc getStreamTransportTracker(): StreamTransportTracker {.inline.} =
var res = cast[StreamTransportTracker](getTracker(StreamTransportTrackerName))
if isNil(res):
res = setupStreamTransportTracker()
doAssert(not(isNil(res)))
res
proc getStreamServerTracker(): StreamServerTracker {.inline.} =
var res = cast[StreamServerTracker](getTracker(StreamServerTrackerName))
if isNil(res):
res = setupStreamServerTracker()
doAssert(not(isNil(res)))
res
proc dumpTransportTracking(): string {.gcsafe.} =
var tracker = getStreamTransportTracker()
"Opened transports: " & $tracker.opened & "\n" &
"Closed transports: " & $tracker.closed
proc dumpServerTracking(): string {.gcsafe.} =
var tracker = getStreamServerTracker()
"Opened servers: " & $tracker.opened & "\n" &
"Closed servers: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getStreamTransportTracker()
tracker.opened != tracker.closed
proc leakServer(): bool {.gcsafe.} =
var tracker = getStreamServerTracker()
tracker.opened != tracker.closed
proc trackStream(t: StreamTransport) {.inline.} =
var tracker = getStreamTransportTracker()
inc(tracker.opened)
proc untrackStream(t: StreamTransport) {.inline.} =
var tracker = getStreamTransportTracker()
inc(tracker.closed)
proc trackServer(s: StreamServer) {.inline.} =
var tracker = getStreamServerTracker()
inc(tracker.opened)
proc untrackServer(s: StreamServer) {.inline.} =
var tracker = getStreamServerTracker()
inc(tracker.closed)
proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} =
let res = StreamTransportTracker(
opened: 0, closed: 0, dump: dumpTransportTracking, isLeaked: leakTransport)
addTracker(StreamTransportTrackerName, res)
res
proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} =
let res = StreamServerTracker(
opened: 0, closed: 0, dump: dumpServerTracking, isLeaked: leakServer)
addTracker(StreamServerTrackerName, res)
res
proc completePendingWriteQueue(queue: var Deque[StreamVector],
v: int) {.inline.} =
while len(queue) > 0:
@ -282,7 +207,7 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector],
proc clean(server: StreamServer) {.inline.} =
if not(server.loopFuture.finished()):
untrackServer(server)
untrackCounter(StreamServerTrackerName)
server.loopFuture.complete()
if not(isNil(server.udata)) and (GCUserData in server.flags):
GC_unref(cast[ref int](server.udata))
@ -290,7 +215,7 @@ proc clean(server: StreamServer) {.inline.} =
proc clean(transp: StreamTransport) {.inline.} =
if not(transp.future.finished()):
untrackStream(transp)
untrackCounter(StreamTransportTrackerName)
transp.future.complete()
GC_unref(transp)
@ -786,7 +711,7 @@ when defined(windows):
else:
let transp = newStreamSocketTransport(sock, bufferSize, child)
# Start tracking transport
trackStream(transp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(transp)
else:
sock.closeSocket()
@ -855,7 +780,7 @@ when defined(windows):
let transp = newStreamPipeTransport(AsyncFD(pipeHandle),
bufferSize, child)
# Start tracking transport
trackStream(transp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(transp)
pipeContinuation(nil)
@ -911,7 +836,7 @@ when defined(windows):
ntransp = newStreamPipeTransport(server.sock, server.bufferSize,
nil, flags)
# Start tracking transport
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
asyncSpawn server.function(server, ntransp)
of ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close call.
@ -1015,7 +940,7 @@ when defined(windows):
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
asyncSpawn server.function(server, ntransp)
of ERROR_OPERATION_ABORTED:
@ -1158,7 +1083,7 @@ when defined(windows):
ntransp = newStreamSocketTransport(server.asock,
server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(ntransp)
of ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or close.
@ -1218,7 +1143,7 @@ when defined(windows):
retFuture.fail(getTransportOsError(error))
return
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(ntransp)
of ERROR_OPERATION_ABORTED, ERROR_PIPE_NOT_CONNECTED:
@ -1550,14 +1475,13 @@ else:
var
saddr: Sockaddr_storage
slen: SockLen
proto: Protocol
var retFuture = newFuture[StreamTransport]("stream.transport.connect")
address.toSAddr(saddr, slen)
proto = Protocol.IPPROTO_TCP
if address.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
let proto =
if address.family == AddressFamily.Unix:
Protocol.IPPROTO_IP
else:
Protocol.IPPROTO_TCP
let sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM,
proto)
@ -1628,7 +1552,7 @@ else:
let transp = newStreamSocketTransport(sock, bufferSize, child)
# Start tracking transport
trackStream(transp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(transp)
proc cancel(udata: pointer) =
@ -1641,7 +1565,7 @@ else:
if res == 0:
let transp = newStreamSocketTransport(sock, bufferSize, child)
# Start tracking transport
trackStream(transp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(transp)
break
else:
@ -1696,7 +1620,7 @@ else:
newStreamSocketTransport(sock, server.bufferSize, transp)
else:
newStreamSocketTransport(sock, server.bufferSize, nil)
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
asyncSpawn server.function(server, ntransp)
else:
# Client was accepted, so we not going to raise assertion, but
@ -1784,7 +1708,7 @@ else:
else:
newStreamSocketTransport(sock, server.bufferSize, nil)
# Start tracking transport
trackStream(ntransp)
trackCounter(StreamTransportTrackerName)
retFuture.complete(ntransp)
else:
discard closeFd(cint(sock))
@ -1895,11 +1819,32 @@ proc closeWait*(server: StreamServer): Future[void] =
server.close()
server.join()
proc getBacklogSize(backlog: int): cint =
doAssert(backlog >= 0 and backlog <= high(int32))
when defined(windows):
# The maximum length of the queue of pending connections. If set to
# SOMAXCONN, the underlying service provider responsible for
# socket s will set the backlog to a maximum reasonable value. If set to
# SOMAXCONN_HINT(N) (where N is a number), the backlog value will be N,
# adjusted to be within the range (200, 65535). Note that SOMAXCONN_HINT
# can be used to set the backlog to a larger value than possible with
# SOMAXCONN.
#
# Microsoft SDK values are
# #define SOMAXCONN 0x7fffffff
# #define SOMAXCONN_HINT(b) (-(b))
if backlog != high(int32):
cint(-backlog)
else:
cint(backlog)
else:
cint(backlog)
proc createStreamServer*(host: TransportAddress,
cbproc: StreamCallback,
flags: set[ServerFlags] = {},
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
backlog: int = DefaultBacklogSize,
bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil,
@ -1982,7 +1927,7 @@ proc createStreamServer*(host: TransportAddress,
raiseTransportOsError(err)
fromSAddr(addr saddr, slen, localAddress)
if listen(SocketHandle(serverSocket), cint(backlog)) != 0:
if listen(SocketHandle(serverSocket), getBacklogSize(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
@ -1992,11 +1937,10 @@ proc createStreamServer*(host: TransportAddress,
else:
# Posix
if sock == asyncInvalidSocket:
var proto = Protocol.IPPROTO_TCP
if host.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)
let proto = if host.family == AddressFamily.Unix:
Protocol.IPPROTO_IP
else:
Protocol.IPPROTO_TCP
serverSocket = createAsyncSocket(host.getDomain(),
SockType.SOCK_STREAM,
proto)
@ -2056,7 +2000,7 @@ proc createStreamServer*(host: TransportAddress,
raiseTransportOsError(err)
fromSAddr(addr saddr, slen, localAddress)
if listen(SocketHandle(serverSocket), cint(backlog)) != 0:
if listen(SocketHandle(serverSocket), getBacklogSize(backlog)) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
discard unregisterAndCloseFd(serverSocket)
@ -2100,14 +2044,14 @@ proc createStreamServer*(host: TransportAddress,
sres.apending = false
# Start tracking server
trackServer(sres)
trackCounter(StreamServerTrackerName)
GC_ref(sres)
sres
proc createStreamServer*(host: TransportAddress,
flags: set[ServerFlags] = {},
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
backlog: int = DefaultBacklogSize,
bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil,
@ -2121,7 +2065,7 @@ proc createStreamServer*[T](host: TransportAddress,
flags: set[ServerFlags] = {},
udata: ref T,
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
backlog: int = DefaultBacklogSize,
bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil): StreamServer {.
@ -2135,7 +2079,7 @@ proc createStreamServer*[T](host: TransportAddress,
flags: set[ServerFlags] = {},
udata: ref T,
sock: AsyncFD = asyncInvalidSocket,
backlog: int = 100,
backlog: int = DefaultBacklogSize,
bufferSize: int = DefaultStreamBufferSize,
child: StreamServer = nil,
init: TransportInitCallback = nil): StreamServer {.
@ -2650,6 +2594,57 @@ proc closeWait*(transp: StreamTransport): Future[void] =
transp.close()
transp.join()
proc shutdownWait*(transp: StreamTransport): Future[void] =
## Perform graceful shutdown of TCP connection backed by transport ``transp``.
doAssert(transp.kind == TransportKind.Socket)
let retFuture = newFuture[void]("stream.transport.shutdown")
transp.checkClosed(retFuture)
transp.checkWriteEof(retFuture)
when defined(windows):
let loop = getThreadDispatcher()
proc continuation(udata: pointer) {.gcsafe.} =
let ovl = cast[RefCustomOverlapped](udata)
if not(retFuture.finished()):
if ovl.data.errCode == OSErrorCode(-1):
retFuture.complete()
else:
transp.state.excl({WriteEof})
retFuture.fail(getTransportOsError(ovl.data.errCode))
GC_unref(ovl)
let povl = RefCustomOverlapped(data: CompletionData(cb: continuation))
GC_ref(povl)
let res = loop.disconnectEx(SocketHandle(transp.fd),
cast[POVERLAPPED](povl), 0'u32, 0'u32)
if res == FALSE:
let err = osLastError()
case err
of ERROR_IO_PENDING:
transp.state.incl({WriteEof})
else:
GC_unref(povl)
retFuture.fail(getTransportOsError(err))
else:
transp.state.incl({WriteEof})
retFuture.complete()
retFuture
else:
proc continuation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()):
retFuture.complete()
let res = osdefs.shutdown(SocketHandle(transp.fd), SHUT_WR)
if res < 0:
let err = osLastError()
retFuture.fail(getTransportOsError(err))
else:
transp.state.incl({WriteEof})
callSoon(continuation, nil)
retFuture
proc closed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
({ReadClosed, WriteClosed} * transp.state != {})
@ -2676,7 +2671,7 @@ proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil,
? register2(fd)
var res = newStreamPipeTransport(fd, bufferSize, child)
# Start tracking transport
trackStream(res)
trackCounter(StreamTransportTrackerName)
ok(res)
proc fromPipe*(fd: AsyncFD, child: StreamTransport = nil,

View File

@ -6,6 +6,7 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/tables
import unittest2
import ../../chronos
@ -17,3 +18,14 @@ template asyncTest*(name: string, body: untyped): untyped =
proc() {.async, gcsafe.} =
body
)())
template checkLeaks*(name: string): untyped =
let counter = getTrackerCounter(name)
if counter.opened != counter.closed:
echo "[" & name & "] opened = ", counter.opened,
", closed = ", counter.closed
check counter.opened == counter.closed
template checkLeaks*(): untyped =
for key in getThreadDispatcher().trackerCounterKeys():
checkLeaks(key)

View File

@ -5,10 +5,22 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import testmacro, testsync, testsoon, testtime, testfut, testsignal,
testaddress, testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testproc, testratelimit, testfutures
import ".."/chronos/config
# Must be imported last to check for Pending futures
import testutils
when (chronosEventEngine in ["epoll", "kqueue"]) or defined(windows):
import testmacro, testsync, testsoon, testtime, testfut, testsignal,
testaddress, testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testproc, testratelimit, testfutures, testthreadsync
# Must be imported last to check for Pending futures
import testutils
elif chronosEventEngine == "poll":
# `poll` engine do not support signals and processes
import testmacro, testsync, testsoon, testtime, testfut, testaddress,
testdatagram, teststream, testserver, testbugs, testnet,
testasyncstream, testhttpserver, testshttpserver, testhttpclient,
testratelimit, testfutures, testthreadsync
# Must be imported last to check for Pending futures
import testutils

View File

@ -7,8 +7,8 @@
# MIT license (LICENSE-MIT)
import unittest2
import bearssl/[x509]
import ../chronos
import ../chronos/streams/[tlsstream, chunkstream, boundstream]
import ".."/chronos/unittest2/asynctests
import ".."/chronos/streams/[tlsstream, chunkstream, boundstream]
{.used.}
@ -145,7 +145,7 @@ proc createBigMessage(message: string, size: int): seq[byte] =
suite "AsyncStream test suite":
test "AsyncStream(StreamTransport) readExactly() test":
proc testReadExactly(address: TransportAddress): Future[bool] {.async.} =
proc testReadExactly(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -157,9 +157,10 @@ suite "AsyncStream test suite":
server.close()
var buffer = newSeq[byte](10)
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
await rstream.readExactly(addr buffer[0], 10)
check cast[string](buffer) == "0000000000"
@ -171,9 +172,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadExactly(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadExactly()) == true
test "AsyncStream(StreamTransport) readUntil() test":
proc testReadUntil(address: TransportAddress): Future[bool] {.async.} =
proc testReadUntil(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -186,9 +188,10 @@ suite "AsyncStream test suite":
var buffer = newSeq[byte](13)
var sep = @[byte('N'), byte('N'), byte('z')]
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var r1 = await rstream.readUntil(addr buffer[0], len(buffer), sep)
check:
@ -207,9 +210,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadUntil(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadUntil()) == true
test "AsyncStream(StreamTransport) readLine() test":
proc testReadLine(address: TransportAddress): Future[bool] {.async.} =
proc testReadLine(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -220,9 +224,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var r1 = await rstream.readLine()
check r1 == "0000000000"
@ -234,9 +239,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadLine(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadLine()) == true
test "AsyncStream(StreamTransport) read() test":
proc testRead(address: TransportAddress): Future[bool] {.async.} =
proc testRead(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -247,9 +253,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var buf1 = await rstream.read(10)
check cast[string](buf1) == "0000000000"
@ -259,9 +266,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testRead(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testRead()) == true
test "AsyncStream(StreamTransport) consume() test":
proc testConsume(address: TransportAddress): Future[bool] {.async.} =
proc testConsume(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -272,9 +280,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var res1 = await rstream.consume(10)
check:
@ -290,16 +299,13 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testConsume(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testConsume()) == true
test "AsyncStream(StreamTransport) leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()
test "AsyncStream(AsyncStream) readExactly() test":
proc testReadExactly2(address: TransportAddress): Future[bool] {.async.} =
proc testReadExactly2(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -323,9 +329,10 @@ suite "AsyncStream test suite":
server.close()
var buffer = newSeq[byte](10)
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
await rstream2.readExactly(addr buffer[0], 10)
@ -347,9 +354,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadExactly2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadExactly2()) == true
test "AsyncStream(AsyncStream) readUntil() test":
proc testReadUntil2(address: TransportAddress): Future[bool] {.async.} =
proc testReadUntil2(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -373,9 +381,10 @@ suite "AsyncStream test suite":
var buffer = newSeq[byte](13)
var sep = @[byte('N'), byte('N'), byte('z')]
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
@ -404,9 +413,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadUntil2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadUntil2()) == true
test "AsyncStream(AsyncStream) readLine() test":
proc testReadLine2(address: TransportAddress): Future[bool] {.async.} =
proc testReadLine2(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -425,9 +435,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
var r1 = await rstream2.readLine()
@ -449,9 +460,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testReadLine2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testReadLine2()) == true
test "AsyncStream(AsyncStream) read() test":
proc testRead2(address: TransportAddress): Future[bool] {.async.} =
proc testRead2(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -469,9 +481,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
var buf1 = await rstream2.read(10)
@ -488,9 +501,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testRead2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testRead2()) == true
test "AsyncStream(AsyncStream) consume() test":
proc testConsume2(address: TransportAddress): Future[bool] {.async.} =
proc testConsume2(): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
const
@ -518,9 +532,10 @@ suite "AsyncStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
@ -547,9 +562,10 @@ suite "AsyncStream test suite":
await transp.closeWait()
await server.join()
result = true
check waitFor(testConsume2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testConsume2()) == true
test "AsyncStream(AsyncStream) write(eof) test":
proc testWriteEof(address: TransportAddress): Future[bool] {.async.} =
proc testWriteEof(): Future[bool] {.async.} =
let
size = 10240
message = createBigMessage("ABCDEFGHIJKLMNOP", size)
@ -578,7 +594,8 @@ suite "AsyncStream test suite":
await transp.closeWait()
let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay}
var server = createStreamServer(address, processClient, flags = flags)
var server = createStreamServer(initTAddress("127.0.0.1:0"),
processClient, flags = flags)
server.start()
var conn = await connect(server.localAddress())
try:
@ -589,13 +606,10 @@ suite "AsyncStream test suite":
await server.closeWait()
return true
check waitFor(testWriteEof(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testWriteEof()) == true
test "AsyncStream(AsyncStream) leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()
suite "ChunkedStream test suite":
test "ChunkedStream test vectors":
@ -624,8 +638,7 @@ suite "ChunkedStream test suite":
" in\r\n\r\nchunks.\r\n0;position=4\r\n\r\n",
"Wikipedia in\r\n\r\nchunks."],
]
proc checkVector(address: TransportAddress,
inputstr: string): Future[string] {.async.} =
proc checkVector(inputstr: string): Future[string] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -637,9 +650,10 @@ suite "ChunkedStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
var res = await rstream2.read()
@ -650,15 +664,16 @@ suite "ChunkedStream test suite":
await server.join()
result = ress
proc testVectors(address: TransportAddress): Future[bool] {.async.} =
proc testVectors(): Future[bool] {.async.} =
var res = true
for i in 0..<len(ChunkedVectors):
var r = await checkVector(address, ChunkedVectors[i][0])
var r = await checkVector(ChunkedVectors[i][0])
if r != ChunkedVectors[i][1]:
res = false
break
result = res
check waitFor(testVectors(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testVectors()) == true
test "ChunkedStream incorrect chunk test":
const BadVectors = [
["10000000;\r\n1"],
@ -673,8 +688,7 @@ suite "ChunkedStream test suite":
["FFFFFFFF ;\r\n1"],
["z\r\n1"]
]
proc checkVector(address: TransportAddress,
inputstr: string): Future[bool] {.async.} =
proc checkVector(inputstr: string): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -687,9 +701,10 @@ suite "ChunkedStream test suite":
server.close()
var res = false
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
try:
@ -732,15 +747,15 @@ suite "ChunkedStream test suite":
await server.join()
result = res
proc testVectors2(address: TransportAddress): Future[bool] {.async.} =
proc testVectors2(): Future[bool] {.async.} =
var res = true
for i in 0..<len(BadVectors):
var r = await checkVector(address, BadVectors[i][0])
var r = await checkVector(BadVectors[i][0])
if not(r):
res = false
break
result = res
check waitFor(testVectors2(initTAddress("127.0.0.1:46001"))) == true
check waitFor(testVectors2()) == true
test "ChunkedStream hex decoding test":
for i in 0 ..< 256:
@ -756,8 +771,7 @@ suite "ChunkedStream test suite":
check hexValue(byte(ch)) == -1
test "ChunkedStream too big chunk header test":
proc checkTooBigChunkHeader(address: TransportAddress,
inputstr: seq[byte]): Future[bool] {.async.} =
proc checkTooBigChunkHeader(inputstr: seq[byte]): Future[bool] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp)
@ -768,9 +782,10 @@ suite "ChunkedStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
let res =
@ -787,15 +802,13 @@ suite "ChunkedStream test suite":
await server.join()
return res
let address = initTAddress("127.0.0.1:46001")
var data1 = createBigMessage("REQUESTSTREAMMESSAGE", 65600)
var data2 = createBigMessage("REQUESTSTREAMMESSAGE", 262400)
check waitFor(checkTooBigChunkHeader(address, data1)) == true
check waitFor(checkTooBigChunkHeader(address, data2)) == true
check waitFor(checkTooBigChunkHeader(data1)) == true
check waitFor(checkTooBigChunkHeader(data2)) == true
test "ChunkedStream read/write test":
proc checkVector(address: TransportAddress,
inputstr: seq[byte],
proc checkVector(inputstr: seq[byte],
chunkSize: int): Future[seq[byte]] {.async.} =
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
@ -816,9 +829,10 @@ suite "ChunkedStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
var res = await rstream2.read()
@ -828,20 +842,17 @@ suite "ChunkedStream test suite":
await server.join()
return res
proc testBigData(address: TransportAddress,
datasize: int, chunksize: int): Future[bool] {.async.} =
proc testBigData(datasize: int, chunksize: int): Future[bool] {.async.} =
var data = createBigMessage("REQUESTSTREAMMESSAGE", datasize)
var check = await checkVector(address, data, chunksize)
var check = await checkVector(data, chunksize)
return (data == check)
let address = initTAddress("127.0.0.1:46001")
check waitFor(testBigData(address, 65600, 1024)) == true
check waitFor(testBigData(address, 262400, 4096)) == true
check waitFor(testBigData(address, 767309, 4457)) == true
check waitFor(testBigData(65600, 1024)) == true
check waitFor(testBigData(262400, 4096)) == true
check waitFor(testBigData(767309, 4457)) == true
test "ChunkedStream read small chunks test":
proc checkVector(address: TransportAddress,
inputstr: seq[byte],
proc checkVector(inputstr: seq[byte],
writeChunkSize: int,
readChunkSize: int): Future[seq[byte]] {.async.} =
proc serveClient(server: StreamServer,
@ -863,9 +874,10 @@ suite "ChunkedStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newChunkedStreamReader(rstream)
var res: seq[byte]
@ -878,27 +890,20 @@ suite "ChunkedStream test suite":
await server.join()
return res
proc testSmallChunk(address: TransportAddress,
datasize: int,
proc testSmallChunk(datasize: int,
writeChunkSize: int,
readChunkSize: int): Future[bool] {.async.} =
var data = createBigMessage("REQUESTSTREAMMESSAGE", datasize)
var check = await checkVector(address, data, writeChunkSize,
readChunkSize)
var check = await checkVector(data, writeChunkSize, readChunkSize)
return (data == check)
let address = initTAddress("127.0.0.1:46001")
check waitFor(testSmallChunk(address, 4457, 128, 1)) == true
check waitFor(testSmallChunk(address, 65600, 1024, 17)) == true
check waitFor(testSmallChunk(address, 262400, 4096, 61)) == true
check waitFor(testSmallChunk(address, 767309, 4457, 173)) == true
check waitFor(testSmallChunk(4457, 128, 1)) == true
check waitFor(testSmallChunk(65600, 1024, 17)) == true
check waitFor(testSmallChunk(262400, 4096, 61)) == true
check waitFor(testSmallChunk(767309, 4457, 173)) == true
test "ChunkedStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()
suite "TLSStream test suite":
const HttpHeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
@ -933,8 +938,7 @@ suite "TLSStream test suite":
"www.google.com"))
check res == true
proc checkSSLServer(address: TransportAddress,
pemkey, pemcert: string): Future[bool] {.async.} =
proc checkSSLServer(pemkey, pemcert: string): Future[bool] {.async.} =
var key: TLSPrivateKey
var cert: TLSCertificate
let testMessage = "TEST MESSAGE"
@ -958,9 +962,10 @@ suite "TLSStream test suite":
key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert)
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ServerFlags.ReuseAddr})
server.start()
var conn = await connect(address)
var conn = await connect(server.localAddress())
var creader = newAsyncStreamReader(conn)
var cwriter = newAsyncStreamWriter(conn)
# We are using self-signed certificate
@ -976,8 +981,7 @@ suite "TLSStream test suite":
return cast[string](res) == (testMessage & "\r\n")
test "Simple server with RSA self-signed certificate":
let res = waitFor(checkSSLServer(initTAddress("127.0.0.1:43808"),
SelfSignedRsaKey, SelfSignedRsaCert))
let res = waitFor(checkSSLServer(SelfSignedRsaKey, SelfSignedRsaCert))
check res == true
test "Custom TrustAnchors test":
@ -985,7 +989,6 @@ suite "TLSStream test suite":
var key = TLSPrivateKey.init(SelfSignedRsaKey)
var cert = TLSCertificate.init(SelfSignedRsaCert)
let trustAnchors = TrustAnchorStore.new(SelfSignedTrustAnchors)
let address = initTAddress("127.0.0.1:43808")
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
@ -1003,9 +1006,10 @@ suite "TLSStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var conn = await connect(address)
var conn = await connect(server.localAddress())
var creader = newAsyncStreamReader(conn)
var cwriter = newAsyncStreamWriter(conn)
let flags = {NoVerifyServerName}
@ -1023,11 +1027,7 @@ suite "TLSStream test suite":
check res == "Some message\r\n"
test "TLSStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()
suite "BoundedStream test suite":
@ -1041,7 +1041,7 @@ suite "BoundedStream test suite":
for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]:
for itemSize in [100, 60000]:
proc boundaryTest(address: TransportAddress, btest: BoundaryBytesTest,
proc boundaryTest(btest: BoundaryBytesTest,
size: int, boundary: seq[byte],
cmp: BoundCmp): Future[bool] {.async.} =
var message = createBigMessage("ABCDEFGHIJKLMNOP", size)
@ -1091,7 +1091,8 @@ suite "BoundedStream test suite":
var res = false
let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay}
var server = createStreamServer(address, processClient, flags = flags)
var server = createStreamServer(initTAddress("127.0.0.1:0"),
processClient, flags = flags)
server.start()
var conn = await connect(server.localAddress())
var rstream = newAsyncStreamReader(conn)
@ -1137,7 +1138,7 @@ suite "BoundedStream test suite":
await server.join()
return (res and clientRes)
proc boundedTest(address: TransportAddress, stest: BoundarySizeTest,
proc boundedTest(stest: BoundarySizeTest,
size: int, cmp: BoundCmp): Future[bool] {.async.} =
var clientRes = false
var res = false
@ -1205,7 +1206,8 @@ suite "BoundedStream test suite":
server.close()
let flags = {ServerFlags.ReuseAddr, ServerFlags.TcpNoDelay}
var server = createStreamServer(address, processClient, flags = flags)
var server = createStreamServer(initTAddress("127.0.0.1:0"),
processClient, flags = flags)
server.start()
var conn = await connect(server.localAddress())
var rstream = newAsyncStreamReader(conn)
@ -1258,7 +1260,6 @@ suite "BoundedStream test suite":
await server.join()
return (res and clientRes)
let address = initTAddress("127.0.0.1:0")
let suffix =
case itemComp
of BoundCmp.Equal:
@ -1267,39 +1268,38 @@ suite "BoundedStream test suite":
"<= " & $itemSize
test "BoundedStream(size) reading/writing test [" & suffix & "]":
check waitFor(boundedTest(address, SizeReadWrite, itemSize,
check waitFor(boundedTest(SizeReadWrite, itemSize,
itemComp)) == true
test "BoundedStream(size) overflow test [" & suffix & "]":
check waitFor(boundedTest(address, SizeOverflow, itemSize,
check waitFor(boundedTest(SizeOverflow, itemSize,
itemComp)) == true
test "BoundedStream(size) incomplete test [" & suffix & "]":
check waitFor(boundedTest(address, SizeIncomplete, itemSize,
check waitFor(boundedTest(SizeIncomplete, itemSize,
itemComp)) == true
test "BoundedStream(size) empty message test [" & suffix & "]":
check waitFor(boundedTest(address, SizeEmpty, itemSize,
check waitFor(boundedTest(SizeEmpty, itemSize,
itemComp)) == true
test "BoundedStream(boundary) reading test [" & suffix & "]":
check waitFor(boundaryTest(address, BoundaryRead, itemSize,
check waitFor(boundaryTest(BoundaryRead, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) double message test [" & suffix & "]":
check waitFor(boundaryTest(address, BoundaryDouble, itemSize,
check waitFor(boundaryTest(BoundaryDouble, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(size+boundary) reading size-bound test [" &
suffix & "]":
check waitFor(boundaryTest(address, BoundarySize, itemSize,
check waitFor(boundaryTest(BoundarySize, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) reading incomplete test [" &
suffix & "]":
check waitFor(boundaryTest(address, BoundaryIncomplete, itemSize,
check waitFor(boundaryTest(BoundaryIncomplete, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) empty message test [" &
suffix & "]":
check waitFor(boundaryTest(address, BoundaryEmpty, itemSize,
check waitFor(boundaryTest(BoundaryEmpty, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream read small chunks test":
proc checkVector(address: TransportAddress,
inputstr: seq[byte],
proc checkVector(inputstr: seq[byte],
writeChunkSize: int,
readChunkSize: int): Future[seq[byte]] {.async.} =
proc serveClient(server: StreamServer,
@ -1321,9 +1321,10 @@ suite "BoundedStream test suite":
server.stop()
server.close()
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var rstream2 = newBoundedStreamReader(rstream, 1048576,
comparison = BoundCmp.LessOrEqual)
@ -1337,23 +1338,19 @@ suite "BoundedStream test suite":
await server.join()
return res
proc testSmallChunk(address: TransportAddress,
datasize: int,
writeChunkSize: int,
proc testSmallChunk(datasize: int, writeChunkSize: int,
readChunkSize: int): Future[bool] {.async.} =
var data = createBigMessage("0123456789ABCDEFGHI", datasize)
var check = await checkVector(address, data, writeChunkSize,
readChunkSize)
var check = await checkVector(data, writeChunkSize, readChunkSize)
return (data == check)
let address = initTAddress("127.0.0.1:46001")
check waitFor(testSmallChunk(address, 4457, 128, 1)) == true
check waitFor(testSmallChunk(address, 65600, 1024, 17)) == true
check waitFor(testSmallChunk(address, 262400, 4096, 61)) == true
check waitFor(testSmallChunk(address, 767309, 4457, 173)) == true
check waitFor(testSmallChunk(4457, 128, 1)) == true
check waitFor(testSmallChunk(65600, 1024, 17)) == true
check waitFor(testSmallChunk(262400, 4096, 61)) == true
check waitFor(testSmallChunk(767309, 4457, 173)) == true
test "BoundedStream zero-sized streams test":
proc checkEmptyStreams(address: TransportAddress): Future[bool] {.async.} =
proc checkEmptyStreams(): Future[bool] {.async.} =
var writer1Res = false
proc serveClient(server: StreamServer,
transp: StreamTransport) {.async.} =
@ -1368,9 +1365,10 @@ suite "BoundedStream test suite":
server.close()
writer1Res = res
var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(initTAddress("127.0.0.1:0"),
serveClient, {ReuseAddr})
server.start()
var transp = await connect(address)
var transp = await connect(server.localAddress())
var rstream = newAsyncStreamReader(transp)
var wstream3 = newAsyncStreamWriter(transp)
var rstream2 = newBoundedStreamReader(rstream, 0'u64)
@ -1394,12 +1392,7 @@ suite "BoundedStream test suite":
await server.join()
return (writer1Res and writer2Res and readerRes)
let address = initTAddress("127.0.0.1:46001")
check waitFor(checkEmptyStreams(address)) == true
check waitFor(checkEmptyStreams()) == true
test "BoundedStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()

View File

@ -6,8 +6,8 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[strutils, net]
import unittest2
import ../chronos
import ".."/chronos/unittest2/asynctests
import ".."/chronos
{.used.}
@ -558,4 +558,4 @@ suite "Datagram Transport test suite":
test "0.0.0.0/::0 (INADDR_ANY) test":
check waitFor(testAnyAddress()) == 6
test "Transports leak test":
check getTracker("datagram.transport").isLeaked() == false
checkLeaks()

View File

@ -6,8 +6,9 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[strutils, sha1]
import unittest2
import ../chronos, ../chronos/apps/http/[httpserver, shttpserver, httpclient]
import ".."/chronos/unittest2/asynctests
import ".."/chronos,
".."/chronos/apps/http/[httpserver, shttpserver, httpclient]
import stew/base10
{.used.}
@ -85,30 +86,36 @@ suite "HTTP client testing suite":
proc createServer(address: TransportAddress,
process: HttpProcessCallback, secure: bool): HttpServerRef =
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
serverFlags = {HttpServerFlags.Http11Pipeline}
if secure:
let secureKey = TLSPrivateKey.init(HttpsSelfSignedRsaKey)
let secureCert = TLSCertificate.init(HttpsSelfSignedRsaCert)
let res = SecureHttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags,
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
HttpServerRef(res.get())
else:
let res = HttpServerRef.new(address, process, socketFlags = socketFlags)
let res = HttpServerRef.new(address, process,
socketFlags = socketFlags,
serverFlags = serverFlags)
res.get()
proc createSession(secure: bool,
maxRedirections = HttpMaxRedirections): HttpSessionRef =
if secure:
HttpSessionRef.new({HttpClientFlag.NoVerifyHost,
HttpClientFlag.NoVerifyServerName},
HttpClientFlag.NoVerifyServerName,
HttpClientFlag.Http11Pipeline},
maxRedirections = maxRedirections)
else:
HttpSessionRef.new(maxRedirections = maxRedirections)
HttpSessionRef.new({HttpClientFlag.Http11Pipeline},
maxRedirections = maxRedirections)
proc testMethods(address: TransportAddress,
secure: bool): Future[int] {.async.} =
proc testMethods(secure: bool): Future[int] {.async.} =
let RequestTests = [
(MethodGet, "/test/get"),
(MethodPost, "/test/post"),
@ -132,10 +139,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
var session = createSession(secure)
@ -175,8 +183,7 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testResponseStreamReadingTest(address: TransportAddress,
secure: bool): Future[int] {.async.} =
proc testResponseStreamReadingTest(secure: bool): Future[int] {.async.} =
let ResponseTests = [
(MethodGet, "/test/short_size_response", 65600, 1024,
"SHORTSIZERESPONSE"),
@ -235,10 +242,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
var session = createSession(secure)
@ -297,8 +305,7 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testRequestSizeStreamWritingTest(address: TransportAddress,
secure: bool): Future[int] {.async.} =
proc testRequestSizeStreamWritingTest(secure: bool): Future[int] {.async.} =
let RequestTests = [
(MethodPost, "/test/big_request", 65600),
(MethodPost, "/test/big_request", 262400)
@ -318,10 +325,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
var session = createSession(secure)
@ -366,7 +374,7 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testRequestChunkedStreamWritingTest(address: TransportAddress,
proc testRequestChunkedStreamWritingTest(
secure: bool): Future[int] {.async.} =
let RequestTests = [
(MethodPost, "/test/big_chunk_request", 65600),
@ -387,10 +395,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
var session = createSession(secure)
@ -435,8 +444,7 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testRequestPostUrlEncodedTest(address: TransportAddress,
secure: bool): Future[int] {.async.} =
proc testRequestPostUrlEncodedTest(secure: bool): Future[int] {.async.} =
let PostRequests = [
("/test/post/urlencoded_size",
"field1=value1&field2=value2&field3=value3", "value1:value2:value3"),
@ -463,10 +471,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
## Sized url-encoded form
@ -533,8 +542,7 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testRequestPostMultipartTest(address: TransportAddress,
secure: bool): Future[int] {.async.} =
proc testRequestPostMultipartTest(secure: bool): Future[int] {.async.} =
let PostRequests = [
("/test/post/multipart_size", "some-part-boundary",
[("field1", "value1"), ("field2", "value2"), ("field3", "value3")],
@ -562,10 +570,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var counter = 0
## Sized multipart form
@ -635,17 +644,9 @@ suite "HTTP client testing suite":
await server.closeWait()
return counter
proc testRequestRedirectTest(address: TransportAddress,
secure: bool,
proc testRequestRedirectTest(secure: bool,
max: int): Future[string] {.async.} =
var session = createSession(secure, maxRedirections = max)
let ha =
if secure:
getAddress(address, HttpClientScheme.Secure, "/")
else:
getAddress(address, HttpClientScheme.NonSecure, "/")
let lastAddress = ha.getUri().combine(parseUri("/final/5"))
var lastAddress: Uri
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -667,10 +668,22 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var session = createSession(secure, maxRedirections = max)
let ha =
if secure:
getAddress(address, HttpClientScheme.Secure, "/")
else:
getAddress(address, HttpClientScheme.NonSecure, "/")
lastAddress = ha.getUri().combine(parseUri("/final/5"))
if session.maxRedirections >= 5:
let (code, data) = await session.fetch(ha.getUri())
await session.closeWait()
@ -691,26 +704,22 @@ suite "HTTP client testing suite":
await server.closeWait()
return "redirect-" & $res
proc testBasicAuthorization(): Future[bool] {.async.} =
let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost},
maxRedirections = 10)
let url = parseUri("https://guest:guest@jigsaw.w3.org/HTTP/Basic/")
let resp = await session.fetch(url)
await session.closeWait()
if (resp.status == 200) and
("Your browser made it!" in bytesToString(resp.data)):
return true
else:
echo "RESPONSE STATUS = [", resp.status, "]"
echo "RESPONSE = [", bytesToString(resp.data), "]"
return false
# proc testBasicAuthorization(): Future[bool] {.async.} =
# let session = HttpSessionRef.new({HttpClientFlag.NoVerifyHost},
# maxRedirections = 10)
# let url = parseUri("https://guest:guest@jigsaw.w3.org/HTTP/Basic/")
# let resp = await session.fetch(url)
# await session.closeWait()
# if (resp.status == 200) and
# ("Your browser made it!" in bytesToString(resp.data)):
# return true
# else:
# echo "RESPONSE STATUS = [", resp.status, "]"
# echo "RESPONSE = [", bytesToString(resp.data), "]"
# return false
proc testConnectionManagement(address: TransportAddress): Future[bool] {.
proc testConnectionManagement(): Future[bool] {.
async.} =
let
keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep")
dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop")
proc test1(
a1: HttpAddress,
version: HttpVersion,
@ -770,10 +779,15 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, false)
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start()
let address = server.instance.localAddress()
let
keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep")
dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop")
try:
let
@ -872,11 +886,7 @@ suite "HTTP client testing suite":
return true
proc testIdleConnection(address: TransportAddress): Future[bool] {.
async.} =
let
ha = getAddress(address, HttpClientScheme.NonSecure, "/test")
proc testIdleConnection(): Future[bool] {.async.} =
proc test(
session: HttpSessionRef,
a: HttpAddress
@ -900,13 +910,16 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, false)
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start()
let session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline},
idleTimeout = 1.seconds,
idlePeriod = 200.milliseconds)
let
address = server.instance.localAddress()
ha = getAddress(address, HttpClientScheme.NonSecure, "/test")
session = HttpSessionRef.new({HttpClientFlag.Http11Pipeline},
idleTimeout = 1.seconds,
idlePeriod = 200.milliseconds)
try:
var f1 = test(session, ha)
var f2 = test(session, ha)
@ -932,12 +945,7 @@ suite "HTTP client testing suite":
return true
proc testNoPipeline(address: TransportAddress): Future[bool] {.
async.} =
let
ha = getAddress(address, HttpClientScheme.NonSecure, "/test")
hb = getAddress(address, HttpClientScheme.NonSecure, "/keep-test")
proc testNoPipeline(): Future[bool] {.async.} =
proc test(
session: HttpSessionRef,
a: HttpAddress
@ -964,12 +972,16 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, false)
var server = createServer(initTAddress("127.0.0.1:0"), process, false)
server.start()
let session = HttpSessionRef.new(idleTimeout = 100.seconds,
idlePeriod = 10.milliseconds)
let
address = server.instance.localAddress()
ha = getAddress(address, HttpClientScheme.NonSecure, "/test")
hb = getAddress(address, HttpClientScheme.NonSecure, "/keep-test")
session = HttpSessionRef.new(idleTimeout = 100.seconds,
idlePeriod = 10.milliseconds)
try:
var f1 = test(session, ha)
var f2 = test(session, ha)
@ -1001,8 +1013,7 @@ suite "HTTP client testing suite":
return true
proc testServerSentEvents(address: TransportAddress,
secure: bool): Future[bool] {.async.} =
proc testServerSentEvents(secure: bool): Future[bool] {.async.} =
const
SingleGoodTests = [
("/test/single/1", "a:b\r\nc: d\re:f\n:comment\r\ng:\n h: j \n\n",
@ -1115,10 +1126,11 @@ suite "HTTP client testing suite":
else:
return await request.respond(Http404, "Page not found")
else:
return dumbResponse()
return defaultResponse()
var server = createServer(address, process, secure)
var server = createServer(initTAddress("127.0.0.1:0"), process, secure)
server.start()
let address = server.instance.localAddress()
var session = createSession(secure)
@ -1184,100 +1196,71 @@ suite "HTTP client testing suite":
return true
test "HTTP all request methods test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testMethods(address, false)) == 18
check waitFor(testMethods(false)) == 18
test "HTTP(S) all request methods test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testMethods(address, true)) == 18
check waitFor(testMethods(true)) == 18
test "HTTP client response streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testResponseStreamReadingTest(address, false)) == 8
check waitFor(testResponseStreamReadingTest(false)) == 8
test "HTTP(S) client response streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testResponseStreamReadingTest(address, true)) == 8
check waitFor(testResponseStreamReadingTest(true)) == 8
test "HTTP client (size) request streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestSizeStreamWritingTest(address, false)) == 2
check waitFor(testRequestSizeStreamWritingTest(false)) == 2
test "HTTP(S) client (size) request streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestSizeStreamWritingTest(address, true)) == 2
check waitFor(testRequestSizeStreamWritingTest(true)) == 2
test "HTTP client (chunked) request streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestChunkedStreamWritingTest(address, false)) == 2
check waitFor(testRequestChunkedStreamWritingTest(false)) == 2
test "HTTP(S) client (chunked) request streaming test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestChunkedStreamWritingTest(address, true)) == 2
check waitFor(testRequestChunkedStreamWritingTest(true)) == 2
test "HTTP client (size + chunked) url-encoded POST test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestPostUrlEncodedTest(address, false)) == 2
check waitFor(testRequestPostUrlEncodedTest(false)) == 2
test "HTTP(S) client (size + chunked) url-encoded POST test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestPostUrlEncodedTest(address, true)) == 2
check waitFor(testRequestPostUrlEncodedTest(true)) == 2
test "HTTP client (size + chunked) multipart POST test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestPostMultipartTest(address, false)) == 2
check waitFor(testRequestPostMultipartTest(false)) == 2
test "HTTP(S) client (size + chunked) multipart POST test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestPostMultipartTest(address, true)) == 2
check waitFor(testRequestPostMultipartTest(true)) == 2
test "HTTP client redirection test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestRedirectTest(address, false, 5)) == "ok-5-200"
check waitFor(testRequestRedirectTest(false, 5)) == "ok-5-200"
test "HTTP(S) client redirection test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestRedirectTest(address, true, 5)) == "ok-5-200"
check waitFor(testRequestRedirectTest(true, 5)) == "ok-5-200"
test "HTTP client maximum redirections test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestRedirectTest(address, false, 4)) == "redirect-true"
check waitFor(testRequestRedirectTest(false, 4)) == "redirect-true"
test "HTTP(S) client maximum redirections test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testRequestRedirectTest(address, true, 4)) == "redirect-true"
check waitFor(testRequestRedirectTest(true, 4)) == "redirect-true"
test "HTTPS basic authorization test":
check waitFor(testBasicAuthorization()) == true
skip()
# This test disabled because remote service is pretty flaky and fails pretty
# often. As soon as more stable service will be found this test should be
# recovered
# check waitFor(testBasicAuthorization()) == true
test "HTTP client connection management test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testConnectionManagement(address)) == true
check waitFor(testConnectionManagement()) == true
test "HTTP client idle connection test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testIdleConnection(address)) == true
check waitFor(testIdleConnection()) == true
test "HTTP client no-pipeline test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testNoPipeline(address)) == true
check waitFor(testNoPipeline()) == true
test "HTTP client server-sent events test":
let address = initTAddress("127.0.0.1:30080")
check waitFor(testServerSentEvents(address, false)) == true
check waitFor(testServerSentEvents(false)) == true
test "Leaks test":
proc getTrackerLeaks(tracker: string): bool =
let tracker = getTracker(tracker)
if isNil(tracker): false else: tracker.isLeaked()
check:
getTrackerLeaks("http.body.reader") == false
getTrackerLeaks("http.body.writer") == false
getTrackerLeaks("httpclient.connection") == false
getTrackerLeaks("httpclient.request") == false
getTrackerLeaks("httpclient.response") == false
getTrackerLeaks("async.stream.reader") == false
getTrackerLeaks("async.stream.writer") == false
getTrackerLeaks("stream.server") == false
getTrackerLeaks("stream.transport") == false
checkLeaks()

View File

@ -6,9 +6,9 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[strutils, algorithm]
import unittest2
import ../chronos, ../chronos/apps/http/httpserver,
../chronos/apps/http/httpcommon
import ".."/chronos/unittest2/asynctests,
".."/chronos,
".."/chronos/apps/http/[httpserver, httpcommon, httpdebug]
import stew/base10
{.used.}
@ -17,6 +17,9 @@ suite "HTTP server testing suite":
type
TooBigTest = enum
GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest
TestHttpResponse = object
headers: HttpTable
data: string
proc httpClient(address: TransportAddress,
data: string): Future[string] {.async.} =
@ -33,8 +36,32 @@ suite "HTTP server testing suite":
if not(isNil(transp)):
await closeWait(transp)
proc testTooBigBodyChunked(address: TransportAddress,
operation: TooBigTest): Future[bool] {.async.} =
proc httpClient2(transp: StreamTransport,
request: string,
length: int): Future[TestHttpResponse] {.async.} =
var buffer = newSeq[byte](4096)
var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
let wres = await transp.write(request)
if wres != len(request):
raise newException(ValueError, "Unable to write full request")
let hres = await transp.readUntil(addr buffer[0], len(buffer), sep)
var hdata = @buffer
hdata.setLen(hres)
zeroMem(addr buffer[0], len(buffer))
await transp.readExactly(addr buffer[0], length)
let data = bytesToString(buffer.toOpenArray(0, length - 1))
let headers =
block:
let resp = parseResponse(hdata, false)
if resp.failed():
raise newException(ValueError, "Unable to decode response headers")
var res = HttpTable.init()
for key, value in resp.headers(hdata):
res.add(key, value)
res
return TestHttpResponse(headers: headers, data: data)
proc testTooBigBodyChunked(operation: TooBigTest): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -56,10 +83,10 @@ suite "HTTP server testing suite":
# Reraising exception, because processor should properly handle it.
raise exc
else:
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
maxRequestBodySize = 10,
socketFlags = socketFlags)
if res.isErr():
@ -67,18 +94,19 @@ suite "HTTP server testing suite":
let server = res.get()
server.start()
let address = server.instance.localAddress()
let request =
case operation
of GetBodyTest, ConsumeBodyTest, PostUrlTest:
"POST / HTTP/1.0\r\n" &
"POST / HTTP/1.1\r\n" &
"Content-Type: application/x-www-form-urlencoded\r\n" &
"Transfer-Encoding: chunked\r\n" &
"Cookie: 2\r\n\r\n" &
"5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" &
"2\r\n9F\r\n0\r\n\r\n"
of PostMultipartTest:
"POST / HTTP/1.0\r\n" &
"POST / HTTP/1.1\r\n" &
"Host: 127.0.0.1:30080\r\n" &
"Transfer-Encoding: chunked\r\n" &
"Content-Type: multipart/form-data; boundary=f98f0\r\n\r\n" &
@ -97,7 +125,7 @@ suite "HTTP server testing suite":
return serverRes and (data.startsWith("HTTP/1.1 413"))
test "Request headers timeout test":
proc testTimeout(address: TransportAddress): Future[bool] {.async.} =
proc testTimeout(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -105,28 +133,29 @@ suite "HTTP server testing suite":
let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init())
else:
if r.error().error == HttpServerError.TimeoutError:
if r.error.kind == HttpServerError.TimeoutError:
serverRes = true
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process, socketFlags = socketFlags,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
process, socketFlags = socketFlags,
httpHeadersTimeout = 100.milliseconds)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let data = await httpClient(address, "")
await server.stop()
await server.closeWait()
return serverRes and (data.startsWith("HTTP/1.1 408"))
check waitFor(testTimeout(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testTimeout()) == true
test "Empty headers test":
proc testEmpty(address: TransportAddress): Future[bool] {.async.} =
proc testEmpty(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -134,27 +163,29 @@ suite "HTTP server testing suite":
let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init())
else:
if r.error().error == HttpServerError.CriticalError:
if r.error.kind == HttpServerError.CriticalError:
serverRes = true
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process, socketFlags = socketFlags)
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"),
process, socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let data = await httpClient(address, "\r\n\r\n")
await server.stop()
await server.closeWait()
return serverRes and (data.startsWith("HTTP/1.1 400"))
check waitFor(testEmpty(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testEmpty()) == true
test "Too big headers test":
proc testTooBig(address: TransportAddress): Future[bool] {.async.} =
proc testTooBig(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -162,12 +193,12 @@ suite "HTTP server testing suite":
let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init())
else:
if r.error().error == HttpServerError.CriticalError:
if r.error.error == HttpServerError.CriticalError:
serverRes = true
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
maxHeadersSize = 10,
socketFlags = socketFlags)
if res.isErr():
@ -175,28 +206,29 @@ suite "HTTP server testing suite":
let server = res.get()
server.start()
let address = server.instance.localAddress()
let data = await httpClient(address, "GET / HTTP/1.1\r\n\r\n")
await server.stop()
await server.closeWait()
return serverRes and (data.startsWith("HTTP/1.1 431"))
check waitFor(testTooBig(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testTooBig()) == true
test "Too big request body test (content-length)":
proc testTooBigBody(address: TransportAddress): Future[bool] {.async.} =
proc testTooBigBody(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
if r.isOk():
discard
else:
if r.error().error == HttpServerError.CriticalError:
if r.error.error == HttpServerError.CriticalError:
serverRes = true
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
maxRequestBodySize = 10,
socketFlags = socketFlags)
if res.isErr():
@ -204,6 +236,7 @@ suite "HTTP server testing suite":
let server = res.get()
server.start()
let address = server.instance.localAddress()
let request = "GET / HTTP/1.1\r\nContent-Length: 20\r\n\r\n"
let data = await httpClient(address, request)
@ -211,30 +244,26 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.startsWith("HTTP/1.1 413"))
check waitFor(testTooBigBody(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testTooBigBody()) == true
test "Too big request body test (getBody()/chunked encoding)":
check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
GetBodyTest)) == true
waitFor(testTooBigBodyChunked(GetBodyTest)) == true
test "Too big request body test (consumeBody()/chunked encoding)":
check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
ConsumeBodyTest)) == true
waitFor(testTooBigBodyChunked(ConsumeBodyTest)) == true
test "Too big request body test (post()/urlencoded/chunked encoding)":
check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
PostUrlTest)) == true
waitFor(testTooBigBodyChunked(PostUrlTest)) == true
test "Too big request body test (post()/multipart/chunked encoding)":
check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
PostMultipartTest)) == true
waitFor(testTooBigBodyChunked(PostMultipartTest)) == true
test "Query arguments test":
proc testQuery(address: TransportAddress): Future[bool] {.async.} =
proc testQuery(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -249,16 +278,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let data1 = await httpClient(address,
"GET /?a=1&a=2&b=3&c=4 HTTP/1.0\r\n\r\n")
@ -271,10 +301,10 @@ suite "HTTP server testing suite":
(data2.find("TEST_OK:a:П:b:Ц:c:Ю:Ф:Б") >= 0)
return r
check waitFor(testQuery(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testQuery()) == true
test "Headers test":
proc testHeaders(address: TransportAddress): Future[bool] {.async.} =
proc testHeaders(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -289,16 +319,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"GET / HTTP/1.0\r\n" &
@ -314,10 +345,10 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testHeaders(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testHeaders()) == true
test "POST arguments (urlencoded/content-length) test":
proc testPostUrl(address: TransportAddress): Future[bool] {.async.} =
proc testPostUrl(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -334,16 +365,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"POST / HTTP/1.0\r\n" &
@ -357,10 +389,10 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testPostUrl(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testPostUrl()) == true
test "POST arguments (urlencoded/chunked encoding) test":
proc testPostUrl2(address: TransportAddress): Future[bool] {.async.} =
proc testPostUrl2(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -377,16 +409,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"POST / HTTP/1.0\r\n" &
@ -401,10 +434,10 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testPostUrl2(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testPostUrl2()) == true
test "POST arguments (multipart/content-length) test":
proc testPostMultipart(address: TransportAddress): Future[bool] {.async.} =
proc testPostMultipart(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -421,16 +454,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"POST / HTTP/1.0\r\n" &
@ -456,10 +490,10 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testPostMultipart(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testPostMultipart()) == true
test "POST arguments (multipart/chunked encoding) test":
proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} =
proc testPostMultipart2(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -476,16 +510,17 @@ suite "HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"POST / HTTP/1.0\r\n" &
@ -520,12 +555,12 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testPostMultipart2()) == true
test "drop() connections test":
const ClientsCount = 10
proc testHTTPdrop(address: TransportAddress): Future[bool] {.async.} =
proc testHTTPdrop(): Future[bool] {.async.} =
var eventWait = newAsyncEvent()
var eventContinue = newAsyncEvent()
var count = 0
@ -539,10 +574,10 @@ suite "HTTP server testing suite":
await eventContinue.wait()
return await request.respond(Http404, "", HttpTable.init())
else:
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags,
maxConnections = 100)
if res.isErr():
@ -550,6 +585,7 @@ suite "HTTP server testing suite":
let server = res.get()
server.start()
let address = server.instance.localAddress()
var clients: seq[Future[string]]
let message = "GET / HTTP/1.0\r\nHost: https://127.0.0.1:80\r\n\r\n"
@ -572,7 +608,7 @@ suite "HTTP server testing suite":
return false
return true
check waitFor(testHTTPdrop(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testHTTPdrop()) == true
test "Content-Type multipart boundary test":
const AllowedCharacters = {
@ -1190,7 +1226,7 @@ suite "HTTP server testing suite":
r6.get() == MediaType.init(req[1][6])
test "SSE server-side events stream test":
proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} =
proc testPostMultipart2(): Future[bool] {.async.} =
var serverRes = false
proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} =
@ -1209,16 +1245,17 @@ suite "HTTP server testing suite":
return response
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(address, process,
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags)
if res.isErr():
return false
let server = res.get()
server.start()
let address = server.instance.localAddress()
let message =
"GET / HTTP/1.1\r\n" &
@ -1237,12 +1274,158 @@ suite "HTTP server testing suite":
await server.closeWait()
return serverRes and (data.find(expect) >= 0)
check waitFor(testPostMultipart2(initTAddress("127.0.0.1:30080"))) == true
check waitFor(testPostMultipart2()) == true
asyncTest "HTTP/1.1 pipeline test":
const TestMessages = [
("GET / HTTP/1.0\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, false, "close"),
("GET / HTTP/1.0\r\nConnection: close\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, false, "close"),
("GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, false, "close"),
("GET / HTTP/1.0\r\n\r\n",
{}, false, "close"),
("GET / HTTP/1.0\r\nConnection: close\r\n\r\n",
{}, false, "close"),
("GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
{}, false, "close"),
("GET / HTTP/1.1\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, true, "keep-alive"),
("GET / HTTP/1.1\r\nConnection: close\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, false, "close"),
("GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n",
{HttpServerFlags.Http11Pipeline}, true, "keep-alive"),
("GET / HTTP/1.1\r\n\r\n",
{}, false, "close"),
("GET / HTTP/1.1\r\nConnection: close\r\n\r\n",
{}, false, "close"),
("GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n",
{}, false, "close")
]
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isOk():
let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init())
else:
return defaultResponse()
for test in TestMessages:
let
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
serverFlags = test[1]
res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
socketFlags = socketFlags,
serverFlags = serverFlags)
check res.isOk()
let
server = res.get()
address = server.instance.localAddress()
server.start()
var transp: StreamTransport
try:
transp = await connect(address)
block:
let response = await transp.httpClient2(test[0], 7)
check:
response.data == "TEST_OK"
response.headers.getString("connection") == test[3]
# We do this sleeping here just because we running both server and
# client in single process, so when we received response from server
# it does not mean that connection has been immediately closed - it
# takes some more calls, so we trying to get this calls happens.
await sleepAsync(50.milliseconds)
let connectionStillAvailable =
try:
let response {.used.} = await transp.httpClient2(test[0], 7)
true
except CatchableError:
false
check connectionStillAvailable == test[2]
finally:
if not(isNil(transp)):
await transp.closeWait()
await server.stop()
await server.closeWait()
asyncTest "HTTP debug tests":
const
TestsCount = 10
TestRequest = "GET /httpdebug HTTP/1.1\r\nConnection: keep-alive\r\n\r\n"
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isOk():
let request = r.get()
return await request.respond(Http200, "TEST_OK", HttpTable.init())
else:
return defaultResponse()
proc client(address: TransportAddress,
data: string): Future[StreamTransport] {.async.} =
var transp: StreamTransport
var buffer = newSeq[byte](4096)
var sep = @[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
try:
transp = await connect(address)
let wres {.used.} =
await transp.write(data)
let hres {.used.} =
await transp.readUntil(addr buffer[0], len(buffer), sep)
transp
except CatchableError:
if not(isNil(transp)): await transp.closeWait()
nil
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let res = HttpServerRef.new(initTAddress("127.0.0.1:0"), process,
serverFlags = {HttpServerFlags.Http11Pipeline},
socketFlags = socketFlags)
check res.isOk()
let server = res.get()
server.start()
let address = server.instance.localAddress()
let info = server.getServerInfo()
check:
info.connectionType == ConnectionType.NonSecure
info.address == address
info.state == HttpServerState.ServerRunning
info.flags == {HttpServerFlags.Http11Pipeline}
info.socketFlags == socketFlags
try:
var clientFutures: seq[Future[StreamTransport]]
for i in 0 ..< TestsCount:
clientFutures.add(client(address, TestRequest))
await allFutures(clientFutures)
let connections = server.getConnections()
check len(connections) == TestsCount
let currentTime = Moment.now()
for index, connection in connections.pairs():
let transp = clientFutures[index].read()
check:
connection.remoteAddress.get() == transp.localAddress()
connection.localAddress.get() == transp.remoteAddress()
connection.connectionType == ConnectionType.NonSecure
connection.connectionState == ConnectionState.Alive
connection.query.get("") == "/httpdebug"
(currentTime - connection.createMoment.get()) != ZeroDuration
(currentTime - connection.acceptMoment) != ZeroDuration
var pending: seq[Future[void]]
for transpFut in clientFutures:
pending.add(closeWait(transpFut.read()))
await allFutures(pending)
finally:
await server.stop()
await server.closeWait()
test "Leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
checkLeaks()

View File

@ -177,6 +177,10 @@ suite "Macro transformations test suite":
of false: await implicit7(v)
of true: 42
proc implicit9(): Future[int] {.async.} =
result = 42
result
let fin = new int
check:
waitFor(implicit()) == 42
@ -193,6 +197,8 @@ suite "Macro transformations test suite":
waitFor(implicit8(true)) == 42
waitFor(implicit8(false)) == 33
waitFor(implicit9()) == 42
suite "Closure iterator's exception transformation issues":
test "Nested defer/finally not called on return":
# issue #288

View File

@ -2,6 +2,8 @@
IF /I "%1" == "STDIN" (
GOTO :STDINTEST
) ELSE IF /I "%1" == "TIMEOUT1" (
GOTO :TIMEOUTTEST1
) ELSE IF /I "%1" == "TIMEOUT2" (
GOTO :TIMEOUTTEST2
) ELSE IF /I "%1" == "TIMEOUT10" (
@ -19,6 +21,10 @@ SET /P "INPUTDATA="
ECHO STDIN DATA: %INPUTDATA%
EXIT 0
:TIMEOUTTEST1
ping -n 1 127.0.0.1 > NUL
EXIT 1
:TIMEOUTTEST2
ping -n 2 127.0.0.1 > NUL
EXIT 2

View File

@ -6,8 +6,9 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/os
import unittest2, stew/[base10, byteutils]
import stew/[base10, byteutils]
import ".."/chronos/unittest2/asynctests
import ".."/chronos/asyncproc
when defined(posix):
from ".."/chronos/osdefs import SIGKILL
@ -96,7 +97,11 @@ suite "Asynchronous process management test suite":
let
options = {AsyncProcessOption.EvalCommand}
command = "exit 1"
command =
when defined(windows):
"tests\\testproc.bat timeout1"
else:
"tests/testproc.sh timeout1"
process = await startProcess(command, options = options)
@ -407,6 +412,52 @@ suite "Asynchronous process management test suite":
finally:
await process.closeWait()
asyncTest "killAndWaitForExit() test":
let command =
when defined(windows):
("tests\\testproc.bat", "timeout10", 0)
else:
("tests/testproc.sh", "timeout10", 128 + int(SIGKILL))
let process = await startProcess(command[0], arguments = @[command[1]])
try:
let exitCode = await process.killAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
asyncTest "terminateAndWaitForExit() test":
let command =
when defined(windows):
("tests\\testproc.bat", "timeout10", 0)
else:
("tests/testproc.sh", "timeout10", 128 + int(SIGTERM))
let process = await startProcess(command[0], arguments = @[command[1]])
try:
let exitCode = await process.terminateAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
asyncTest "terminateAndWaitForExit() timeout test":
when defined(windows):
skip()
else:
let
command = ("tests/testproc.sh", "noterm", 128 + int(SIGKILL))
process = await startProcess(command[0], arguments = @[command[1]])
# We should wait here to allow `bash` execute `trap` command, otherwise
# our test script will be killed with SIGTERM. Increase this timeout
# if test become flaky.
await sleepAsync(1.seconds)
try:
expect AsyncProcessTimeoutError:
let exitCode {.used.} =
await process.terminateAndWaitForExit(1.seconds)
let exitCode = await process.killAndWaitForExit(10.seconds)
check exitCode == command[2]
finally:
await process.closeWait()
test "File descriptors leaks test":
when defined(windows):
skip()
@ -414,12 +465,4 @@ suite "Asynchronous process management test suite":
check getCurrentFD() == markFD
test "Leaks test":
proc getTrackerLeaks(tracker: string): bool =
let tracker = getTracker(tracker)
if isNil(tracker): false else: tracker.isLeaked()
check:
getTrackerLeaks("async.process") == false
getTrackerLeaks("async.stream.reader") == false
getTrackerLeaks("async.stream.writer") == false
getTrackerLeaks("stream.transport") == false
checkLeaks()

View File

@ -3,6 +3,9 @@
if [ "$1" == "stdin" ]; then
read -r inputdata
echo "STDIN DATA: $inputdata"
elif [ "$1" == "timeout1" ]; then
sleep 1
exit 1
elif [ "$1" == "timeout2" ]; then
sleep 2
exit 2
@ -15,6 +18,11 @@ elif [ "$1" == "bigdata" ]; then
done
elif [ "$1" == "envtest" ]; then
echo "$CHRONOSASYNC"
elif [ "$1" == "noterm" ]; then
trap -- '' SIGTERM
while true; do
sleep 1
done
else
echo "arguments missing"
fi

View File

@ -15,22 +15,23 @@ import ../chronos/ratelimit
suite "Token Bucket":
test "Sync test":
var bucket = TokenBucket.new(1000, 1.milliseconds)
let
start = Moment.now()
fullTime = start + 1.milliseconds
check:
bucket.tryConsume(800) == true
bucket.tryConsume(200) == true
bucket.tryConsume(800, start) == true
bucket.tryConsume(200, start) == true
# Out of budget
bucket.tryConsume(100) == false
waitFor(sleepAsync(10.milliseconds))
check:
bucket.tryConsume(800) == true
bucket.tryConsume(200) == true
bucket.tryConsume(100, start) == false
bucket.tryConsume(800, fullTime) == true
bucket.tryConsume(200, fullTime) == true
# Out of budget
bucket.tryConsume(100) == false
bucket.tryConsume(100, fullTime) == false
test "Async test":
var bucket = TokenBucket.new(1000, 500.milliseconds)
var bucket = TokenBucket.new(1000, 1000.milliseconds)
check: bucket.tryConsume(1000) == true
var toWait = newSeq[Future[void]]()
@ -41,28 +42,26 @@ suite "Token Bucket":
waitFor(allFutures(toWait))
let duration = Moment.now() - start
check: duration in 700.milliseconds .. 1100.milliseconds
check: duration in 1400.milliseconds .. 2200.milliseconds
test "Over budget async":
var bucket = TokenBucket.new(100, 10.milliseconds)
var bucket = TokenBucket.new(100, 100.milliseconds)
# Consume 10* the budget cap
let beforeStart = Moment.now()
waitFor(bucket.consume(1000).wait(1.seconds))
when not defined(macosx):
# CI's macos scheduler is so jittery that this tests sometimes takes >500ms
# the test will still fail if it's >1 seconds
check Moment.now() - beforeStart in 90.milliseconds .. 150.milliseconds
waitFor(bucket.consume(1000).wait(5.seconds))
check Moment.now() - beforeStart in 900.milliseconds .. 1500.milliseconds
test "Sync manual replenish":
var bucket = TokenBucket.new(1000, 0.seconds)
let start = Moment.now()
check:
bucket.tryConsume(1000) == true
bucket.tryConsume(1000) == false
bucket.tryConsume(1000, start) == true
bucket.tryConsume(1000, start) == false
bucket.replenish(2000)
check:
bucket.tryConsume(1000) == true
bucket.tryConsume(1000, start) == true
# replenish is capped to the bucket max
bucket.tryConsume(1000) == false
bucket.tryConsume(1000, start) == false
test "Async manual replenish":
var bucket = TokenBucket.new(10 * 150, 0.seconds)
@ -102,24 +101,25 @@ suite "Token Bucket":
test "Very long replenish":
var bucket = TokenBucket.new(7000, 1.hours)
check bucket.tryConsume(7000)
check bucket.tryConsume(1) == false
let start = Moment.now()
check bucket.tryConsume(7000, start)
check bucket.tryConsume(1, start) == false
# With this setting, it takes 514 milliseconds
# to tick one. Check that we can eventually
# consume, even if we update multiple time
# before that
let start = Moment.now()
while Moment.now() - start >= 514.milliseconds:
check bucket.tryConsume(1) == false
waitFor(sleepAsync(10.milliseconds))
var fakeNow = start
while fakeNow - start < 514.milliseconds:
check bucket.tryConsume(1, fakeNow) == false
fakeNow += 30.milliseconds
check bucket.tryConsume(1) == false
check bucket.tryConsume(1, fakeNow) == true
test "Short replenish":
var bucket = TokenBucket.new(15000, 1.milliseconds)
check bucket.tryConsume(15000)
check bucket.tryConsume(1) == false
let start = Moment.now()
check bucket.tryConsume(15000, start)
check bucket.tryConsume(1, start) == false
waitFor(sleepAsync(1.milliseconds))
check bucket.tryConsume(15000) == true
check bucket.tryConsume(15000, start + 1.milliseconds) == true

View File

@ -6,8 +6,9 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/strutils
import unittest2
import ../chronos, ../chronos/apps/http/shttpserver
import ".."/chronos/unittest2/asynctests
import ".."/chronos,
".."/chronos/apps/http/shttpserver
import stew/base10
{.used.}
@ -115,7 +116,7 @@ suite "Secure HTTP server testing suite":
HttpTable.init())
else:
serverRes = false
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure}
@ -154,7 +155,7 @@ suite "Secure HTTP server testing suite":
else:
serverRes = true
testFut.complete()
return dumbResponse()
return defaultResponse()
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
let serverFlags = {Secure}
@ -178,3 +179,6 @@ suite "Secure HTTP server testing suite":
return serverRes and data == "EXCEPTION"
check waitFor(testHTTPS2(initTAddress("127.0.0.1:30080"))) == true
test "Leaks test":
checkLeaks()

View File

@ -6,7 +6,7 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[strutils, os]
import unittest2
import ".."/chronos/unittest2/asynctests
import ".."/chronos, ".."/chronos/[osdefs, oserrno]
{.used.}
@ -34,7 +34,7 @@ suite "Stream Transport test suite":
]
else:
let addresses = [
initTAddress("127.0.0.1:33335"),
initTAddress("127.0.0.1:0"),
initTAddress(r"/tmp/testpipe")
]
@ -43,7 +43,7 @@ suite "Stream Transport test suite":
var markFD: int
proc getCurrentFD(): int =
let local = initTAddress("127.0.0.1:33334")
let local = initTAddress("127.0.0.1:0")
let sock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM,
Protocol.IPPROTO_UDP)
closeSocket(sock)
@ -348,7 +348,7 @@ suite "Stream Transport test suite":
proc test1(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient1, {ReuseAddr})
server.start()
result = await swarmManager1(address)
result = await swarmManager1(server.local)
server.stop()
server.close()
await server.join()
@ -356,7 +356,7 @@ suite "Stream Transport test suite":
proc test2(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient2, {ReuseAddr})
server.start()
result = await swarmManager2(address)
result = await swarmManager2(server.local)
server.stop()
server.close()
await server.join()
@ -364,7 +364,7 @@ suite "Stream Transport test suite":
proc test3(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient3, {ReuseAddr})
server.start()
result = await swarmManager3(address)
result = await swarmManager3(server.local)
server.stop()
server.close()
await server.join()
@ -372,7 +372,7 @@ suite "Stream Transport test suite":
proc testSendFile(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient4, {ReuseAddr})
server.start()
result = await swarmManager4(address)
result = await swarmManager4(server.local)
server.stop()
server.close()
await server.join()
@ -414,7 +414,7 @@ suite "Stream Transport test suite":
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
result = await swarmManager(address)
result = await swarmManager(server.local)
await server.join()
proc testWCR(address: TransportAddress): Future[int] {.async.} =
@ -456,13 +456,13 @@ suite "Stream Transport test suite":
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
result = await swarmManager(address)
result = await swarmManager(server.local)
await server.join()
proc test7(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient7, {ReuseAddr})
server.start()
result = await swarmWorker7(address)
result = await swarmWorker7(server.local)
server.stop()
server.close()
await server.join()
@ -470,7 +470,7 @@ suite "Stream Transport test suite":
proc test8(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient8, {ReuseAddr})
server.start()
result = await swarmWorker8(address)
result = await swarmWorker8(server.local)
await server.join()
# proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} =
@ -553,7 +553,7 @@ suite "Stream Transport test suite":
proc test11(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient11, {ReuseAddr})
server.start()
result = await swarmWorker11(address)
result = await swarmWorker11(server.local)
server.stop()
server.close()
await server.join()
@ -579,7 +579,7 @@ suite "Stream Transport test suite":
proc test12(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient12, {ReuseAddr})
server.start()
result = await swarmWorker12(address)
result = await swarmWorker12(server.local)
server.stop()
server.close()
await server.join()
@ -601,7 +601,7 @@ suite "Stream Transport test suite":
proc test13(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient13, {ReuseAddr})
server.start()
result = await swarmWorker13(address)
result = await swarmWorker13(server.local)
server.stop()
server.close()
await server.join()
@ -621,7 +621,7 @@ suite "Stream Transport test suite":
subres = 0
server.start()
var transp = await connect(address)
var transp = await connect(server.local)
var fut = swarmWorker(transp)
# We perfrom shutdown(SHUT_RD/SD_RECEIVE) for the socket, in such way its
# possible to emulate socket's EOF.
@ -674,7 +674,7 @@ suite "Stream Transport test suite":
proc test16(address: TransportAddress): Future[int] {.async.} =
var server = createStreamServer(address, serveClient16, {ReuseAddr})
server.start()
result = await swarmWorker16(address)
result = await swarmWorker16(server.local)
server.stop()
server.close()
await server.join()
@ -701,7 +701,7 @@ suite "Stream Transport test suite":
var server = createStreamServer(address, client, {ReuseAddr})
server.start()
var msg = "HELLO"
var ntransp = await connect(address)
var ntransp = await connect(server.local)
await syncFut
while true:
var res = await ntransp.write(msg)
@ -763,7 +763,7 @@ suite "Stream Transport test suite":
var transp: StreamTransport
try:
transp = await connect(address)
transp = await connect(server.local)
flag = true
except CatchableError:
server.stop()
@ -796,31 +796,31 @@ suite "Stream Transport test suite":
server.start()
try:
var r1, r2, r3, r4, r5: string
var t1 = await connect(address)
var t1 = await connect(server.local)
try:
r1 = await t1.readLine(4)
finally:
await t1.closeWait()
var t2 = await connect(address)
var t2 = await connect(server.local)
try:
r2 = await t2.readLine(6)
finally:
await t2.closeWait()
var t3 = await connect(address)
var t3 = await connect(server.local)
try:
r3 = await t3.readLine(8)
finally:
await t3.closeWait()
var t4 = await connect(address)
var t4 = await connect(server.local)
try:
r4 = await t4.readLine(8)
finally:
await t4.closeWait()
var t5 = await connect(address)
var t5 = await connect(server.local)
try:
r5 = await t5.readLine()
finally:
@ -945,7 +945,7 @@ suite "Stream Transport test suite":
var server = createStreamServer(address, serveClient, {ReuseAddr})
server.start()
var t1 = await connect(address)
var t1 = await connect(server.local)
try:
discard await t1.readLV(2000)
except TransportIncompleteError:
@ -959,7 +959,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t2 = await connect(address)
var t2 = await connect(server.local)
try:
var r2 = await t2.readLV(2000)
c2 = (r2 == @[])
@ -972,7 +972,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t3 = await connect(address)
var t3 = await connect(server.local)
try:
discard await t3.readLV(2000)
except TransportIncompleteError:
@ -986,7 +986,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t4 = await connect(address)
var t4 = await connect(server.local)
try:
discard await t4.readLV(2000)
except TransportIncompleteError:
@ -1000,7 +1000,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t5 = await connect(address)
var t5 = await connect(server.local)
try:
discard await t5.readLV(1000)
except ValueError:
@ -1014,7 +1014,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t6 = await connect(address)
var t6 = await connect(server.local)
try:
var expectMsg = createMessage(1024)
var r6 = await t6.readLV(2000)
@ -1029,7 +1029,7 @@ suite "Stream Transport test suite":
await server.join()
return false
var t7 = await connect(address)
var t7 = await connect(server.local)
try:
var expectMsg = createMessage(1024)
var expectDone = "DONE"
@ -1062,7 +1062,7 @@ suite "Stream Transport test suite":
try:
for i in 0 ..< TestsCount:
transp = await connect(address)
transp = await connect(server.local)
await sleepAsync(10.milliseconds)
await transp.closeWait()
inc(connected)
@ -1117,7 +1117,7 @@ suite "Stream Transport test suite":
try:
for i in 0 ..< 3:
try:
let transp = await connect(address)
let transp = await connect(server.local)
await sleepAsync(10.milliseconds)
await transp.closeWait()
except TransportTooManyError:
@ -1166,7 +1166,7 @@ suite "Stream Transport test suite":
await server.closeWait()
var acceptFut = acceptTask(server)
var transp = await connect(address)
var transp = await connect(server.local)
await server.join()
await transp.closeWait()
await acceptFut
@ -1187,7 +1187,7 @@ suite "Stream Transport test suite":
await server.closeWait()
var acceptFut = acceptTask(server)
var transp = await connect(address)
var transp = await connect(server.local)
await server.join()
await transp.closeWait()
await acceptFut
@ -1259,46 +1259,39 @@ suite "Stream Transport test suite":
return buffer == message
proc testConnectBindLocalAddress() {.async.} =
let dst1 = initTAddress("127.0.0.1:33335")
let dst2 = initTAddress("127.0.0.1:33336")
let dst3 = initTAddress("127.0.0.1:33337")
proc client(server: StreamServer, transp: StreamTransport) {.async.} =
await transp.closeWait()
# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
# running the test multiple times or if a test ran previously used the same port.
let servers =
[createStreamServer(dst1, client, {ReuseAddr}),
createStreamServer(dst2, client, {ReuseAddr}),
createStreamServer(dst3, client, {ReusePort})]
let server1 = createStreamServer(initTAddress("127.0.0.1:0"), client)
let server2 = createStreamServer(initTAddress("127.0.0.1:0"), client)
let server3 = createStreamServer(initTAddress("127.0.0.1:0"), client, {ReusePort})
for server in servers:
server.start()
server1.start()
server2.start()
server3.start()
let ta = initTAddress("0.0.0.0:35000")
# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})
# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})
# It works cause even though there's an active listening socket bound to dst3, we are using ReusePort
var transp1 = await connect(server1.local, localAddress = server3.local, flags={SocketFlags.ReusePort})
var transp2 = await connect(server2.local, localAddress = server3.local, flags={SocketFlags.ReusePort})
expect(TransportOsError):
var transp2 {.used.} = await connect(dst3, localAddress = ta)
var transp2 {.used.} = await connect(server2.local, localAddress = server3.local)
expect(TransportOsError):
var transp3 {.used.} =
await connect(dst3, localAddress = initTAddress(":::35000"))
var transp3 {.used.} = await connect(server2.local, localAddress = initTAddress("::", server3.local.port))
await transp1.closeWait()
await transp2.closeWait()
await transp3.closeWait()
for server in servers:
server.stop()
await server.closeWait()
server1.stop()
await server1.closeWait()
server2.stop()
await server2.closeWait()
server3.stop()
await server3.closeWait()
markFD = getCurrentFD()
@ -1339,7 +1332,10 @@ suite "Stream Transport test suite":
else:
skip()
else:
check waitFor(testSendFile(addresses[i])) == FilesCount
if defined(emscripten):
skip()
else:
check waitFor(testSendFile(addresses[i])) == FilesCount
test prefixes[i] & "Connection refused test":
var address: TransportAddress
if addresses[i].family == AddressFamily.Unix:
@ -1370,10 +1366,11 @@ suite "Stream Transport test suite":
test prefixes[i] & "close() while in accept() waiting test":
check waitFor(testAcceptClose(addresses[i])) == true
test prefixes[i] & "Intermediate transports leak test #1":
checkLeaks()
when defined(windows):
skip()
else:
check getTracker("stream.transport").isLeaked() == false
checkLeaks(StreamTransportTrackerName)
test prefixes[i] & "accept() too many file descriptors test":
when defined(windows):
skip()
@ -1389,10 +1386,8 @@ suite "Stream Transport test suite":
check waitFor(testPipe()) == true
test "[IP] bind connect to local address":
waitFor(testConnectBindLocalAddress())
test "Servers leak test":
check getTracker("stream.server").isLeaked() == false
test "Transports leak test":
check getTracker("stream.transport").isLeaked() == false
test "Leaks test":
checkLeaks()
test "File descriptors leak test":
when defined(windows):
# Windows handle numbers depends on many conditions, so we can't use

369
tests/testthreadsync.nim Normal file
View File

@ -0,0 +1,369 @@
# Chronos Test Suite
# (c) Copyright 2023-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[cpuinfo, locks, strutils]
import ../chronos/unittest2/asynctests
import ../chronos/threadsync
{.used.}
type
ThreadResult = object
value: int
ThreadResultPtr = ptr ThreadResult
LockPtr = ptr Lock
ThreadArg = object
signal: ThreadSignalPtr
retval: ThreadResultPtr
index: int
ThreadArg2 = object
signal1: ThreadSignalPtr
signal2: ThreadSignalPtr
retval: ThreadResultPtr
ThreadArg3 = object
lock: LockPtr
signal: ThreadSignalPtr
retval: ThreadResultPtr
index: int
WaitSendKind {.pure.} = enum
Sync, Async
const
TestsCount = 1000
suite "Asynchronous multi-threading sync primitives test suite":
proc setResult(thr: ThreadResultPtr, value: int) =
thr[].value = value
proc new(t: typedesc[ThreadResultPtr], value: int = 0): ThreadResultPtr =
var res = cast[ThreadResultPtr](allocShared0(sizeof(ThreadResult)))
res[].value = value
res
proc free(thr: ThreadResultPtr) =
doAssert(not(isNil(thr)))
deallocShared(thr)
let numProcs = countProcessors() * 2
template threadSignalTest(sendFlag, waitFlag: WaitSendKind) =
proc testSyncThread(arg: ThreadArg) {.thread.} =
let res = waitSync(arg.signal, 1500.milliseconds)
if res.isErr():
arg.retval.setResult(1)
else:
if res.get():
arg.retval.setResult(2)
else:
arg.retval.setResult(3)
proc testAsyncThread(arg: ThreadArg) {.thread.} =
proc testAsyncCode(arg: ThreadArg) {.async.} =
try:
await wait(arg.signal).wait(1500.milliseconds)
arg.retval.setResult(2)
except AsyncTimeoutError:
arg.retval.setResult(3)
except CatchableError:
arg.retval.setResult(1)
waitFor testAsyncCode(arg)
let signal = ThreadSignalPtr.new().tryGet()
var args: seq[ThreadArg]
var threads = newSeq[Thread[ThreadArg]](numProcs)
for i in 0 ..< numProcs:
let
res = ThreadResultPtr.new()
arg = ThreadArg(signal: signal, retval: res, index: i)
args.add(arg)
case waitFlag
of WaitSendKind.Sync:
createThread(threads[i], testSyncThread, arg)
of WaitSendKind.Async:
createThread(threads[i], testAsyncThread, arg)
await sleepAsync(500.milliseconds)
case sendFlag
of WaitSendKind.Sync:
check signal.fireSync().isOk()
of WaitSendKind.Async:
await signal.fire()
joinThreads(threads)
var ncheck: array[3, int]
for item in args:
if item.retval[].value == 1:
inc(ncheck[0])
elif item.retval[].value == 2:
inc(ncheck[1])
elif item.retval[].value == 3:
inc(ncheck[2])
free(item.retval)
check:
signal.close().isOk()
ncheck[0] == 0
ncheck[1] == 1
ncheck[2] == numProcs - 1
template threadSignalTest2(testsCount: int,
sendFlag, waitFlag: WaitSendKind) =
proc testSyncThread(arg: ThreadArg2) {.thread.} =
for i in 0 ..< testsCount:
block:
let res = waitSync(arg.signal1, 1500.milliseconds)
if res.isErr():
arg.retval.setResult(-1)
return
if not(res.get()):
arg.retval.setResult(-2)
return
block:
let res = arg.signal2.fireSync()
if res.isErr():
arg.retval.setResult(-3)
return
arg.retval.setResult(i + 1)
proc testAsyncThread(arg: ThreadArg2) {.thread.} =
proc testAsyncCode(arg: ThreadArg2) {.async.} =
for i in 0 ..< testsCount:
try:
await wait(arg.signal1).wait(1500.milliseconds)
except AsyncTimeoutError:
arg.retval.setResult(-2)
return
except AsyncError:
arg.retval.setResult(-1)
return
except CatchableError:
arg.retval.setResult(-3)
return
try:
await arg.signal2.fire()
except AsyncError:
arg.retval.setResult(-4)
return
except CatchableError:
arg.retval.setResult(-5)
return
arg.retval.setResult(i + 1)
waitFor testAsyncCode(arg)
let
signal1 = ThreadSignalPtr.new().tryGet()
signal2 = ThreadSignalPtr.new().tryGet()
retval = ThreadResultPtr.new()
arg = ThreadArg2(signal1: signal1, signal2: signal2, retval: retval)
var thread: Thread[ThreadArg2]
case waitFlag
of WaitSendKind.Sync:
createThread(thread, testSyncThread, arg)
of WaitSendKind.Async:
createThread(thread, testAsyncThread, arg)
let start = Moment.now()
for i in 0 ..< testsCount:
case sendFlag
of WaitSendKind.Sync:
block:
let res = signal1.fireSync()
check res.isOk()
block:
let res = waitSync(arg.signal2, 1500.milliseconds)
check:
res.isOk()
res.get() == true
of WaitSendKind.Async:
await arg.signal1.fire()
await wait(arg.signal2).wait(1500.milliseconds)
joinThreads(thread)
let finish = Moment.now()
let perf = (float64(nanoseconds(1.seconds)) /
float64(nanoseconds(finish - start))) * float64(testsCount)
echo "Switches tested: ", testsCount, ", elapsed time: ", (finish - start),
", performance = ", formatFloat(perf, ffDecimal, 4),
" switches/second"
check:
arg.retval[].value == testsCount
template threadSignalTest3(testsCount: int,
sendFlag, waitFlag: WaitSendKind) =
proc testSyncThread(arg: ThreadArg3) {.thread.} =
withLock(arg.lock[]):
let res = waitSync(arg.signal, 10.milliseconds)
if res.isErr():
arg.retval.setResult(1)
else:
if res.get():
arg.retval.setResult(2)
else:
arg.retval.setResult(3)
proc testAsyncThread(arg: ThreadArg3) {.thread.} =
proc testAsyncCode(arg: ThreadArg3) {.async.} =
withLock(arg.lock[]):
try:
await wait(arg.signal).wait(10.milliseconds)
arg.retval.setResult(2)
except AsyncTimeoutError:
arg.retval.setResult(3)
except CatchableError:
arg.retval.setResult(1)
waitFor testAsyncCode(arg)
let signal = ThreadSignalPtr.new().tryGet()
var args: seq[ThreadArg3]
var threads = newSeq[Thread[ThreadArg3]](numProcs)
var lockPtr = cast[LockPtr](allocShared0(sizeof(Lock)))
initLock(lockPtr[])
acquire(lockPtr[])
for i in 0 ..< numProcs:
let
res = ThreadResultPtr.new()
arg = ThreadArg3(signal: signal, retval: res, index: i, lock: lockPtr)
args.add(arg)
case waitFlag
of WaitSendKind.Sync:
createThread(threads[i], testSyncThread, arg)
of WaitSendKind.Async:
createThread(threads[i], testAsyncThread, arg)
await sleepAsync(500.milliseconds)
case sendFlag
of WaitSendKind.Sync:
for i in 0 ..< testsCount:
check signal.fireSync().isOk()
of WaitSendKind.Async:
for i in 0 ..< testsCount:
await signal.fire()
release(lockPtr[])
joinThreads(threads)
deinitLock(lockPtr[])
deallocShared(lockPtr)
var ncheck: array[3, int]
for item in args:
if item.retval[].value == 1:
inc(ncheck[0])
elif item.retval[].value == 2:
inc(ncheck[1])
elif item.retval[].value == 3:
inc(ncheck[2])
free(item.retval)
check:
signal.close().isOk()
ncheck[0] == 0
ncheck[1] == 1
ncheck[2] == numProcs - 1
template threadSignalTest4(testsCount: int,
sendFlag, waitFlag: WaitSendKind) =
let signal = ThreadSignalPtr.new().tryGet()
let start = Moment.now()
for i in 0 ..< testsCount:
case sendFlag
of WaitSendKind.Sync:
check signal.fireSync().isOk()
of WaitSendKind.Async:
await signal.fire()
case waitFlag
of WaitSendKind.Sync:
check waitSync(signal).isOk()
of WaitSendKind.Async:
await wait(signal)
let finish = Moment.now()
let perf = (float64(nanoseconds(1.seconds)) /
float64(nanoseconds(finish - start))) * float64(testsCount)
echo "Switches tested: ", testsCount, ", elapsed time: ", (finish - start),
", performance = ", formatFloat(perf, ffDecimal, 4),
" switches/second"
check:
signal.close.isOk()
asyncTest "ThreadSignal: Multiple [" & $numProcs &
"] threads waiting test [sync -> sync]":
threadSignalTest(WaitSendKind.Sync, WaitSendKind.Sync)
asyncTest "ThreadSignal: Multiple [" & $numProcs &
"] threads waiting test [async -> async]":
threadSignalTest(WaitSendKind.Async, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple [" & $numProcs &
"] threads waiting test [async -> sync]":
threadSignalTest(WaitSendKind.Async, WaitSendKind.Sync)
asyncTest "ThreadSignal: Multiple [" & $numProcs &
"] threads waiting test [sync -> async]":
threadSignalTest(WaitSendKind.Sync, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [sync -> sync]":
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync)
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [async -> async]":
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [sync -> async]":
threadSignalTest2(TestsCount, WaitSendKind.Sync, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple thread switches [" & $TestsCount &
"] test [async -> sync]":
threadSignalTest2(TestsCount, WaitSendKind.Async, WaitSendKind.Sync)
asyncTest "ThreadSignal: Multiple signals [" & $TestsCount &
"] to multiple threads [" & $numProcs & "] test [sync -> sync]":
threadSignalTest3(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync)
asyncTest "ThreadSignal: Multiple signals [" & $TestsCount &
"] to multiple threads [" & $numProcs & "] test [async -> async]":
threadSignalTest3(TestsCount, WaitSendKind.Async, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple signals [" & $TestsCount &
"] to multiple threads [" & $numProcs & "] test [sync -> async]":
threadSignalTest3(TestsCount, WaitSendKind.Sync, WaitSendKind.Async)
asyncTest "ThreadSignal: Multiple signals [" & $TestsCount &
"] to multiple threads [" & $numProcs & "] test [async -> sync]":
threadSignalTest3(TestsCount, WaitSendKind.Async, WaitSendKind.Sync)
asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount &
"] test [sync -> sync]":
threadSignalTest4(TestsCount, WaitSendKind.Sync, WaitSendKind.Sync)
asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount &
"] test [sync -> sync]":
threadSignalTest4(TestsCount, WaitSendKind.Async, WaitSendKind.Async)
asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount &
"] test [sync -> async]":
threadSignalTest4(TestsCount, WaitSendKind.Sync, WaitSendKind.Async)
asyncTest "ThreadSignal: Single threaded switches [" & $TestsCount &
"] test [async -> sync]":
threadSignalTest4(TestsCount, WaitSendKind.Async, WaitSendKind.Sync)