diff --git a/chronos/apps/http/httpagent.nim b/chronos/apps/http/httpagent.nim new file mode 100644 index 00000000..c8cac48f --- /dev/null +++ b/chronos/apps/http/httpagent.nim @@ -0,0 +1,24 @@ +# +# Chronos HTTP/S client implementation +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import strutils + +const + ChronosName* = "nim-chronos" + ## Project name string + ChronosMajor* {.intdefine.}: int = 3 + ## Major number of Chronos' version. + ChronosMinor* {.intdefine.}: int = 0 + ## Minor number of Chronos' version. + ChronosPatch* {.intdefine.}: int = 2 + ## Patch number of Chronos' version. + ChronosVersion* = $ChronosMajor & "." & $ChronosMinor & "." & $ChronosPatch + ## Version of Chronos as a string. + ChronosIdent* = "$1/$2 ($3/$4)" % [ChronosName, ChronosVersion, hostCPU, + hostOS] + ## Project ident name for networking services diff --git a/chronos/apps/http/httpbodyrw.nim b/chronos/apps/http/httpbodyrw.nim new file mode 100644 index 00000000..efa043ea --- /dev/null +++ b/chronos/apps/http/httpbodyrw.nim @@ -0,0 +1,146 @@ +# +# Chronos HTTP/S body reader/writer +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import ../../asyncloop, ../../asyncsync +import ../../streams/[asyncstream, boundstream] + +const + HttpBodyReaderTrackerName* = "http.body.reader" + ## HTTP body reader leaks tracker name + HttpBodyWriterTrackerName* = "http.body.writer" + ## HTTP body writer leaks tracker name + +type + HttpBodyReader* = ref object of AsyncStreamReader + streams*: seq[AsyncStreamReader] + + HttpBodyWriter* = ref object of AsyncStreamWriter + streams*: seq[AsyncStreamWriter] + + HttpBodyTracker* = ref object of TrackerBase + opened*: int64 + closed*: int64 + +proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe, raises: [Defect].} +proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe, raises: [Defect].} + +proc getHttpBodyWriterTracker(): HttpBodyTracker {.inline.} = + var res = cast[HttpBodyTracker](getTracker(HttpBodyWriterTrackerName)) + if isNil(res): + res = setupHttpBodyWriterTracker() + res + +proc getHttpBodyReaderTracker(): HttpBodyTracker {.inline.} = + var res = cast[HttpBodyTracker](getTracker(HttpBodyReaderTrackerName)) + if isNil(res): + res = setupHttpBodyReaderTracker() + res + +proc dumpHttpBodyWriterTracking(): string {.gcsafe.} = + let tracker = getHttpBodyWriterTracker() + "Opened HTTP body writers: " & $tracker.opened & "\n" & + "Closed HTTP body writers: " & $tracker.closed + +proc dumpHttpBodyReaderTracking(): string {.gcsafe.} = + let tracker = getHttpBodyReaderTracker() + "Opened HTTP body readers: " & $tracker.opened & "\n" & + "Closed HTTP body readers: " & $tracker.closed + +proc leakHttpBodyWriter(): bool {.gcsafe.} = + var tracker = getHttpBodyWriterTracker() + tracker.opened != tracker.closed + +proc leakHttpBodyReader(): bool {.gcsafe.} = + var tracker = getHttpBodyReaderTracker() + tracker.opened != tracker.closed + +proc trackHttpBodyWriter(t: HttpBodyWriter) {.inline.} = + inc(getHttpBodyWriterTracker().opened) + +proc untrackHttpBodyWriter*(t: HttpBodyWriter) {.inline.} = + inc(getHttpBodyWriterTracker().closed) + +proc trackHttpBodyReader(t: HttpBodyReader) {.inline.} = + inc(getHttpBodyReaderTracker().opened) + +proc untrackHttpBodyReader*(t: HttpBodyReader) {.inline.} = + inc(getHttpBodyReaderTracker().closed) + +proc setupHttpBodyWriterTracker(): HttpBodyTracker {.gcsafe.} = + var res = HttpBodyTracker(opened: 0, closed: 0, + dump: dumpHttpBodyWriterTracking, + isLeaked: leakHttpBodyWriter + ) + addTracker(HttpBodyWriterTrackerName, res) + res + +proc setupHttpBodyReaderTracker(): HttpBodyTracker {.gcsafe.} = + var res = HttpBodyTracker(opened: 0, closed: 0, + dump: dumpHttpBodyReaderTracking, + isLeaked: leakHttpBodyReader + ) + addTracker(HttpBodyReaderTrackerName, res) + res + +proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = + ## HttpBodyReader is AsyncStreamReader which holds references to all the + ## ``streams``. Also on close it will close all the ``streams``. + ## + ## First stream in sequence will be used as a source. + doAssert(len(streams) > 0, "At least one stream must be added") + var res = HttpBodyReader(streams: @streams) + res.init(streams[0]) + trackHttpBodyReader(res) + res + +proc closeWait*(bstream: HttpBodyReader) {.async.} = + ## Close and free resource allocated by body reader. + var res = newSeq[Future[void]]() + # We closing streams in reversed order because stream at position [0], uses + # data from stream at position [1]. + for index in countdown((len(bstream.streams) - 1), 0): + res.add(bstream.streams[index].closeWait()) + await allFutures(res) + await procCall(closeWait(AsyncStreamReader(bstream))) + untrackHttpBodyReader(bstream) + +proc newHttpBodyWriter*(streams: varargs[AsyncStreamWriter]): HttpBodyWriter = + ## HttpBodyWriter is AsyncStreamWriter which holds references to all the + ## ``streams``. Also on close it will close all the ``streams``. + ## + ## First stream in sequence will be used as a destination. + doAssert(len(streams) > 0, "At least one stream must be added") + var res = HttpBodyWriter(streams: @streams) + res.init(streams[0]) + trackHttpBodyWriter(res) + res + +proc closeWait*(bstream: HttpBodyWriter) {.async.} = + ## Close and free all the resources allocated by body writer. + var res = newSeq[Future[void]]() + for index in countdown(len(bstream.streams) - 1, 0): + res.add(bstream.streams[index].closeWait()) + await allFutures(res) + await procCall(closeWait(AsyncStreamWriter(bstream))) + untrackHttpBodyWriter(bstream) + +proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [Defect].} = + if len(bstream.streams) == 1: + # If HttpBodyReader has only one stream it has ``BoundedStreamReader``, in + # such case its impossible to get more bytes then expected amount. + false + else: + # If HttpBodyReader has two or more streams, we check if + # ``BoundedStreamReader`` at EOF. + if bstream.streams[0].atEof(): + for i in 1 ..< len(bstream.streams): + if not(bstream.streams[1].atEof()): + return true + false + else: + false diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim new file mode 100644 index 00000000..25fa837e --- /dev/null +++ b/chronos/apps/http/httpclient.nim @@ -0,0 +1,1233 @@ +# +# Chronos HTTP/S client +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[uri, tables, strutils, sequtils] +import stew/[results, base10], httputils +import ../../asyncloop, ../../asyncsync +import ../../streams/[asyncstream, tlsstream, chunkstream, boundstream] +import httptable, httpcommon, httpagent, httpbodyrw, multipart +export httptable, httpcommon, httpagent, httpbodyrw, multipart + +const + HttpMaxHeadersSize* = 8192 + ## Maximum size of HTTP headers in octets + HttpConnectTimeout* = 12.seconds + ## Timeout for connecting to host (12 sec) + HttpHeadersTimeout* = 120.seconds + ## Timeout for receiving response headers (120 sec) + HttpMaxRedirections* = 10 + ## Maximum number of Location redirections. + HttpClientConnectionTrackerName* = "httpclient.connection" + ## HttpClient connection leaks tracker name + HttpClientRequestTrackerName* = "httpclient.request" + ## HttpClient request leaks tracker name + HttpClientResponseTrackerName* = "httpclient.response" + ## HttpClient response leaks tracker name + +type + HttpClientConnectionState* {.pure.} = enum + Closed ## Connection has been closed + Resolving, ## Resolving remote hostname + Connecting, ## Connecting to remote server + Ready, ## Connected to remote server + RequestHeadersSending, ## Sending request headers + RequestHeadersSent, ## Request headers has been sent + RequestBodySending, ## Sending request body + RequestBodySent, ## Request body has been sent + ResponseHeadersReceiving, ## Receiving response headers + ResponseHeadersReceived, ## Response headers has been received + ResponseBodyReceiving, ## Receiving response body + ResponseBodyReceived, ## Response body has been received + Error ## Error happens + + HttpClientScheme* {.pure.} = enum + NonSecure, ## Non-secure connection + Secure ## Secure TLS connection + + HttpClientRequestState* {.pure.} = enum + Closed, ## Request has been closed + Created, ## Request created + Connecting, ## Connecting to remote host + HeadersSending, ## Sending request headers + HeadersSent, ## Request headers has been sent + BodySending, ## Sending request body + BodySent, ## Request body has been sent + ResponseReceived, ## Request's response headers received + Error ## Error happens + + HttpClientResponseState* {.pure.} = enum + Closed, ## Response has been closed + HeadersReceived, ## Response headers received + BodyReceiving, ## Response body receiving + BodyReceived, ## Response body received + Error ## Error happens + + HttpClientBodyFlag* {.pure.} = enum + Sized, ## `Content-Length` present + Chunked, ## `Transfer-Encoding: chunked` present + Custom ## None of the above + + HttpClientRequestFlag* {.pure.} = enum + CloseConnection, ## Send `Connection: close` in request + + HttpHeaderTuple* = tuple + key: string + value: string + + HttpResponseTuple* = tuple + status: int + data: seq[byte] + + HttpClientConnection* = object of RootObj + case kind*: HttpClientScheme + of HttpClientScheme.NonSecure: + discard + of HttpClientScheme.Secure: + treader*: AsyncStreamReader + twriter*: AsyncStreamWriter + tls*: TLSAsyncStream + transp*: StreamTransport + reader*: AsyncStreamReader + writer*: AsyncStreamWriter + state*: HttpClientConnectionState + error*: ref HttpError + remoteHostname*: string + + HttpClientConnectionRef* = ref HttpClientConnection + + HttpSessionRef* = ref object + connections*: Table[string, seq[HttpClientConnectionRef]] + maxRedirections*: int + connectTimeout*: Duration + headersTimeout*: Duration + connectionBufferSize*: int + maxConnections*: int + flags*: HttpClientFlags + + HttpAddress* = object + id*: string + scheme*: HttpClientScheme + hostname*: string + port*: uint16 + path*: string + query*: string + anchor*: string + username*: string + password*: string + addresses*: seq[TransportAddress] + + HttpClientRequest* = object + meth*: HttpMethod + address: HttpAddress + state: HttpClientRequestState + version*: HttpVersion + headers*: HttpTable + bodyFlag: HttpClientBodyFlag + flags: set[HttpClientRequestFlag] + connection*: HttpClientConnectionRef + session*: HttpSessionRef + error*: ref HttpError + buffer*: seq[byte] + writer*: HttpBodyWriter + redirectCount: int + + HttpClientRequestRef* = ref HttpClientRequest + + HttpClientResponse* = object + state: HttpClientResponseState + requestMethod*: HttpMethod + address*: HttpAddress + status*: int + reason*: string + version*: HttpVersion + headers*: HttpTable + connection*: HttpClientConnectionRef + session*: HttpSessionRef + reader*: HttpBodyReader + error*: ref HttpError + bodyFlag*: HttpClientBodyFlag + contentEncoding*: set[ContentEncodingFlags] + transferEncoding*: set[TransferEncodingFlags] + contentLength*: uint64 + + HttpClientResponseRef* = ref HttpClientResponse + + HttpClientFlag* {.pure.} = enum + NoVerifyHost, ## Skip remote server certificate verification + NoVerifyServerName, ## Skip remote server name CN verification + NoInet4Resolution, ## Do not resolve server hostname to IPv4 addresses + NoInet6Resolution, ## Do not resolve server hostname to IPv6 addresses + NoAutomaticRedirect ## Do not handle HTTP redirection automatically + + HttpClientFlags* = set[HttpClientFlag] + + HttpClientTracker* = ref object of TrackerBase + opened*: int64 + closed*: int64 + +proc setupHttpClientConnectionTracker(): HttpClientTracker {. + gcsafe, raises: [Defect].} +proc setupHttpClientRequestTracker(): HttpClientTracker {. + gcsafe, raises: [Defect].} +proc setupHttpClientResponseTracker(): HttpClientTracker {. + gcsafe, raises: [Defect].} + +proc getHttpClientConnectionTracker(): HttpClientTracker {.inline.} = + var res = cast[HttpClientTracker](getTracker(HttpClientConnectionTrackerName)) + if isNil(res): + res = setupHttpClientConnectionTracker() + res + +proc getHttpClientRequestTracker(): HttpClientTracker {.inline.} = + var res = cast[HttpClientTracker](getTracker(HttpClientRequestTrackerName)) + if isNil(res): + res = setupHttpClientRequestTracker() + res + +proc getHttpClientResponseTracker(): HttpClientTracker {.inline.} = + var res = cast[HttpClientTracker](getTracker(HttpClientResponseTrackerName)) + if isNil(res): + res = setupHttpClientResponseTracker() + res + +proc dumpHttpClientConnectionTracking(): string {.gcsafe.} = + let tracker = getHttpClientConnectionTracker() + "Opened HTTP client connections: " & $tracker.opened & "\n" & + "Closed HTTP client connections: " & $tracker.closed + +proc dumpHttpClientRequestTracking(): string {.gcsafe.} = + let tracker = getHttpClientRequestTracker() + "Opened HTTP client requests: " & $tracker.opened & "\n" & + "Closed HTTP client requests: " & $tracker.closed + +proc dumpHttpClientResponseTracking(): string {.gcsafe.} = + let tracker = getHttpClientResponseTracker() + "Opened HTTP client responses: " & $tracker.opened & "\n" & + "Closed HTTP client responses: " & $tracker.closed + +proc leakHttpClientConnection(): bool {.gcsafe.} = + var tracker = getHttpClientConnectionTracker() + tracker.opened != tracker.closed + +proc leakHttpClientRequest(): bool {.gcsafe.} = + var tracker = getHttpClientRequestTracker() + tracker.opened != tracker.closed + +proc leakHttpClientResponse(): bool {.gcsafe.} = + var tracker = getHttpClientResponseTracker() + tracker.opened != tracker.closed + +proc trackHttpClientConnection(t: HttpClientConnectionRef) {.inline.} = + inc(getHttpClientConnectionTracker().opened) + +proc untrackHttpClientConnection*(t: HttpClientConnectionRef) {.inline.} = + inc(getHttpClientConnectionTracker().closed) + +proc trackHttpClientRequest(t: HttpClientRequestRef) {.inline.} = + inc(getHttpClientRequestTracker().opened) + +proc untrackHttpClientRequest*(t: HttpClientRequestRef) {.inline.} = + inc(getHttpClientRequestTracker().closed) + +proc trackHttpClientResponse(t: HttpClientResponseRef) {.inline.} = + inc(getHttpClientResponseTracker().opened) + +proc untrackHttpClientResponse*(t: HttpClientResponseRef) {.inline.} = + inc(getHttpClientResponseTracker().closed) + +proc setupHttpClientConnectionTracker(): HttpClientTracker {.gcsafe.} = + var res = HttpClientTracker(opened: 0, closed: 0, + dump: dumpHttpClientConnectionTracking, + isLeaked: leakHttpClientConnection + ) + addTracker(HttpClientConnectionTrackerName, res) + res + +proc setupHttpClientRequestTracker(): HttpClientTracker {.gcsafe.} = + var res = HttpClientTracker(opened: 0, closed: 0, + dump: dumpHttpClientRequestTracking, + isLeaked: leakHttpClientRequest + ) + addTracker(HttpClientRequestTrackerName, res) + res + +proc setupHttpClientResponseTracker(): HttpClientTracker {.gcsafe.} = + var res = HttpClientTracker(opened: 0, closed: 0, + dump: dumpHttpClientResponseTracking, + isLeaked: leakHttpClientResponse + ) + addTracker(HttpClientResponseTrackerName, res) + res + +proc new*(t: typedesc[HttpSessionRef], + flags: HttpClientFlags = {}, + maxRedirections = HttpMaxRedirections, + connectTimeout = HttpConnectTimeout, + headersTimeout = HttpHeadersTimeout, + connectionBufferSize = DefaultStreamBufferSize, + maxConnections = -1): HttpSessionRef {. + raises: [Defect] .} = + ## Create new HTTP session object. + ## + ## ``maxRedirections`` - maximum number of HTTP 3xx redirections + ## ``connectTimeout`` - timeout for ongoing HTTP connection + ## ``headersTimeout`` - timeout for receiving HTTP response headers + doAssert(maxRedirections >= 0, "maxRedirections should not be negative") + HttpSessionRef( + flags: flags, + maxRedirections: maxRedirections, + connectTimeout: connectTimeout, + headersTimeout: headersTimeout, + connectionBufferSize: connectionBufferSize, + maxConnections: maxConnections, + connections: initTable[string, seq[HttpClientConnectionRef]]() + ) + +proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] {.raises: [Defect] .} = + var res: set[TLSFlags] + if HttpClientFlag.NoVerifyHost in flags: + res.incl(TLSFlags.NoVerifyHost) + if HttpClientFlag.NoVerifyServerName in flags: + res.incl(TLSFlags.NoVerifyServerName) + res + +proc getAddress*(session: HttpSessionRef, url: Uri): HttpResult[HttpAddress] {. + raises: [Defect] .} = + let scheme = + if len(url.scheme) == 0: + HttpClientScheme.NonSecure + else: + case toLowerAscii(url.scheme) + of "http": + HttpClientScheme.NonSecure + of "https": + HttpClientScheme.Secure + else: + return err("URL scheme not supported") + + let port = + if len(url.port) == 0: + case scheme + of HttpClientScheme.NonSecure: + 80'u16 + of HttpClientScheme.Secure: + 443'u16 + else: + let res = Base10.decode(uint16, url.port) + if res.isErr(): + return err("Invalid URL port number") + res.get() + + let hostname = + block: + if len(url.hostname) == 0: + return err("URL hostname is missing") + url.hostname + + let id = hostname & ":" & Base10.toString(port) + + let addresses = + try: + if (HttpClientFlag.NoInet4Resolution in session.flags) and + (HttpClientFlag.NoInet6Resolution in session.flags): + # DNS resolution is disabled. + @[initTAddress(hostname, Port(port))] + else: + if (HttpClientFlag.NoInet4Resolution notin session.flags) and + (HttpClientFlag.NoInet6Resolution notin session.flags): + # DNS resolution for both IPv4 and IPv6 addresses. + resolveTAddress(hostname, Port(port)) + else: + if HttpClientFlag.NoInet6Resolution in session.flags: + # DNS resolution only for IPv4 addresses. + resolveTAddress(hostname, Port(port), AddressFamily.IPv4) + else: + # DNS resolution only for IPv6 addresses + resolveTAddress(hostname, Port(port), AddressFamily.IPv6) + except TransportAddressError: + return err("Could not resolve address of remote server") + + if len(addresses) == 0: + return err("Could not resolve address of remote server") + + ok(HttpAddress(id: id, scheme: scheme, hostname: hostname, port: port, + path: url.path, query: url.query, anchor: url.anchor, + username: url.username, password: url.password, + addresses: addresses)) + +proc getAddress*(session: HttpSessionRef, + url: string): HttpResult[HttpAddress] {.raises: [Defect].} = + ## Create new HTTP address using URL string ``url`` and . + session.getAddress(parseUri(url)) + +proc getAddress*(address: TransportAddress, + ctype: HttpClientScheme = HttpClientScheme.NonSecure, + queryString: string = "/"): HttpAddress {.raises: [Defect].} = + ## Create new HTTP address using Transport address ``address``, connection + ## type ``ctype`` and query string ``queryString``. + let uri = parseUri(queryString) + HttpAddress(id: $address, scheme: ctype, hostname: address.host, + port: uint16(address.port), path: uri.path, query: uri.query, + anchor: uri.anchor, username: "", password: "", addresses: @[address] + ) + +proc getUri*(address: HttpAddress): Uri = + ## Retrieve URI from ``address``. + let scheme = + case address.scheme + of HttpClientScheme.NonSecure: + "http" + of HttpClientScheme.Secure: + "https" + Uri( + scheme: scheme, username: address.username, password: address.password, + hostname: address.hostname, port: Base10.toString(address.port), + path: address.path, query: address.query, anchor: address.anchor, + opaque: false + ) + +proc redirect*(srcuri, dsturi: Uri): Uri = + ## Transform original's URL ``srcuri`` to ``dsturi``. + if (len(dsturi.scheme) > 0) and (len(dsturi.hostname) > 0): + # `dsturi` is absolute URL, replace + dsturi + else: + # `dsturi` is relative URL, combine + var tmpuri = dsturi + tmpuri.username = "" + tmpuri.password = "" + combine(srcuri, tmpuri) + +proc redirect*(session: HttpSessionRef, + srcaddr: HttpAddress, uri: Uri): HttpResult[HttpAddress] = + ## Transform original address ``srcaddr`` using redirected url ``uri`` and + ## session ``session`` parameters. + let srcuri = srcaddr.getUri() + var newuri = srcuri.redirect(uri) + if newuri.hostname != srcuri.hostname: + session.getAddress(newuri) + else: + let scheme = + case newuri.scheme + of "http": + HttpClientScheme.NonSecure + of "https": + HttpClientScheme.Secure + else: + return err("URL scheme not supported") + + let port = + if len(newuri.port) == 0: + case scheme: + of HttpClientScheme.NonSecure: + 80'u16 + of HttpClientScheme.Secure: + 443'u16 + else: + let res = Base10.decode(uint16, newuri.port) + if res.isErr(): + return err("Invalid URL port number") + res.get() + + if len(newuri.hostname) == 0: + return err("URL hostname is missing") + + let id = newuri.hostname & ":" & Base10.toString(port) + + ok(HttpAddress( + id: id, scheme: scheme, hostname: newuri.hostname, port: port, + path: newuri.path, query: newuri.query, anchor: newuri.anchor, + username: newuri.username, password: newuri.password, + addresses: srcaddr.addresses + )) + +proc new(t: typedesc[HttpClientConnectionRef], session: HttpSessionRef, + ha: HttpAddress, transp: StreamTransport): HttpClientConnectionRef = + case ha.scheme + of HttpClientScheme.NonSecure: + let res = HttpClientConnectionRef( + kind: HttpClientScheme.NonSecure, + transp: transp, + reader: newAsyncStreamReader(transp), + writer: newAsyncStreamWriter(transp), + state: HttpClientConnectionState.Connecting, + remoteHostname: ha.id + ) + trackHttpClientConnection(res) + res + of HttpClientScheme.Secure: + let treader = newAsyncStreamReader(transp) + let twriter = newAsyncStreamWriter(transp) + let tls = newTLSClientAsyncStream(treader, twriter, ha.hostname, + flags = session.flags.getTLSFlags()) + let res = HttpClientConnectionRef( + kind: HttpClientScheme.Secure, + transp: transp, + treader: treader, + twriter: twriter, + reader: tls.reader, + writer: tls.writer, + tls: tls, + state: HttpClientConnectionState.Connecting, + remoteHostname: ha.id + ) + trackHttpClientConnection(res) + res + +proc setState(request: HttpClientRequestRef, state: HttpClientRequestState) {. + raises: [Defect] .} = + request.state = state + case state + of HttpClientRequestState.HeadersSending: + request.connection.state = HttpClientConnectionState.RequestHeadersSending + of HttpClientRequestState.HeadersSent: + request.connection.state = HttpClientConnectionState.RequestHeadersSent + of HttpClientRequestState.BodySending: + request.connection.state = HttpClientConnectionState.RequestBodySending + of HttpClientRequestState.BodySent: + request.connection.state = HttpClientConnectionState.RequestBodySent + of HttpClientRequestState.ResponseReceived: + request.connection.state = HttpClientConnectionState.ResponseHeadersReceived + else: + discard + +proc setState(response: HttpClientResponseRef, + state: HttpClientResponseState) {.raises: [Defect] .} = + response.state = state + case state + of HttpClientResponseState.HeadersReceived: + response.connection.state = + HttpClientConnectionState.ResponseHeadersReceived + of HttpClientResponseState.BodyReceiving: + response.connection.state = HttpClientConnectionState.ResponseBodyReceiving + of HttpClientResponseState.BodyReceived: + response.connection.state = HttpClientConnectionState.ResponseBodyReceived + else: + discard + +proc setError(request: HttpClientRequestRef, error: ref HttpError) {. + raises: [Defect] .} = + request.error = error + request.setState(HttpClientRequestState.Error) + if not(isNil(request.connection)): + request.connection.state = HttpClientConnectionState.Error + request.connection.error = error + +proc setError(response: HttpClientResponseRef, error: ref HttpError) {. + raises: [Defect] .} = + response.error = error + response.setState(HttpClientResponseState.Error) + if not(isNil(response.connection)): + response.connection.state = HttpClientConnectionState.Error + response.connection.error = error + +proc closeWait(conn: HttpClientConnectionRef) {.async.} = + ## Close HttpClientConnectionRef instance ``conn`` and free all the resources. + if conn.state != HttpClientConnectionState.Closed: + await allFutures(conn.reader.closeWait(), conn.writer.closeWait()) + case conn.kind + of HttpClientScheme.Secure: + await allFutures(conn.treader.closeWait(), conn.twriter.closeWait()) + of HttpClientScheme.NonSecure: + discard + await conn.transp.closeWait() + conn.state = HttpClientConnectionState.Closed + untrackHttpClientConnection(conn) + +proc connect(session: HttpSessionRef, + ha: HttpAddress): Future[HttpClientConnectionRef] {.async.} = + ## Establish new connection with remote server using ``url`` and ``flags``. + ## On success returns ``HttpClientConnectionRef`` object. + + # Here we trying to connect to every possible remote host address we got after + # DNS resolution. + for address in ha.addresses: + let transp = + try: + await connect(address, bufferSize = session.connectionBufferSize) + except CancelledError as exc: + raise exc + except CatchableError: + nil + if not(isNil(transp)): + let conn = + block: + let res = HttpClientConnectionRef.new(session, ha, transp) + case res.kind + of HttpClientScheme.Secure: + try: + await res.tls.handshake() + res.state = HttpClientConnectionState.Ready + except CancelledError as exc: + await res.closeWait() + raise exc + except AsyncStreamError: + await res.closeWait() + of HttpClientScheme.Nonsecure: + res.state = HttpClientConnectionState.Ready + res + if conn.state == HttpClientConnectionState.Ready: + return conn + + # If all attempts to connect to the remote host have failed. + raiseHttpConnectionError("Could not connect to remote host") + +proc acquireConnection(session: HttpSessionRef, + ha: HttpAddress): Future[HttpClientConnectionRef] {. + async.} = + ## Obtain connection from ``session`` or establish a new one. + let conn = + block: + let conns = session.connections.getOrDefault(ha.id) + if len(conns) > 0: + var res: HttpClientConnectionRef = nil + for item in conns: + if item.state == HttpClientConnectionState.Ready: + res = item + break + res + else: + nil + if not(isNil(conn)): + return conn + else: + var default: seq[HttpClientConnectionRef] + let res = + try: + await session.connect(ha).wait(session.connectTimeout) + except AsyncTimeoutError: + raiseHttpConnectionError("Connection timed out") + session.connections.mgetOrPut(ha.id, default).add(res) + return res + +proc removeConnection(session: HttpSessionRef, + conn: HttpClientConnectionRef) {.async.} = + var conns = session.connections.getOrDefault(conn.remoteHostname) + conns.keepItIf(it != conn) + session.connections[conn.remoteHostname] = conns + await conn.closeWait() + +proc releaseConnection(session: HttpSessionRef, + conn: HttpClientConnectionRef) {.async.} = + ## Return connection back to the ``session``. + ## + ## If connection not in ``Ready`` state it will be closed and removed from + ## the ``session``. + if conn.state != HttpClientConnectionState.Ready: + await session.removeConnection(conn) + +proc closeWait*(session: HttpSessionRef) {.async.} = + ## Closes HTTP session object. + ## + ## This closes all the connections opened to remote servers. + var pending: seq[Future[void]] + for items in session.connections.values(): + for item in items: + pending.add(closeWait(item)) + await allFutures(pending) + +proc closeWait*(request: HttpClientRequestRef) {.async.} = + if request.state != HttpClientRequestState.Closed: + if not(isNil(request.writer)): + if not(request.writer.closed()): + await request.writer.closeWait() + request.writer = nil + if request.state != HttpClientRequestState.ResponseReceived: + if not(isNil(request.connection)): + await request.session.releaseConnection(request.connection) + request.connection = nil + request.session = nil + request.error = nil + request.setState(HttpClientRequestState.Closed) + untrackHttpClientRequest(request) + +proc closeWait*(response: HttpClientResponseRef) {.async.} = + if response.state != HttpClientResponseState.Closed: + if not(isNil(response.reader)): + if not(response.reader.closed()): + await response.reader.closeWait() + response.reader = nil + if not(isNil(response.connection)): + await response.session.releaseConnection(response.connection) + response.connection = nil + response.session = nil + response.error = nil + response.setState(HttpClientResponseState.Closed) + untrackHttpClientResponse(response) + +proc prepareResponse(request: HttpClientRequestRef, + data: openarray[byte]): HttpResult[HttpClientResponseRef] {. + raises: [Defect] .} = + ## Process response headers. + let resp = parseResponse(data, false) + if resp.failed(): + return err("Invalid headers received") + + let headers = + block: + var res = HttpTable.init() + for key, value in resp.headers(data): + res.add(key, value) + if res.count(ContentTypeHeader) > 1: + return err("Invalid headers received, too many `Content-Type`") + if res.count(ContentLengthHeader) > 1: + return err("Invalid headers received, too many `Content-Length`") + if res.count(TransferEncodingHeader) > 1: + return err("Invalid headers received, too many `Transfer-Encoding`") + res + + # Preprocessing "Content-Encoding" header. + let contentEncoding = + block: + let res = getContentEncoding(headers.getList(ContentEncodingHeader)) + if res.isErr(): + return err("Invalid headers received, invalid `Content-Encoding`") + else: + res.get() + + # Preprocessing "Transfer-Encoding" header. + let transferEncoding = + block: + let res = getTransferEncoding(headers.getList(TransferEncodingHeader)) + if res.isErr(): + return err("Invalid headers received, invalid `Transfer-Encoding`") + else: + res.get() + + # Preprocessing "Content-Length" header. + let (contentLength, bodyFlag) = + if ContentLengthHeader in headers: + let length = headers.getInt(ContentLengthHeader) + (length, HttpClientBodyFlag.Sized) + else: + if TransferEncodingFlags.Chunked in transferEncoding: + (0'u64, HttpClientBodyFlag.Chunked) + else: + (0'u64, HttpClientBodyFlag.Custom) + + let res = HttpClientResponseRef( + state: HttpClientResponseState.HeadersReceived, status: resp.code, + address: request.address, requestMethod: request.meth, + reason: resp.reason(data), version: resp.version, session: request.session, + connection: request.connection, headers: headers, + contentEncoding: contentEncoding, transferEncoding: transferEncoding, + contentLength: contentLength, bodyFlag: bodyFlag + ) + request.setState(HttpClientRequestState.ResponseReceived) + trackHttpClientResponse(res) + ok(res) + +proc getResponse(req: HttpClientRequestRef): Future[HttpClientResponseRef] {. + async.} = + var buffer: array[HttpMaxHeadersSize, byte] + let bytesRead = + try: + await req.connection.reader.readUntil(addr buffer[0], + len(buffer), HeadersMark).wait( + req.session.headersTimeout) + except CancelledError as exc: + raise exc + except AsyncTimeoutError: + raiseHttpReadError("Reading response headers timed out") + except AsyncStreamError: + raiseHttpReadError("Could not read response headers") + let resp = prepareResponse(req, buffer.toOpenArray(0, bytesRead - 1)) + if resp.isErr(): + raiseHttpProtocolError(resp.error()) + return resp.get() + +proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + ha: HttpAddress, meth: HttpMethod = MethodGet, + version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[byte] = []): HttpClientRequestRef {. + raises: [Defect].} = + let res = HttpClientRequestRef( + state: HttpClientRequestState.Created, session: session, meth: meth, + version: version, flags: flags, headers: HttpTable.init(headers), + address: ha, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body + ) + trackHttpClientRequest(res) + res + +proc new*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + url: string, meth: HttpMethod = MethodGet, + version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[byte] = []): HttpResult[HttpClientRequestRef] {. + raises: [Defect].} = + let address = ? session.getAddress(parseUri(url)) + let res = HttpClientRequestRef( + state: HttpClientRequestState.Created, session: session, meth: meth, + version: version, flags: flags, headers: HttpTable.init(headers), + address: address, bodyFlag: HttpClientBodyFlag.Custom, buffer: @body + ) + trackHttpClientRequest(res) + ok(res) + +proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + url: string, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [] + ): HttpResult[HttpClientRequestRef] {.raises: [Defect].} = + HttpClientRequestRef.new(session, url, MethodGet, version, flags, headers) + +proc get*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + ha: HttpAddress, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [] + ): HttpClientRequestRef {.raises: [Defect].} = + HttpClientRequestRef.new(session, ha, MethodGet, version, flags, headers) + +proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + url: string, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[byte] = [] + ): HttpResult[HttpClientRequestRef] {.raises: [Defect].} = + HttpClientRequestRef.new(session, url, MethodPost, version, flags, headers, + body) + +proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + url: string, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[char] = []): HttpResult[HttpClientRequestRef] {. + raises: [Defect].} = + HttpClientRequestRef.new(session, url, MethodPost, version, flags, headers, + body.toOpenArrayByte(0, len(body) - 1)) + +proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + ha: HttpAddress, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[byte] = []): HttpClientRequestRef {. + raises: [Defect].} = + HttpClientRequestRef.new(session, ha, MethodPost, version, flags, headers, + body) + +proc post*(t: typedesc[HttpClientRequestRef], session: HttpSessionRef, + ha: HttpAddress, version: HttpVersion = HttpVersion11, + flags: set[HttpClientRequestFlag] = {}, + headers: openarray[HttpHeaderTuple] = [], + body: openarray[char] = []): HttpClientRequestRef {. + raises: [Defect].} = + HttpClientRequestRef.new(session, ha, MethodPost, version, flags, headers, + body.toOpenArrayByte(0, len(body) - 1)) + +proc prepareRequest(request: HttpClientRequestRef): string {. + raises: [Defect].} = + template hasChunkedEncoding(request: HttpClientRequestRef): bool = + toLowerAscii(request.headers.getString(TransferEncodingHeader)) == "chunked" + + # We use ChronosIdent as `User-Agent` string if its not set. + if UserAgentHeader notin request.headers: + discard request.headers.hasKeyOrPut(UserAgentHeader, ChronosIdent) + # We use request's hostname as `Host` string if its not set. + if HostHeader notin request.headers: + discard request.headers.hasKeyOrPut(HostHeader, request.address.hostname) + # We set `Connection` to value according to flags if its not set. + if ConnectionHeader notin request.headers: + if HttpClientRequestFlag.CloseConnection in request.flags: + discard request.headers.hasKeyOrPut(ConnectionHeader, "close") + else: + discard request.headers.hasKeyOrPut(ConnectionHeader, "keep-alive") + # We set `Accept` to accept any content if its not set. + if AcceptHeader notin request.headers: + discard request.headers.hasKeyOrPut(AcceptHeader, "*/*") + + # Here we perform automatic detection: if request was created with non-zero + # body and `Content-Length` header is missing we will create one with size + # of body stored in request. + if ContentLengthHeader notin request.headers: + if len(request.buffer) > 0: + let slength = Base10.toString(uint64(len(request.buffer))) + discard request.headers.hasKeyOrPut(ContentLengthHeader, slength) + + request.bodyFlag = + if ContentLengthHeader in request.headers: + HttpClientBodyFlag.Sized + else: + if request.hasChunkedEncoding(): + HttpClientBodyFlag.Chunked + else: + HttpClientBodyFlag.Custom + + let entity = + block: + var res = + if len(request.address.path) > 0: + request.address.path + else: + "/" + if len(request.address.query) > 0: + res.add("?") + res.add(request.address.query) + if len(request.address.anchor) > 0: + res.add("#") + res.add(request.address.anchor) + res + + var res = $request.meth + res.add(" ") + res.add(entity) + res.add(" ") + res.add($request.version) + res.add("\r\n") + for k, v in request.headers.stringItems(): + if len(v) > 0: + res.add(normalizeHeaderName(k)) + res.add(": ") + res.add(v) + res.add("\r\n") + res.add("\r\n") + res + +proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {. + async.} = + doAssert(request.state == HttpClientRequestState.Created) + request.setState(HttpClientRequestState.Connecting) + request.connection = + try: + await request.session.acquireConnection(request.address) + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except HttpError as exc: + request.setError(exc) + raise exc + + let headers = request.prepareRequest() + + try: + request.setState(HttpClientRequestState.HeadersSending) + await request.connection.writer.write(headers) + request.setState(HttpClientRequestState.HeadersSent) + request.setState(HttpClientRequestState.BodySending) + if len(request.buffer) > 0: + await request.connection.writer.write(request.buffer) + request.setState(HttpClientRequestState.BodySent) + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except AsyncStreamError as exc: + let error = newHttpWriteError("Could not send request headers") + request.setError(error) + raise error + + let resp = + try: + await request.getResponse() + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except HttpError as exc: + request.setError(exc) + raise exc + return resp + +proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {. + async.} = + ## Start sending request's headers and return `HttpBodyWriter`, which can be + ## used to send request's body. + doAssert(request.state == HttpClientRequestState.Created) + doAssert(len(request.buffer) == 0, + "Request should not have static body content (len(buffer) == 0)") + request.setState(HttpClientRequestState.Connecting) + request.connection = + try: + await request.session.acquireConnection(request.address) + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except HttpError as exc: + request.setError(exc) + raise exc + + let headers = request.prepareRequest() + + try: + request.setState(HttpClientRequestState.HeadersSending) + await request.connection.writer.write(headers) + request.setState(HttpClientRequestState.HeadersSent) + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except AsyncStreamError as exc: + let error = newHttpWriteError("Could not send request headers") + request.setError(error) + raise error + + let writer = + case request.bodyFlag + of HttpClientBodyFlag.Sized: + let size = Base10.decode(uint64, + request.headers.getString("content-length")) + let writer = newBoundedStreamWriter(request.connection.writer, size.get()) + newHttpBodyWriter([AsyncStreamWriter(writer)]) + of HttpClientBodyFlag.Chunked: + let writer = newChunkedStreamWriter(request.connection.writer) + newHttpBodyWriter([AsyncStreamWriter(writer)]) + of HttpClientBodyFlag.Custom: + let writer = newAsyncStreamWriter(request.connection.writer) + newHttpBodyWriter([writer]) + + request.writer = writer + request.setState(HttpClientRequestState.BodySending) + return writer + +proc finish*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {. + async.} = + ## Finish sending request and receive response. + doAssert(request.state == HttpClientRequestState.BodySending) + doAssert(request.connection.state == + HttpClientConnectionState.RequestBodySending) + doAssert(request.writer.closed()) + request.setState(HttpClientRequestState.BodySent) + let resp = + try: + await request.getResponse() + except CancelledError as exc: + request.setError(newHttpInterruptError()) + raise exc + except HttpError as exc: + request.setError(exc) + raise exc + return resp + +proc getNewLocation*(resp: HttpClientResponseRef): HttpResult[HttpAddress] = + ## Returns new address according to response's `Location` header value. + if "location" in resp.headers: + let location = resp.headers.getString("location") + if len(location) > 0: + resp.session.redirect(resp.address, parseUri(location)) + else: + err("Location header with empty value") + else: + err("Location header is missing") + +proc getBodyReader*(resp: HttpClientResponseRef): HttpBodyReader = + ## Returns stream's reader instance which can be used to read response's body. + ## + ## Streams which was obtained using this procedure must be closed to avoid + ## leaks. + doAssert(resp.state in { + HttpClientResponseState.HeadersReceived, + HttpClientResponseState.BodyReceiving}) + doAssert(resp.connection.state in { + HttpClientConnectionState.ResponseHeadersReceived, + HttpClientConnectionState.ResponseBodyReceiving}) + if isNil(resp.reader): + let reader = + case resp.bodyFlag + of HttpClientBodyFlag.Sized: + let bstream = newBoundedStreamReader(resp.connection.reader, + resp.contentLength) + newHttpBodyReader(bstream) + of HttpClientBodyFlag.Chunked: + newHttpBodyReader(newChunkedStreamReader(resp.connection.reader)) + of HttpClientBodyFlag.Custom: + newHttpBodyReader(newAsyncStreamReader(resp.connection.reader)) + resp.setState(HttpClientResponseState.BodyReceiving) + resp.reader = reader + resp.reader + +proc finish*(resp: HttpClientResponseRef) {.async.} = + ## Finish receiving response. + doAssert(resp.state == HttpClientResponseState.BodyReceiving) + doAssert(resp.connection.state == + HttpClientConnectionState.ResponseBodyReceiving) + doAssert(resp.reader.closed()) + resp.setState(HttpClientResponseState.BodyReceived) + resp.connection.state = HttpClientConnectionState.Ready + +proc getBodyBytes*(response: HttpClientResponseRef): Future[seq[byte]] {. + async.} = + ## Read all bytes from response ``response``. + doAssert(response.state == HttpClientResponseState.HeadersReceived) + doAssert(response.connection.state == + HttpClientConnectionState.ResponseHeadersReceived) + var reader = response.getBodyReader() + try: + let data = await reader.read() + await reader.closeWait() + reader = nil + await response.finish() + return data + except CancelledError as exc: + if not(isNil(reader)): + await reader.closeWait() + response.setError(newHttpInterruptError()) + raise exc + except AsyncStreamError as exc: + if not(isNil(reader)): + await reader.closeWait() + let error = newHttpReadError("Could not read response") + response.setError(error) + raise error + +proc getBodyBytes*(response: HttpClientResponseRef, + nbytes: int): Future[seq[byte]] {.async.} = + ## Read all bytes (nbytes <= 0) or exactly `nbytes` bytes from response + ## ``response``. + doAssert(response.state == HttpClientResponseState.HeadersReceived) + doAssert(response.connection.state == + HttpClientConnectionState.ResponseHeadersReceived) + var reader = response.getBodyReader() + try: + let data = await reader.read(nbytes) + await reader.closeWait() + reader = nil + await response.finish() + return data + except CancelledError as exc: + if not(isNil(reader)): + await reader.closeWait() + response.setError(newHttpInterruptError()) + raise exc + except AsyncStreamError as exc: + if not(isNil(reader)): + await reader.closeWait() + let error = newHttpReadError("Could not read response") + response.setError(error) + raise error + +proc consumeBody*(response: HttpClientResponseRef): Future[int] {.async.} = + ## Consume/discard response and return number of bytes consumed. + doAssert(response.state == HttpClientResponseState.HeadersReceived) + doAssert(response.connection.state == + HttpClientConnectionState.ResponseHeadersReceived) + var reader = response.getBodyReader() + try: + let res = await reader.consume() + await reader.closeWait() + reader = nil + await response.finish() + return res + except CancelledError as exc: + if not(isNil(reader)): + await reader.closeWait() + response.setError(newHttpInterruptError()) + raise exc + except AsyncStreamError as exc: + if not(isNil(reader)): + await reader.closeWait() + let error = newHttpReadError("Could not read response") + response.setError(error) + raise error + +proc redirect*(request: HttpClientRequestRef, + ha: HttpAddress): HttpResult[HttpClientRequestRef] = + ## Create new request object using original request object ``request`` and + ## new redirected address ``ha``. + ## + ## This procedure could return an error if number of redirects exceeded + ## maximum allowed number of redirects in request's session. + let redirectCount = request.redirectCount + 1 + if redirectCount > request.session.maxRedirections: + err("Maximum number of redirects exceeded") + else: + var res = HttpClientRequestRef.new(request.session, ha, request.meth, + request.version, request.flags, request.headers.toList(), request.buffer) + res.redirectCount = redirectCount + ok(res) + +proc redirect*(request: HttpClientRequestRef, + uri: Uri): HttpResult[HttpClientRequestRef] = + ## Create new request object using original request object ``request`` and + ## redirected URL ``uri``. + ## + ## This procedure could return an error if number of redirects exceeded + ## maximum allowed number of redirects in request's session or ``uri`` is + ## incorrect or not supported. + let redirectCount = request.redirectCount + 1 + if redirectCount > request.session.maxRedirections: + err("Maximum number of redirects exceeded") + else: + let address = ? request.session.redirect(request.address, uri) + var res = HttpClientRequestRef.new(request.session, address, request.meth, + request.version, request.flags, request.headers.toList(), request.buffer) + res.redirectCount = redirectCount + ok(res) + +proc fetch*(request: HttpClientRequestRef): Future[HttpResponseTuple] {. + async.} = + let response = await request.send() + let data = await response.getBodyBytes() + let code = response.status + await response.closeWait() + return (code, data) + +proc fetch*(session: HttpSessionRef, url: Uri): Future[HttpResponseTuple] {. + async.} = + ## Fetch resource pointed by ``url`` using HTTP GET method and ``session`` + ## parameters. + ## + ## This procedure supports HTTP redirections. + let address = + block: + let res = session.getAddress(url) + if res.isErr(): + raiseHttpAddressError(res.error()) + res.get() + + var + request = HttpClientRequestRef.new(session, address) + response: HttpClientResponseRef = nil + redirect: HttpClientRequestRef = nil + + while true: + try: + response = await request.send() + if response.status >= 300 and response.status < 400: + redirect = + block: + if "location" in response.headers: + let location = response.headers.getString("location") + if len(location) > 0: + let res = request.redirect(parseUri(location)) + if res.isErr(): + raiseHttpRedirectError(res.error()) + res.get() + else: + raiseHttpRedirectError("Location header with an empty value") + else: + raiseHttpRedirectError("Location header missing") + await request.closeWait() + request = nil + discard await response.consumeBody() + await response.closeWait() + response = nil + request = redirect + redirect = nil + else: + await request.closeWait() + request = nil + let data = await response.getBodyBytes() + let code = response.status + await response.closeWait() + response = nil + return (code, data) + except CancelledError as exc: + if not(isNil(request)): + await closeWait(request) + if not(isNil(redirect)): + await closeWait(redirect) + if not(isNil(response)): + await closeWait(response) + raise exc + except HttpError as exc: + if not(isNil(request)): + await closeWait(request) + if not(isNil(redirect)): + await closeWait(redirect) + if not(isNil(response)): + await closeWait(response) + raise exc diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index ac4b39f2..14c2e956 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -7,17 +7,33 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[strutils, uri] -import stew/results, httputils +import stew/[results, endians2], httputils import ../../asyncloop, ../../asyncsync import ../../streams/[asyncstream, boundstream] export results, httputils, strutils const - HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + HeadersMark* = @[0x0d'u8, 0x0a'u8, 0x0d'u8, 0x0a'u8] PostMethods* = {MethodPost, MethodPatch, MethodPut, MethodDelete} MaximumBodySizeError* = "Maximum size of request's body reached" + UserAgentHeader* = "user-agent" + DateHeader* = "date" + HostHeader* = "host" + ConnectionHeader* = "connection" + AcceptHeader* = "accept" + ContentLengthHeader* = "content-length" + TransferEncodingHeader* = "transfer-encoding" + ContentEncodingHeader* = "content-encoding" + ContentTypeHeader* = "content-type" + ExpectHeader* = "expect" + ServerHeader* = "server" + LocationHeader* = "location" + + UrlEncodedContentType* = "application/x-www-form-urlencoded" + MultipartContentType* = "multipart/form-data" + type HttpResult*[T] = Result[T, string] HttpResultCode*[T] = Result[T, HttpCode] @@ -29,6 +45,13 @@ type HttpRecoverableError* = object of HttpError code*: HttpCode HttpDisconnectError* = object of HttpError + HttpConnectionError* = object of HttpError + HttpInterruptError* = object of HttpError + HttpReadError* = object of HttpError + HttpWriteError* = object of HttpError + HttpProtocolError* = object of HttpError + HttpRedirectError* = object of HttpError + HttpAddressError* = object of HttpError TransferEncodingFlags* {.pure.} = enum Identity, Chunked, Compress, Deflate, Gzip @@ -36,45 +59,6 @@ type ContentEncodingFlags* {.pure.} = enum Identity, Br, Compress, Deflate, Gzip - HttpBodyReader* = ref object of AsyncStreamReader - streams*: seq[AsyncStreamReader] - -proc newHttpBodyReader*(streams: varargs[AsyncStreamReader]): HttpBodyReader = - ## HttpBodyReader is AsyncStreamReader which holds references to all the - ## ``streams``. Also on close it will close all the ``streams``. - ## - ## First stream in sequence will be used as a source. - doAssert(len(streams) > 0, "At least one stream must be added") - var res = HttpBodyReader(streams: @streams) - res.init(streams[0]) - res - -proc closeWait*(bstream: HttpBodyReader) {.async.} = - ## Close and free resource allocated by body reader. - var res = newSeq[Future[void]]() - # We closing streams in reversed order because stream at position [0], uses - # data from stream at position [1]. - for index in countdown((len(bstream.streams) - 1), 0): - res.add(bstream.streams[index].closeWait()) - await allFutures(res) - await procCall(closeWait(AsyncStreamReader(bstream))) - -proc hasOverflow*(bstream: HttpBodyReader): bool {.raises: [Defect].} = - if len(bstream.streams) == 1: - # If HttpBodyReader has only one stream it has ``BoundedStreamReader``, in - # such case its impossible to get more bytes then expected amount. - false - else: - # If HttpBodyReader has two or more streams, we check if - # ``BoundedStreamReader`` at EOF. - if bstream.streams[0].atEof(): - for i in 1 ..< len(bstream.streams): - if not(bstream.streams[1].atEof()): - return true - false - else: - false - proc raiseHttpCriticalError*(msg: string, code = Http400) {.noinline, noreturn.} = raise (ref HttpCriticalError)(code: code, msg: msg) @@ -85,6 +69,36 @@ proc raiseHttpDisconnectError*() {.noinline, noreturn.} = proc raiseHttpDefect*(msg: string) {.noinline, noreturn.} = raise (ref HttpDefect)(msg: msg) +proc raiseHttpConnectionError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpConnectionError)(msg: msg) + +proc raiseHttpInterruptError*() {.noinline, noreturn.} = + raise (ref HttpInterruptError)(msg: "Connection was interrupted") + +proc raiseHttpReadError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpReadError)(msg: msg) + +proc raiseHttpProtocolError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpProtocolError)(msg: msg) + +proc raiseHttpWriteError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpWriteError)(msg: msg) + +proc raiseHttpRedirectError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpRedirectError)(msg: msg) + +proc raiseHttpAddressError*(msg: string) {.noinline, noreturn.} = + raise (ref HttpAddressError)(msg: msg) + +template newHttpInterruptError*(): ref HttpInterruptError = + newException(HttpInterruptError, "Connection was interrupted") + +template newHttpReadError*(message: string): ref HttpReadError = + newException(HttpReadError, message) + +template newHttpWriteError*(message: string): ref HttpWriteError = + newException(HttpWriteError, message) + iterator queryParams*(query: string): tuple[key: string, value: string] {. raises: [Defect].} = ## Iterate over url-encoded query string. @@ -211,3 +225,48 @@ func stringToBytes*(src: openarray[char]): seq[byte] = dst else: default + +proc dumpHex*(pbytes: openarray[byte], groupBy = 1, ascii = true): string = + ## Get hexadecimal dump of memory for array ``pbytes``. + var res = "" + var offset = 0 + var ascii = "" + + while offset < len(pbytes): + if (offset mod 16) == 0: + res = res & toHex(uint64(offset)) & ": " + + for k in 0 ..< groupBy: + let ch = pbytes[offset + k] + ascii.add(if ord(ch) > 31 and ord(ch) < 127: char(ch) else: '.') + + let item = + case groupBy: + of 1: + toHex(pbytes[offset]) + of 2: + toHex(uint16.fromBytes(pbytes.toOpenArray(offset, len(pbytes) - 1))) + of 4: + toHex(uint32.fromBytes(pbytes.toOpenArray(offset, len(pbytes) - 1))) + of 8: + toHex(uint64.fromBytes(pbytes.toOpenArray(offset, len(pbytes) - 1))) + else: + "" + res.add(item) + res.add(" ") + offset = offset + groupBy + + if (offset mod 16) == 0: + res.add(" ") + res.add(ascii) + ascii.setLen(0) + res.add("\p") + + if (offset mod 16) != 0: + let spacesCount = ((16 - (offset mod 16)) div groupBy) * + (groupBy * 2 + 1) + 1 + res = res & repeat(' ', spacesCount) + res = res & ascii + + res.add("\p") + res diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index cc4e5da0..bf440d3d 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -289,18 +289,19 @@ proc prepareRequest(conn: HttpConnectionRef, table.add(key, value) # Validating HTTP request headers # Some of the headers must be present only once. - if table.count("content-type") > 1: + if table.count(ContentTypeHeader) > 1: return err(Http400) - if table.count("content-length") > 1: + if table.count(ContentLengthHeader) > 1: return err(Http400) - if table.count("transfer-encoding") > 1: + if table.count(TransferEncodingHeader) > 1: return err(Http400) table # Preprocessing "Content-Encoding" header. request.contentEncoding = block: - let res = getContentEncoding(request.headers.getList("content-encoding")) + let res = getContentEncoding( + request.headers.getList(ContentEncodingHeader)) if res.isErr(): return err(Http400) else: @@ -310,7 +311,7 @@ proc prepareRequest(conn: HttpConnectionRef, request.transferEncoding = block: let res = getTransferEncoding( - request.headers.getList("transfer-encoding")) + request.headers.getList(TransferEncodingHeader)) if res.isErr(): return err(Http400) else: @@ -318,8 +319,8 @@ proc prepareRequest(conn: HttpConnectionRef, # Almost all HTTP requests could have body (except TRACE), we perform some # steps to reveal information about body. - if "content-length" in request.headers: - let length = request.headers.getInt("content-length") + if ContentLengthHeader in request.headers: + let length = request.headers.getInt(ContentLengthHeader) if length > 0: if request.meth == MethodTrace: return err(Http400) @@ -337,20 +338,16 @@ proc prepareRequest(conn: HttpConnectionRef, if request.hasBody(): # If request has body, we going to understand how its encoded. - const - UrlEncodedType = "application/x-www-form-urlencoded" - MultipartType = "multipart/form-data" - - if "content-type" in request.headers: - let contentType = request.headers.getString("content-type") + if ContentTypeHeader in request.headers: + let contentType = request.headers.getString(ContentTypeHeader) let tmp = strip(contentType).toLowerAscii() - if tmp.startsWith(UrlEncodedType): + if tmp.startsWith(UrlEncodedContentType): request.requestFlags.incl(HttpRequestFlags.UrlencodedForm) - elif tmp.startsWith(MultipartType): + elif tmp.startsWith(MultipartContentType): request.requestFlags.incl(HttpRequestFlags.MultipartForm) - if "expect" in request.headers: - let expectHeader = request.headers.getString("expect") + if ExpectHeader in request.headers: + let expectHeader = request.headers.getString(ExpectHeader) if strip(expectHeader).toLowerAscii() == "100-continue": request.requestFlags.incl(HttpRequestFlags.ClientExpect) @@ -430,17 +427,29 @@ proc sendErrorResponse(conn: HttpConnectionRef, version: HttpVersion, datatype = "text/text", databody = ""): Future[bool] {.async.} = var answer = $version & " " & $code & "\r\n" - answer.add("Date: " & httpDate() & "\r\n") - if len(datatype) > 0: - answer.add("Content-Type: " & datatype & "\r\n") - answer.add("Content-Length: " & - Base10.toString(uint64(len(databody))) & "\r\n") - if keepAlive: - answer.add("Connection: keep-alive\r\n") - else: - answer.add("Connection: close\r\n") - answer.add("Host: " & conn.server.getHostname() & "\r\n") + answer.add(DateHeader) + answer.add(": ") + answer.add(httpDate()) answer.add("\r\n") + if len(datatype) > 0: + answer.add(ContentTypeHeader) + answer.add(": ") + answer.add(datatype) + answer.add("\r\n") + answer.add(ContentLengthHeader) + answer.add(": ") + answer.add(Base10.toString(uint64(len(databody)))) + answer.add("\r\n") + if keepAlive: + answer.add(ConnectionHeader) + answer.add(": keep-alive\r\n") + else: + answer.add(ConnectionHeader) + answer.add(": close\r\n") + answer.add(HostHeader) + answer.add(": ") + answer.add(conn.server.getHostname()) + answer.add("\r\n\r\n") if len(databody) > 0: answer.add(databody) try: @@ -744,12 +753,12 @@ proc getMultipartReader*(req: HttpRequestRef): HttpResult[MultiPartReaderRef] = ## Create new MultiPartReader interface for specific request. if req.meth in PostMethods: if MultipartForm in req.requestFlags: - let ctype = ? getContentType(req.headers.getList("content-type")) - if ctype != "multipart/form-data": + let ctype = ? getContentType(req.headers.getList(ContentTypeHeader)) + if ctype != MultipartContentType: err("Content type is not supported") else: let boundary = ? getMultipartBoundary( - req.headers.getList("content-type") + req.headers.getList(ContentTypeHeader) ) var stream = ? req.getBodyReader() ok(MultiPartReaderRef.new(stream, boundary)) @@ -877,21 +886,21 @@ template checkPending(t: untyped) = proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {. raises: [Defect].}= - if not(resp.hasHeader("date")): - resp.setHeader("date", httpDate()) - if not(resp.hasHeader("content-type")): - resp.setHeader("content-type", "text/html; charset=utf-8") - if not(resp.hasHeader("content-length")): - resp.setHeader("content-length", Base10.toString(uint64(length))) - if not(resp.hasHeader("server")): - resp.setHeader("server", resp.connection.server.serverIdent) - if not(resp.hasHeader("host")): - resp.setHeader("host", resp.connection.server.getHostname()) - if not(resp.hasHeader("connection")): + if not(resp.hasHeader(DateHeader)): + resp.setHeader(DateHeader, httpDate()) + if not(resp.hasHeader(ContentTypeHeader)): + resp.setHeader(ContentTypeHeader, "text/html; charset=utf-8") + if not(resp.hasHeader(ContentLengthHeader)): + resp.setHeader(ContentLengthHeader, Base10.toString(uint64(length))) + if not(resp.hasHeader(ServerHeader)): + resp.setHeader(ServerHeader, resp.connection.server.serverIdent) + if not(resp.hasHeader(HostHeader)): + resp.setHeader(HostHeader, resp.connection.server.getHostname()) + if not(resp.hasHeader(ConnectionHeader)): if KeepAlive in resp.flags: - resp.setHeader("connection", "keep-alive") + resp.setHeader(ConnectionHeader, "keep-alive") else: - resp.setHeader("connection", "close") + resp.setHeader(ConnectionHeader, "close") var answer = $(resp.version) & " " & $(resp.status) & "\r\n" for k, v in resp.headersTable.stringItems(): if len(v) > 0: @@ -904,21 +913,21 @@ proc prepareLengthHeaders(resp: HttpResponseRef, length: int): string {. proc prepareChunkedHeaders(resp: HttpResponseRef): string {. raises: [Defect].} = - if not(resp.hasHeader("date")): - resp.setHeader("date", httpDate()) - if not(resp.hasHeader("content-type")): - resp.setHeader("content-type", "text/html; charset=utf-8") - if not(resp.hasHeader("transfer-encoding")): - resp.setHeader("transfer-encoding", "chunked") - if not(resp.hasHeader("server")): - resp.setHeader("server", resp.connection.server.serverIdent) - if not(resp.hasHeader("host")): - resp.setHeader("host", resp.connection.server.getHostname()) - if not(resp.hasHeader("connection")): + if not(resp.hasHeader(DateHeader)): + resp.setHeader(DateHeader, httpDate()) + if not(resp.hasHeader(ContentTypeHeader)): + resp.setHeader(ContentTypeHeader, "text/html; charset=utf-8") + if not(resp.hasHeader(TransferEncodingHeader)): + resp.setHeader(TransferEncodingHeader, "chunked") + if not(resp.hasHeader(ServerHeader)): + resp.setHeader(ServerHeader, resp.connection.server.serverIdent) + if not(resp.hasHeader(HostHeader)): + resp.setHeader(HostHeader, resp.connection.server.getHostname()) + if not(resp.hasHeader(ConnectionHeader)): if KeepAlive in resp.flags: - resp.setHeader("connection", "keep-alive") + resp.setHeader(ConnectionHeader, "keep-alive") else: - resp.setHeader("connection", "close") + resp.setHeader(ConnectionHeader, "close") var answer = $(resp.version) & " " & $(resp.status) & "\r\n" for k, v in resp.headersTable.stringItems(): if len(v) > 0: @@ -1076,9 +1085,21 @@ proc respond*(req: HttpRequestRef, code: HttpCode, respond(req, code, content, HttpTable.init()) proc respond*(req: HttpRequestRef, code: HttpCode): Future[HttpResponseRef] = - ## Reponds to the request with specified ``HttpCode`` only. + ## Responds to the request with specified ``HttpCode`` only. respond(req, code, "", HttpTable.init()) +proc redirect*(req: HttpRequestRef, code: HttpCode, + location: Uri): Future[HttpResponseRef] = + ## Responds to the request with redirection to location ``location``. + let headers = HttpTable.init([("location", $location)]) + respond(req, code, "", headers) + +proc redirect*(req: HttpRequestRef, code: HttpCode, + location: string): Future[HttpResponseRef] = + ## Responds to the request with redirection to location ``location``. + let headers = HttpTable.init([("location", location)]) + respond(req, code, "", headers) + proc responded*(req: HttpRequestRef): bool = ## Returns ``true`` if request ``req`` has been responded or responding. if isSome(req.response): diff --git a/chronos/apps/http/httptable.nim b/chronos/apps/http/httptable.nim index 0e83400e..ee4e1e99 100644 --- a/chronos/apps/http/httptable.nim +++ b/chronos/apps/http/httptable.nim @@ -190,3 +190,10 @@ proc `$`*(ht: HttpTables): string = res.add(item) res.add("\p") res + +proc toList*(ht: HttpTables, normKey = false): auto = + ## Returns sequence of (key, value) pairs. + var res: seq[tuple[key: string, value: string]] + for key, value in ht.stringItems(normKey): + res.add((key, value)) + res diff --git a/chronos/apps/http/multipart.nim b/chronos/apps/http/multipart.nim index a5d2fc10..13c84f8c 100644 --- a/chronos/apps/http/multipart.nim +++ b/chronos/apps/http/multipart.nim @@ -11,8 +11,8 @@ import std/[monotimes, strutils] import stew/results import ../../asyncloop import ../../streams/[asyncstream, boundstream, chunkstream] -import httptable, httpcommon -export httptable, httpcommon, asyncstream +import httptable, httpcommon, httpbodyrw +export httptable, httpcommon, httpbodyrw, asyncstream const UnableToReadMultipartBody = "Unable to read multipart message body" @@ -21,8 +21,12 @@ type MultiPartSource* {.pure.} = enum Stream, Buffer + MultiPartWriterState* {.pure.} = enum + MessagePreparing, MessageStarted, PartStarted, PartFinished, + MessageFinished, MessageFailure + MultiPartReader* = object - case kind: MultiPartSource + case kind*: MultiPartSource of MultiPartSource.Stream: stream*: HttpBodyReader of MultiPartSource.Buffer: @@ -35,6 +39,20 @@ type MultiPartReaderRef* = ref MultiPartReader + MultiPartWriter* = object + case kind*: MultiPartSource + of MultiPartSource.Stream: + stream*: HttpBodyWriter + of MultiPartSource.Buffer: + buffer*: seq[byte] + beginMark: seq[byte] + finishMark: seq[byte] + beginPartMark: seq[byte] + finishPartMark: seq[byte] + state*: MultiPartWriterState + + MultiPartWriterRef* = ref MultiPartWriter + MultiPart* = object case kind: MultiPartSource of MultiPartSource.Stream: @@ -409,6 +427,18 @@ func isEmpty*(mp: MultiPart): bool {. ## Returns ``true`` is multipart ``mp`` is not initialized/filled yet. mp.counter == 0 +func validateBoundary[B: BChar](boundary: openarray[B]): HttpResult[void] = + if len(boundary) == 0: + err("Content-Type boundary must be at least 1 character size") + elif len(boundary) > 70: + err("Content-Type boundary must be less then 70 characters") + else: + for ch in boundary: + if chr(ord(ch)) notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9', + '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}: + return err("Content-Type boundary alphabet incorrect") + ok() + func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {. raises: [Defect].} = ## Returns ``multipart/form-data`` boundary value from ``Content-Type`` @@ -453,13 +483,280 @@ func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {. err("Missing Content-Type boundary") else: let candidate = strip(bparts[1]) - if len(candidate) == 0: - err("Content-Type boundary must be at least 1 character size") - elif len(candidate) > 70: - err("Content-Type boundary must be less then 70 characters") + let res = validateBoundary(candidate) + if res.isErr(): + err($res.error()) else: - for ch in candidate: - if ch notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9', - '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}: - return err("Content-Type boundary alphabet incorrect") ok(candidate) + +proc quoteCheck(name: string): HttpResult[string] = + if len(name) > 0: + var res = newStringOfCap(len(name)) + for ch in name: + case ch + of '\x00' .. '\x08', '\x0a' .. '\x1f': + return err("Incorrect character encountered") + of '\x09', '\x20', '\x21': + res.add(ch) + of '\x22': + res.add('\\') + res.add('"') + of '\x23' .. '\x7f': + res.add(ch) + else: + return err("Incorrect character encountered") + ok(res) + else: + ok(name) + +proc init*[B: BChar](mpt: typedesc[MultiPartWriter], + boundary: openarray[B]): MultiPartWriter {. + raises: [Defect].} = + ## Create new MultiPartWriter instance with `buffer` interface. + ## + ## ``boundary`` - is multipart boundary, this value must not be empty. + doAssert(validateBoundary(boundary).isOk()) + + let sboundary = + when B is char: + @(boundary.toOpenArrayByte(0, len(boundary) - 1)) + else: + @boundary + + var finishMark = sboundary + finishMark.add([0x2d'u8, 0x2d'u8, 0x0d'u8, 0x0a'u8]) + var beginPartMark = sboundary + beginPartMark.add([0x0d'u8, 0x0a'u8]) + + MultiPartWriter( + kind: MultiPartSource.Buffer, + buffer: newSeq[byte](), + beginMark: @[0x2d'u8, 0x2d'u8], + finishMark: finishMark, + beginPartMark: beginPartMark, + finishPartMark: @[0x0d'u8, 0x0a'u8, 0x2d'u8, 0x2d'u8], + state: MultiPartWriterState.MessagePreparing + ) + +proc new*[B: BChar](mpt: typedesc[MultiPartWriterRef], + stream: HttpBodyWriter, + boundary: openarray[B]): MultiPartWriterRef {. + raises: [Defect].} = + doAssert(validateBoundary(boundary).isOk()) + doAssert(not(isNil(stream))) + + let sboundary = + when B is char: + @(boundary.toOpenArrayByte(0, len(boundary) - 1)) + else: + @boundary + + var finishMark = sboundary + finishMark.add([0x2d'u8, 0x2d'u8, 0x0d'u8, 0x0a'u8]) + var beginPartMark = sboundary + beginPartMark.add([0x0d'u8, 0x0a'u8]) + + MultiPartWriterRef( + kind: MultiPartSource.Stream, + stream: stream, + beginMark: @[0x2d'u8, 0x2d'u8], + finishMark: finishMark, + beginPartMark: beginPartMark, + finishPartMark: @[0x0d'u8, 0x0a'u8, 0x2d'u8, 0x2d'u8], + state: MultiPartWriterState.MessagePreparing + ) + +proc prepareHeaders(partMark: openarray[byte], name: string, filename: string, + headers: HttpTable): string = + const ContentDisposition = "Content-Disposition" + let qname = + block: + let res = quoteCheck(name) + doAssert(res.isOk()) + res.get() + let qfilename = + block: + let res = quoteCheck(filename) + doAssert(res.isOk()) + res.get() + var buffer = newString(len(partMark)) + copyMem(addr buffer[0], unsafeAddr partMark[0], len(partMark)) + buffer.add(ContentDisposition) + buffer.add(": ") + if ContentDisposition in headers: + buffer.add(headers.getString(ContentDisposition)) + buffer.add("\r\n") + else: + buffer.add("form-data; name=\"") + buffer.add(qname) + buffer.add("\"") + if len(qfilename) > 0: + buffer.add("; filename=\"") + buffer.add(qfilename) + buffer.add("\"") + buffer.add("\r\n") + + for k, v in headers.stringItems(): + if k != toLowerAscii(ContentDisposition): + if len(v) > 0: + buffer.add(k) + buffer.add(": ") + buffer.add(v) + buffer.add("\r\n") + buffer.add("\r\n") + buffer + +proc begin*(mpw: MultiPartWriterRef) {.async.} = + ## Starts multipart message form and write approprate markers to output + ## stream. + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state == MultiPartWriterState.MessagePreparing) + # write "--" + try: + await mpw.stream.write(mpw.beginMark) + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to start multipart message") + mpw.state = MultiPartWriterState.MessageStarted + +proc begin*(mpw: var MultiPartWriter) = + ## Starts multipart message form and write approprate markers to output + ## buffer. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.MessagePreparing) + # write "--" + mpw.buffer.add(mpw.beginMark) + mpw.state = MultiPartWriterState.MessageStarted + +proc beginPart*(mpw: MultiPartWriterRef, name: string, + filename: string, headers: HttpTable) {.async.} = + ## Starts part of multipart message and write appropriate ``headers`` to the + ## output stream. + ## + ## Note: `filename` and `name` arguments could be only ASCII strings. + const ContentDisposition = "Content-Disposition" + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state in {MultiPartWriterState.MessageStarted, + MultiPartWriterState.PartFinished}) + # write "" + # write "" + # write "" + let buffer = prepareHeaders(mpw.beginPartMark, name, filename, headers) + try: + await mpw.stream.write(buffer) + mpw.state = MultiPartWriterState.PartStarted + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to start multipart part") + +proc beginPart*(mpw: var MultiPartWriter, name: string, + filename: string, headers: HttpTable) = + ## Starts part of multipart message and write appropriate ``headers`` to the + ## output stream. + ## + ## Note: `filename` and `name` arguments could be only ASCII strings. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state in {MultiPartWriterState.MessageStarted, + MultiPartWriterState.PartFinished}) + let buffer = prepareHeaders(mpw.beginPartMark, name, filename, headers) + # write "" + # write "" + # write "" + mpw.buffer.add(buffer.toOpenArrayByte(0, len(buffer) - 1)) + mpw.state = MultiPartWriterState.PartStarted + +proc write*(mpw: MultiPartWriterRef, pbytes: pointer, nbytes: int) {.async.} = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + try: + # write of data + await mpw.stream.write(pbytes, nbytes) + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to write multipart data") + +proc write*(mpw: MultiPartWriterRef, data: seq[byte]) {.async.} = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + try: + # write of data + await mpw.stream.write(data) + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to write multipart data") + +proc write*(mpw: MultiPartWriterRef, data: string) {.async.} = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + try: + # write of data + await mpw.stream.write(data) + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to write multipart data") + +proc write*(mpw: var MultiPartWriter, pbytes: pointer, nbytes: int) = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + let index = len(mpw.buffer) + if nbytes > 0: + mpw.buffer.setLen(index + nbytes) + copyMem(addr mpw.buffer[0], pbytes, nbytes) + +proc write*(mpw: var MultiPartWriter, data: openarray[byte]) = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + mpw.buffer.add(data) + +proc write*(mpw: var MultiPartWriter, data: openarray[char]) = + ## Write part's data ``data`` to the output stream. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + mpw.buffer.add(data.toOpenArrayByte(0, len(data) - 1)) + +proc finishPart*(mpw: MultiPartWriterRef) {.async.} = + ## Finish multipart's message part and send proper markers to output stream. + doAssert(mpw.state == MultiPartWriterState.PartStarted) + try: + # write "--" + await mpw.stream.write(mpw.finishPartMark) + mpw.state = MultiPartWriterState.PartFinished + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to finish multipart message part") + +proc finishPart*(mpw: var MultiPartWriter) = + ## Finish multipart's message part and send proper markers to output stream. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.PartStarted) + # write "--" + mpw.buffer.add(mpw.finishPartMark) + mpw.state = MultiPartWriterState.PartFinished + +proc finish*(mpw: MultiPartWriterRef) {.async.} = + ## Finish multipart's message form and send finishing markers to the output + ## stream. + doAssert(mpw.kind == MultiPartSource.Stream) + doAssert(mpw.state == MultiPartWriterState.PartFinished) + try: + # write "--" + await mpw.stream.write(mpw.finishMark) + mpw.state = MultiPartWriterState.MessageFinished + except AsyncStreamError: + mpw.state = MultiPartWriterState.MessageFailure + raiseHttpCriticalError("Unable to finish multipart message") + +proc finish*(mpw: var MultiPartWriter): seq[byte] = + ## Finish multipart's message form and send finishing markers to the output + ## stream. + doAssert(mpw.kind == MultiPartSource.Buffer) + doAssert(mpw.state == MultiPartWriterState.PartFinished) + # write "--" + mpw.buffer.add(mpw.finishMark) + mpw.state = MultiPartWriterState.MessageFinished + mpw.buffer diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index be0ebc84..305090c7 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -377,6 +377,21 @@ proc address*(ta: TransportAddress): IpAddress {. else: raise newException(ValueError, "IpAddress supports only IPv4/IPv6!") +proc host*(ta: TransportAddress): string {.raises: [Defect].} = + ## Returns ``host`` of TransportAddress ``ta``. + ## + ## For IPv4 and IPv6 addresses it will return IP address as string, or empty + ## string for Unix address. + case ta.family + of AddressFamily.IPv4: + $IpAddress(family: IpAddressFamily.IPv4, address_v4: ta.address_v4) + of AddressFamily.IPv6: + let a = $IpAddress(family: IpAddressFamily.IPv6, + address_v6: ta.address_v6) + "[" & a & "]" + else: + "" + proc resolveTAddress*(address: string, port: Port, domain: Domain): seq[TransportAddress] {. raises: [Defect, TransportAddressError].} = diff --git a/tests/testall.nim b/tests/testall.nim index decec74c..021442e2 100644 --- a/tests/testall.nim +++ b/tests/testall.nim @@ -7,5 +7,5 @@ # MIT license (LICENSE-MIT) import testmacro, testsync, testsoon, testtime, testfut, testsignal, testaddress, testdatagram, teststream, testserver, testbugs, testnet, - testasyncstream, testhttpserver, testshttpserver + testasyncstream, testhttpserver, testshttpserver, testhttpclient import testutils diff --git a/tests/testhttpclient.nim b/tests/testhttpclient.nim new file mode 100644 index 00000000..34da4683 --- /dev/null +++ b/tests/testhttpclient.nim @@ -0,0 +1,769 @@ +# Chronos Test Suite +# (c) Copyright 2021-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[strutils, strutils, sha1] +import unittest2 +import ../chronos, ../chronos/apps/http/[httpserver, shttpserver, httpclient] +import stew/base10 + +when defined(nimHasUsed): {.used.} + +# To create self-signed certificate and key you can use openssl +# openssl req -new -x509 -sha256 -newkey rsa:2048 -nodes \ +# -keyout example-com.key.pem -days 3650 -out example-com.cert.pem +const HttpsSelfSignedRsaKey = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCn7tXGLKMIMzOG +tVzUixax1/ftlSLcpEAkZMORuiCCnYjtIJhGZdzRFZC8fBlfAJZpLIAOfX2L2f1J +ZuwpwDkOIvNqKMBrl5Mvkl5azPT0rtnjuwrcqN5NFtbmZPKFYvbjex2aXGqjl5MW +nQIs/ZA++DVEXmaN9oDxcZsvRMDKfrGQf9iLeoVL47Gx9KpqNqD/JLIn4LpieumV +yYidm6ukTOqHRvrWm36y6VvKW4TE97THacULmkeahtTf8zDJbbh4EO+gifgwgJ2W +BUS0+5hMcWu8111mXmanlOVlcoW8fH8RmPjL1eK1Z3j3SVHEf7oWZtIVW5gGA0jQ +nfA4K51RAgMBAAECggEANZ7/R13tWKrwouy6DWuz/WlWUtgx333atUQvZhKmWs5u +cDjeJmxUC7b1FhoSB9GqNT7uTLIpKkSaqZthgRtNnIPwcU890Zz+dEwqMJgNByvl +it+oYjjRco/+YmaNQaYN6yjelPE5Y678WlYb4b29Fz4t0/zIhj/VgEKkKH2tiXpS +TIicoM7pSOscEUfaW3yp5bS5QwNU6/AaF1wws0feBACd19ZkcdPvr52jopbhxlXw +h3XTV/vXIJd5zWGp0h/Jbd4xcD4MVo2GjfkeORKY6SjDaNzt8OGtePcKnnbUVu8b +2XlDxukhDQXqJ3g0sHz47mhvo4JeIM+FgymRm+3QmQKBgQDTawrEA3Zy9WvucaC7 +Zah02oE9nuvpF12lZ7WJh7+tZ/1ss+Fm7YspEKaUiEk7nn1CAVFtem4X4YCXTBiC +Oqq/o+ipv1yTur0ae6m4pwLm5wcMWBh3H5zjfQTfrClNN8yjWv8u3/sq8KesHPnT +R92/sMAptAChPgTzQphWbxFiYwKBgQDLWFaBqXfZYVnTyUvKX8GorS6jGWc6Eh4l +lAFA+2EBWDICrUxsDPoZjEXrWCixdqLhyehaI3KEFIx2bcPv6X2c7yx3IG5lA/Gx +TZiKlY74c6jOTstkdLW9RJbg1VUHUVZMf/Owt802YmEfUI5S5v7jFmKW6VG+io+K ++5KYeHD1uwKBgQDMf53KPA82422jFwYCPjLT1QduM2q97HwIomhWv5gIg63+l4BP +rzYMYq6+vZUYthUy41OAMgyLzPQ1ZMXQMi83b7R9fTxvKRIBq9xfYCzObGnE5vHD +SDDZWvR75muM5Yxr9nkfPkgVIPMO6Hg+hiVYZf96V0LEtNjU9HWmJYkLQQKBgQCQ +ULGUdGHKtXy7AjH3/t3CiKaAupa4cANVSCVbqQy/l4hmvfdu+AbH+vXkgTzgNgKD +nHh7AI1Vj//gTSayLlQn/Nbh9PJkXtg5rYiFUn+VdQBo6yMOuIYDPZqXFtCx0Nge +kvCwisHpxwiG4PUhgS+Em259DDonsM8PJFx2OYRx4QKBgEQpGhg71Oi9MhPJshN7 +dYTowaMS5eLTk2264ARaY+hAIV7fgvUa+5bgTVaWL+Cfs33hi4sMRqlEwsmfds2T +cnQiJ4cU20Euldfwa5FLnk6LaWdOyzYt/ICBJnKFRwfCUbS4Bu5rtMEM+3t0wxnJ +IgaD04WhoL9EX0Qo3DC1+0kG +-----END PRIVATE KEY----- +""" + +# This SSL certificate will expire 13 October 2030. +const HttpsSelfSignedRsaCert = """ +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIUUdcusjDd3XQi3FPM8urdFG3qI+8wDQYJKoZIhvcNAQEL +BQAwXzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEYMBYGA1UEAwwPMTI3LjAuMC4xOjQz +ODA4MB4XDTIwMTAxMjIxNDUwMVoXDTMwMTAxMDIxNDUwMVowXzELMAkGA1UEBhMC +QVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdp +dHMgUHR5IEx0ZDEYMBYGA1UEAwwPMTI3LjAuMC4xOjQzODA4MIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEAp+7VxiyjCDMzhrVc1IsWsdf37ZUi3KRAJGTD +kboggp2I7SCYRmXc0RWQvHwZXwCWaSyADn19i9n9SWbsKcA5DiLzaijAa5eTL5Je +Wsz09K7Z47sK3KjeTRbW5mTyhWL243sdmlxqo5eTFp0CLP2QPvg1RF5mjfaA8XGb +L0TAyn6xkH/Yi3qFS+OxsfSqajag/ySyJ+C6YnrplcmInZurpEzqh0b61pt+sulb +yluExPe0x2nFC5pHmobU3/MwyW24eBDvoIn4MICdlgVEtPuYTHFrvNddZl5mp5Tl +ZXKFvHx/EZj4y9XitWd490lRxH+6FmbSFVuYBgNI0J3wOCudUQIDAQABo1MwUTAd +BgNVHQ4EFgQUBKha84woY5WkFxKw7qx1cONg1H8wHwYDVR0jBBgwFoAUBKha84wo +Y5WkFxKw7qx1cONg1H8wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC +AQEAHZMYt9Ry+Xj3vTbzpGFQzYQVTJlfJWSN6eWNOivRFQE5io9kOBEe5noa8aLo +dLkw6ztxRP2QRJmlhGCO9/HwS17ckrkgZp3EC2LFnzxcBmoZu+owfxOT1KqpO52O +IKOl8eVohi1pEicE4dtTJVcpI7VCMovnXUhzx1Ci4Vibns4a6H+BQa19a1JSpifN +tO8U5jkjJ8Jprs/VPFhJj2O3di53oDHaYSE5eOrm2ZO14KFHSk9cGcOGmcYkUv8B +nV5vnGadH5Lvfxb/BCpuONabeRdOxMt9u9yQ89vNpxFtRdZDCpGKZBCfmUP+5m3m +N8r5CwGcIX/XPC3lKazzbZ8baA== +-----END CERTIFICATE----- +""" + +suite "HTTP client testing suite": + + proc createBigMessage(message: string, size: int): seq[byte] = + var res = newSeq[byte](size) + for i in 0 ..< len(res): + res[i] = byte(message[i mod len(message)]) + res + + proc createServer(address: TransportAddress, + process: HttpProcessCallback, secure: bool): HttpServerRef = + let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + if secure: + let secureKey = TLSPrivateKey.init(HttpsSelfSignedRsaKey) + let secureCert = TLSCertificate.init(HttpsSelfSignedRsaCert) + let res = SecureHttpServerRef.new(address, process, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + HttpServerRef(res.get()) + else: + let res = HttpServerRef.new(address, process, socketFlags = socketFlags) + res.get() + + proc createSession(secure: bool, + maxRedirections = HttpMaxRedirections): HttpSessionRef = + if secure: + HttpSessionRef.new({HttpClientFlag.NoVerifyHost, + HttpClientFlag.NoVerifyServerName}, + maxRedirections = maxRedirections) + else: + HttpSessionRef.new(maxRedirections = maxRedirections) + + proc testMethods(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let RequestTests = [ + (MethodGet, "/test/get"), + (MethodPost, "/test/post"), + (MethodHead, "/test/head"), + (MethodPut, "/test/put"), + (MethodDelete, "/test/delete"), + (MethodTrace, "/test/trace"), + (MethodOptions, "/test/options"), + (MethodConnect, "/test/connect"), + (MethodPatch, "/test/patch") + ] + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/get", "/test/post", "/test/head", "/test/put", + "/test/delete", "/test/trace", "/test/options", "/test/connect", + "/test/patch", "/test/error": + return await request.respond(Http200, request.uri.path) + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + var session = createSession(secure) + + for item in RequestTests: + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var req = HttpClientRequestRef.new(session, ha, item[0]) + let response = await fetch(req) + if response.status == 200: + let data = cast[string](response.data) + if data == item[1]: + inc(counter) + await req.closeWait() + await session.closeWait() + + for item in RequestTests: + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var req = HttpClientRequestRef.new(session, ha, item[0]) + let response = await fetch(req) + if response.status == 200: + let data = cast[string](response.data) + if data == item[1]: + inc(counter) + await req.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testResponseStreamReadingTest(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let ResponseTests = [ + (MethodGet, "/test/short_size_response", 65600, 1024, + "SHORTSIZERESPONSE"), + (MethodGet, "/test/long_size_response", 262400, 1024, + "LONGSIZERESPONSE"), + (MethodGet, "/test/short_chunked_response", 65600, 1024, + "SHORTCHUNKRESPONSE"), + (MethodGet, "/test/long_chunked_response", 262400, 1024, + "LONGCHUNKRESPONSE") + ] + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/short_size_response": + var response = request.getResponse() + var data = createBigMessage(ResponseTests[0][4], ResponseTests[0][2]) + response.status = Http200 + await response.sendBody(data) + return response + of "/test/long_size_response": + var response = request.getResponse() + var data = createBigMessage(ResponseTests[1][4], ResponseTests[1][2]) + response.status = Http200 + await response.sendBody(data) + return response + of "/test/short_chunked_response": + var response = request.getResponse() + var data = createBigMessage(ResponseTests[2][4], ResponseTests[2][2]) + response.status = Http200 + await response.prepare() + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(1024, len(data) - offset) + await response.sendChunk(addr data[offset], toWrite) + offset = offset + toWrite + await response.finish() + return response + of "/test/long_chunked_response": + var response = request.getResponse() + var data = createBigMessage(ResponseTests[3][4], ResponseTests[3][2]) + response.status = Http200 + await response.prepare() + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(1024, len(data) - offset) + await response.sendChunk(addr data[offset], toWrite) + offset = offset + toWrite + await response.finish() + return response + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + var session = createSession(secure) + for item in ResponseTests: + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var req = HttpClientRequestRef.new(session, ha, item[0]) + var response = await send(req) + if response.status == 200: + var reader = response.getBodyReader() + var res: seq[byte] + while true: + var data = await reader.read(item[3]) + res.add(data) + if len(data) != item[3]: + break + await reader.closeWait() + if len(res) == item[2]: + let expect = createBigMessage(item[4], len(res)) + if expect == res: + inc(counter) + await response.closeWait() + await req.closeWait() + await session.closeWait() + + for item in ResponseTests: + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var req = HttpClientRequestRef.new(session, ha, item[0]) + var response = await send(req) + if response.status == 200: + var reader = response.getBodyReader() + var res: seq[byte] + while true: + var data = await reader.read(item[3]) + res.add(data) + if len(data) != item[3]: + break + await reader.closeWait() + if len(res) == item[2]: + let expect = createBigMessage(item[4], len(res)) + if expect == res: + inc(counter) + await response.closeWait() + await req.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testRequestSizeStreamWritingTest(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let RequestTests = [ + (MethodPost, "/test/big_request", 65600), + (MethodPost, "/test/big_request", 262400) + ] + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/big_request": + if request.hasBody(): + let body = await request.getBody() + let digest = $secureHash(cast[string](body)) + return await request.respond(Http200, digest) + else: + return await request.respond(Http400, "Missing content body") + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + var session = createSession(secure) + for item in RequestTests: + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var data = createBigMessage("REQUESTSTREAMMESSAGE", item[2]) + let headers = [ + ("Content-Type", "application/octet-stream"), + ("Content-Length", Base10.toString(uint64(len(data)))) + ] + var request = HttpClientRequestRef.new( + session, ha, item[0], headers = headers + ) + + var expectDigest = $secureHash(cast[string](data)) + # Sending big request by 1024bytes long chunks + var writer = await open(request) + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(1024, len(data) - offset) + await writer.write(addr data[offset], toWrite) + offset = offset + toWrite + await writer.finish() + await writer.closeWait() + var response = await request.finish() + + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == expectDigest: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testRequestChunkedStreamWritingTest(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let RequestTests = [ + (MethodPost, "/test/big_chunk_request", 65600), + (MethodPost, "/test/big_chunk_request", 262400) + ] + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/big_chunk_request": + if request.hasBody(): + let body = await request.getBody() + let digest = $secureHash(cast[string](body)) + return await request.respond(Http200, digest) + else: + return await request.respond(Http400, "Missing content body") + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + var session = createSession(secure) + for item in RequestTests: + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, item[1]) + else: + getAddress(address, HttpClientScheme.NonSecure, item[1]) + var data = createBigMessage("REQUESTSTREAMMESSAGE", item[2]) + let headers = [ + ("Content-Type", "application/octet-stream"), + ("Transfer-Encoding", "chunked") + ] + var request = HttpClientRequestRef.new( + session, ha, item[0], headers = headers + ) + + var expectDigest = $secureHash(cast[string](data)) + # Sending big request by 1024bytes long chunks + var writer = await open(request) + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(1024, len(data) - offset) + await writer.write(addr data[offset], toWrite) + offset = offset + toWrite + await writer.finish() + await writer.closeWait() + var response = await request.finish() + + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == expectDigest: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testRequestPostUrlEncodedTest(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let PostRequests = [ + ("/test/post/urlencoded_size", + "field1=value1&field2=value2&field3=value3", "value1:value2:value3"), + ("/test/post/urlencoded_chunked", + "field1=longlonglongvalue1&field2=longlonglongvalue2&" & + "field3=longlonglongvalue3", "longlonglongvalue1:longlonglongvalue2:" & + "longlonglongvalue3") + ] + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/post/urlencoded_size", "/test/post/urlencoded_chunked": + if request.hasBody(): + var postTable = await request.post() + let body = postTable.getString("field1") & ":" & + postTable.getString("field2") & ":" & + postTable.getString("field3") + return await request.respond(Http200, body) + else: + return await request.respond(Http400, "Missing content body") + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + ## Sized url-encoded form + block: + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, PostRequests[0][0]) + else: + getAddress(address, HttpClientScheme.NonSecure, PostRequests[0][0]) + let headers = [ + ("Content-Type", "application/x-www-form-urlencoded"), + ] + var request = HttpClientRequestRef.new( + session, ha, MethodPost, headers = headers, + body = cast[seq[byte]](PostRequests[0][1])) + var response = await send(request) + + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == PostRequests[0][2]: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + ## Chunked url-encoded form + block: + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, PostRequests[1][0]) + else: + getAddress(address, HttpClientScheme.NonSecure, PostRequests[1][0]) + let headers = [ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Transfer-Encoding", "chunked") + ] + var request = HttpClientRequestRef.new( + session, ha, MethodPost, headers = headers) + + var data = PostRequests[1][1] + + var writer = await open(request) + var offset = 0 + while true: + if len(data) == offset: + break + let toWrite = min(16, len(data) - offset) + await writer.write(addr data[offset], toWrite) + offset = offset + toWrite + await writer.finish() + await writer.closeWait() + var response = await request.finish() + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == PostRequests[1][2]: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testRequestPostMultipartTest(address: TransportAddress, + secure: bool): Future[int] {.async.} = + let PostRequests = [ + ("/test/post/multipart_size", "some-part-boundary", + [("field1", "value1"), ("field2", "value2"), ("field3", "value3")], + "value1:value2:value3"), + ("/test/post/multipart_chunked", "some-part-boundary", + [("field1", "longlonglongvalue1"), ("field2", "longlonglongvalue2"), + ("field3", "longlonglongvalue3")], + "longlonglongvalue1:longlonglongvalue2:longlonglongvalue3") + ] + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/test/post/multipart_size", "/test/post/multipart_chunked": + if request.hasBody(): + var postTable = await request.post() + let body = postTable.getString("field1") & ":" & + postTable.getString("field2") & ":" & + postTable.getString("field3") + return await request.respond(Http200, body) + else: + return await request.respond(Http400, "Missing content body") + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + var counter = 0 + + ## Sized multipart form + block: + var mp = MultiPartWriter.init(PostRequests[0][1]) + mp.begin() + for item in PostRequests[0][2]: + mp.beginPart(item[0], "", HttpTable.init()) + mp.write(item[1]) + mp.finishPart() + let data = mp.finish() + + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, PostRequests[0][0]) + else: + getAddress(address, HttpClientScheme.NonSecure, PostRequests[0][0]) + let headers = [ + ("Content-Type", "multipart/form-data; boundary=" & PostRequests[0][1]), + ] + var request = HttpClientRequestRef.new( + session, ha, MethodPost, headers = headers, body = data) + var response = await send(request) + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == PostRequests[0][3]: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + ## Chunked multipart form + block: + var session = createSession(secure) + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, PostRequests[0][0]) + else: + getAddress(address, HttpClientScheme.NonSecure, PostRequests[0][0]) + let headers = [ + ("Content-Type", "multipart/form-data; boundary=" & PostRequests[1][1]), + ("Transfer-Encoding", "chunked") + ] + var request = HttpClientRequestRef.new( + session, ha, MethodPost, headers = headers) + var writer = await open(request) + var mpw = MultiPartWriterRef.new(writer, PostRequests[1][1]) + await mpw.begin() + for item in PostRequests[1][2]: + await mpw.beginPart(item[0], "", HttpTable.init()) + await mpw.write(item[1]) + await mpw.finishPart() + await mpw.finish() + await writer.finish() + await writer.closeWait() + let response = await request.finish() + if response.status == 200: + var res = await response.getBodyBytes() + if cast[string](res) == PostRequests[1][3]: + inc(counter) + await response.closeWait() + await request.closeWait() + await session.closeWait() + + await server.stop() + await server.closeWait() + return counter + + proc testRequestRedirectTest(address: TransportAddress, + secure: bool, + max: int): Future[string] {.async.} = + var session = createSession(secure, maxRedirections = max) + + let ha = + if secure: + getAddress(address, HttpClientScheme.Secure, "/") + else: + getAddress(address, HttpClientScheme.NonSecure, "/") + let lastAddress = ha.getUri().combine(parseUri("/final/5")) + + proc process(r: RequestFence): Future[HttpResponseRef] {. + async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/": + return await request.redirect(Http302, "/redirect/1") + of "/redirect/1": + return await request.redirect(Http302, "/next/redirect/2") + of "/next/redirect/2": + return await request.redirect(Http302, "redirect/3") + of "/next/redirect/redirect/3": + return await request.redirect(Http302, "next/redirect/4") + of "/next/redirect/redirect/next/redirect/4": + return await request.redirect(Http302, lastAddress) + of "/final/5": + return await request.respond(Http200, "ok-5") + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, secure) + server.start() + if session.maxRedirections >= 5: + let (code, data) = await session.fetch(ha.getUri()) + await session.closeWait() + await server.stop() + await server.closeWait() + return data.bytesToString() & "-" & $code + else: + let res = + try: + let (code {.used.}, data {.used.}) = await session.fetch(ha.getUri()) + false + except HttpRedirectError: + true + except CatchableError: + false + await session.closeWait() + await server.stop() + await server.closeWait() + return "redirect-" & $res + + test "HTTP all request methods test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testMethods(address, false)) == 18 + + test "HTTP(S) all request methods test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testMethods(address, true)) == 18 + + test "HTTP client response streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testResponseStreamReadingTest(address, false)) == 8 + + test "HTTP(S) client response streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testResponseStreamReadingTest(address, true)) == 8 + + test "HTTP client (size) request streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestSizeStreamWritingTest(address, false)) == 2 + + test "HTTP(S) client (size) request streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestSizeStreamWritingTest(address, true)) == 2 + + test "HTTP client (chunked) request streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestChunkedStreamWritingTest(address, false)) == 2 + + test "HTTP(S) client (chunked) request streaming test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestChunkedStreamWritingTest(address, true)) == 2 + + test "HTTP client (size + chunked) url-encoded POST test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestPostUrlEncodedTest(address, false)) == 2 + + test "HTTP(S) client (size + chunked) url-encoded POST test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestPostUrlEncodedTest(address, true)) == 2 + + test "HTTP client (size + chunked) multipart POST test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestPostMultipartTest(address, false)) == 2 + + test "HTTP(S) client (size + chunked) multipart POST test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestPostMultipartTest(address, true)) == 2 + + test "HTTP client redirection test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestRedirectTest(address, false, 5)) == "ok-5-200" + + test "HTTP(S) client redirection test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestRedirectTest(address, true, 5)) == "ok-5-200" + + test "HTTP client maximum redirections test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestRedirectTest(address, false, 4)) == "redirect-true" + + test "HTTP(S) client maximum redirections test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testRequestRedirectTest(address, true, 4)) == "redirect-true" + + test "Leaks test": + proc getTrackerLeaks(tracker: string): bool = + let tracker = getTracker(tracker) + if isNil(tracker): false else: tracker.isLeaked() + + check: + getTrackerLeaks("http.body.reader") == false + getTrackerLeaks("http.body.writer") == false + getTrackerLeaks("httpclient.connection") == false + getTrackerLeaks("httpclient.request") == false + getTrackerLeaks("httpclient.response") == false + getTrackerLeaks("async.stream.reader") == false + getTrackerLeaks("async.stream.writer") == false + getTrackerLeaks("stream.server") == false + getTrackerLeaks("stream.transport") == false