diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index 997425b8..e123cce1 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -6,8 +6,8 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import stew/results, httputils -export results, httputils +import stew/results, httputils, strutils, uri +export results, httputils, strutils const useChroniclesLogging* {.booldefine.} = false @@ -19,9 +19,110 @@ type HttpResultCode*[T] = Result[T, HttpCode] HttpError* = object of CatchableError - HttpCriticalFailure* = object of HttpError - HttpRecoverableFailure* = object of HttpError + HttpCriticalError* = object of HttpError + HttpRecoverableError* = object of HttpError + + TransferEncodingFlags* {.pure.} = enum + Identity, Chunked, Compress, Deflate, Gzip + + ContentEncodingFlags* {.pure.} = enum + Identity, Br, Compress, Deflate, Gzip template log*(body: untyped) = when defined(useChroniclesLogging): body + +proc newHttpCriticalError*(msg: string): ref HttpCriticalError = + newException(HttpCriticalError, msg) + +proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError = + newException(HttpRecoverableError, msg) + +iterator queryParams*(query: string): tuple[key: string, value: string] = + ## Iterate over url-encoded query string. + for pair in query.split('&'): + let items = pair.split('=', maxsplit = 1) + let k = items[0] + let v = if len(items) > 1: items[1] else: "" + yield (decodeUrl(k), decodeUrl(v)) + +func getTransferEncoding*(ch: openarray[string]): HttpResult[ + set[TransferEncodingFlags]] = + ## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return + ## it as set of ``TransferEncodingFlags``. + var res: set[TransferEncodingFlags] = {} + if len(ch) == 0: + res.incl(TransferEncodingFlags.Identity) + ok(res) + else: + for header in ch: + for item in header.split(","): + case strip(item.toLowerAscii()) + of "identity": + res.incl(TransferEncodingFlags.Identity) + of "chunked": + res.incl(TransferEncodingFlags.Chunked) + of "compress": + res.incl(TransferEncodingFlags.Compress) + of "deflate": + res.incl(TransferEncodingFlags.Deflate) + of "gzip": + res.incl(TransferEncodingFlags.Gzip) + of "": + res.incl(TransferEncodingFlags.Identity) + else: + return err("Incorrect Transfer-Encoding value") + ok(res) + +func getContentEncoding*(ch: openarray[string]): HttpResult[ + set[ContentEncodingFlags]] = + ## Parse value of multiple HTTP headers ``Content-Encoding`` and return + ## it as set of ``ContentEncodingFlags``. + var res: set[ContentEncodingFlags] = {} + if len(ch) == 0: + res.incl(ContentEncodingFlags.Identity) + ok(res) + else: + for header in ch: + for item in header.split(","): + case strip(item.toLowerAscii()): + of "identity": + res.incl(ContentEncodingFlags.Identity) + of "br": + res.incl(ContentEncodingFlags.Br) + of "compress": + res.incl(ContentEncodingFlags.Compress) + of "deflate": + res.incl(ContentEncodingFlags.Deflate) + of "gzip": + res.incl(ContentEncodingFlags.Gzip) + of "": + res.incl(ContentEncodingFlags.Identity) + else: + return err("Incorrect Content-Encoding value") + ok(res) + +func getContentType*(ch: openarray[string]): HttpResult[string] = + ## Check and prepare value of ``Content-Type`` header. + if len(ch) > 1: + err("Multiple Content-Type values found") + else: + let mparts = ch[0].split(";") + ok(strip(mparts[0]).toLowerAscii()) + +func getMultipartBoundary*(contentType: string): HttpResult[string] = + ## Process ``multipart/form-data`` ``Content-Type`` header and return + ## multipart boundary. + let mparts = contentType.split(";") + if strip(mparts[0]).toLowerAscii() != "multipart/form-data": + return err("Content-Type is not multipart") + if len(mparts) < 2: + return err("Content-Type missing boundary value") + let stripped = strip(mparts[1]) + if not(stripped.toLowerAscii().startsWith("boundary")): + return err("Incorrect Content-Type boundary format") + let bparts = stripped.split("=") + if len(bparts) < 2: + err("Missing Content-Type boundary") + else: + ok(strip(bparts[1])) diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index b4d6185c..5ca33dbb 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -21,15 +21,25 @@ type HttpServerFlags* = enum Secure - TransferEncodingFlags* {.pure.} = enum - Identity, Chunked, Compress, Deflate, Gzip + HttpConnectionStatus* = enum + DropConnection, KeepConnection - ContentEncodingFlags* {.pure.} = enum - Identity, Br, Compress, Deflate, Gzip + HttpErrorEnum* = enum + TimeoutError, CatchableError, RecoverableError, CriticalError + + HttpProcessError* = object + error*: HttpErrorEnum + exc*: HttpError + remote*: TransportAddress + + HttpProcessStatus*[T] = Result[T, HttpProcessError] HttpRequestFlags* {.pure.} = enum BoundBody, UnboundBody, MultipartForm, UrlencodedForm + HttpProcessCallback* = + proc(request: HttpProcessStatus[HttpRequest]): Future[HttpStatus] + HttpServer* = ref object of RootRef instance*: StreamServer # semaphore*: AsyncSemaphore @@ -43,6 +53,7 @@ type bodyTimeout: Duration maxHeadersSize: int maxRequestBodySize: int + processCallback: HttpProcessCallback HttpServerState* = enum ServerRunning, ServerStopped, ServerClosed @@ -64,18 +75,24 @@ type connection*: HttpConnection mainReader*: AsyncStreamReader + HttpResponse* = object + code*: HttpCode + version*: HttpVersion + headersTable: HttpTable + body*: seq[byte] + connection*: HttpConnection + mainWriter: AsyncStreamWriter + HttpConnection* = ref object of RootRef - server: HttpServer + server*: HttpServer transp: StreamTransport buffer: seq[byte] -const - HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] - proc new*(htype: typedesc[HttpServer], address: TransportAddress, flags: set[HttpServerFlags] = {}, serverUri = Uri(), + processCallback: HttpProcessCallback, maxConnections: int = -1, bufferSize: int = 4096, backlogSize: int = 100, @@ -88,7 +105,8 @@ proc new*(htype: typedesc[HttpServer], headersTimeout: httpHeadersTimeout, bodyTimeout: httpBodyTimeout, maxHeadersSize: maxHeadersSize, - maxRequestBodySize: maxRequestBodySize + maxRequestBodySize: maxRequestBodySize, + processCallback: processCallback ) res.baseUri = @@ -118,92 +136,6 @@ proc getId(transp: StreamTransport): string {.inline.} = ## Returns string unique transport's identifier as string. $transp.remoteAddress() & "_" & $transp.localAddress() -proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure = - newException(HttpCriticalFailure, msg) - -proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure = - newException(HttpRecoverableFailure, msg) - -iterator queryParams(query: string): tuple[key: string, value: string] = - for pair in query.split('&'): - let items = pair.split('=', maxsplit = 1) - let k = items[0] - let v = if len(items) > 1: items[1] else: "" - yield (decodeUrl(k), decodeUrl(v)) - -func getMultipartBoundary*(contentType: string): HttpResult[string] = - let mparts = contentType.split(";") - if strip(mparts[0]).toLowerAscii() != "multipart/form-data": - return err("Content-Type is not multipart") - if len(mparts) < 2: - return err("Content-Type missing boundary value") - let stripped = strip(mparts[1]) - if not(stripped.toLowerAscii().startsWith("boundary")): - return err("Incorrect Content-Type boundary format") - let bparts = stripped.split("=") - if len(bparts) < 2: - err("Missing Content-Type boundary") - else: - ok(strip(bparts[1])) - -func getContentType*(contentHeader: seq[string]): HttpResult[string] = - if len(contentHeader) > 1: - return err("Multiple Content-Type values found") - let mparts = contentHeader[0].split(";") - ok(strip(mparts[0]).toLowerAscii()) - -func getTransferEncoding(contentHeader: seq[string]): HttpResult[ - set[TransferEncodingFlags]] = - var res: set[TransferEncodingFlags] = {} - if len(contentHeader) == 0: - res.incl(TransferEncodingFlags.Identity) - ok(res) - else: - for header in contentHeader: - for item in header.split(","): - case strip(item.toLowerAscii()) - of "identity": - res.incl(TransferEncodingFlags.Identity) - of "chunked": - res.incl(TransferEncodingFlags.Chunked) - of "compress": - res.incl(TransferEncodingFlags.Compress) - of "deflate": - res.incl(TransferEncodingFlags.Deflate) - of "gzip": - res.incl(TransferEncodingFlags.Gzip) - of "": - res.incl(TransferEncodingFlags.Identity) - else: - return err("Incorrect Transfer-Encoding value") - ok(res) - -func getContentEncoding(contentHeader: seq[string]): HttpResult[ - set[ContentEncodingFlags]] = - var res: set[ContentEncodingFlags] = {} - if len(contentHeader) == 0: - res.incl(ContentEncodingFlags.Identity) - ok(res) - else: - for header in contentHeader: - for item in header.split(","): - case strip(item.toLowerAscii()): - of "identity": - res.incl(ContentEncodingFlags.Identity) - of "br": - res.incl(ContentEncodingFlags.Br) - of "compress": - res.incl(ContentEncodingFlags.Compress) - of "deflate": - res.incl(ContentEncodingFlags.Deflate) - of "gzip": - res.incl(ContentEncodingFlags.Gzip) - of "": - res.incl(ContentEncodingFlags.Identity) - else: - return err("Incorrect Content-Encoding value") - ok(res) - proc hasBody*(request: HttpRequest): bool = ## Returns ``true`` if request has body. request.requestFlags * {HttpRequestFlags.BoundBody, @@ -333,7 +265,7 @@ proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} = try: return await read(res.get()) except AsyncStreamError: - raise newHttpCriticalFailure("Read failure") + raise newHttpCriticalError("Read Error") proc consumeBody*(request: HttpRequest): Future[void] {.async.} = ## Consume/discard request's body. @@ -346,7 +278,7 @@ proc consumeBody*(request: HttpRequest): Future[void] {.async.} = discard await reader.consume() return except AsyncStreamError: - raise newHttpCriticalFailure("Read failure") + raise newHttpCriticalError("Read Error") proc sendErrorResponse(conn: HttpConnection, version: HttpVersion, code: HttpCode, keepAlive = true, @@ -374,17 +306,13 @@ proc sendErrorResponse(conn: HttpConnection, version: HttpVersion, except CatchableError: return false -proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true, - datatype = "text/text", - databody = ""): Future[bool] = +proc sendErrorResponse*(request: HttpRequest, code: HttpCode, keepAlive = true, + datatype = "text/text", + databody = ""): Future[bool] = sendErrorResponse(request.connection, request.version, code, keepAlive, datatype, databody) proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} = - when defined(useChroniclesLogging): - logScope: - peer = $conn.transp.remoteAddress - try: conn.buffer.setLen(conn.server.maxHeadersSize) let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer), @@ -392,42 +320,62 @@ proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} = conn.buffer.setLen(res) let header = parseRequest(conn.buffer) if header.failed(): - log debug "Malformed header received" discard await conn.sendErrorResponse(HttpVersion11, Http400, false) - raise newHttpCriticalFailure("Malformed request recieved") + raise newHttpCriticalError("Malformed request recieved") else: let res = prepareRequest(conn, header) if res.isErr(): discard await conn.sendErrorResponse(HttpVersion11, Http400, false) - raise newHttpCriticalFailure("Invalid request received") + raise newHttpCriticalError("Invalid request received") else: return res.get() except TransportOsError: - log debug "Unexpected OS error" - raise newHttpCriticalFailure("Unexpected OS error") + raise newHttpCriticalError("Unexpected OS error") except TransportIncompleteError: - log debug "Remote peer disconnected" - raise newHttpCriticalFailure("Remote peer disconnected") + raise newHttpCriticalError("Remote peer disconnected") except TransportLimitError: - log debug "Maximum size of request headers reached" discard await conn.sendErrorResponse(HttpVersion11, Http413, false) - raise newHttpCriticalFailure("Maximum size of request headers reached") + raise newHttpCriticalError("Maximum size of request headers reached") proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} = - when defined(useChroniclesLogging): - logScope: - peer = $transp.remoteAddress - var conn = HttpConnection( transp: transp, buffer: newSeq[byte](server.maxHeadersSize), server: server ) - log info "Client connected" var breakLoop = false while true: + var status: HttpProcessStatus + var arg: HttpProcessStatus[HttpRequest] try: let request = await conn.getRequest().wait(server.headersTimeout) + arg = ok(request) + except AsyncTimeoutError as exc: + discard await conn.sendErrorResponse(HttpVersion11, Http408, false) + breakLoop = true + arg = err(HttpProcessError(exc: exc, remote: transp.remoteAddress())) + except CancelledError: + breakLoop = true + except HttpRecoverableError: + breakLoop = false + arg = err() + except CatchableError: + breakLoop = true + + if breakLoop: + break + + breakLoop = false + let status = + try: + await conn.server.processCallback() + except CancelledError: + breakLoop = true + HttpCriticalError + except CatchableError: + breakLoop = true + HttpCriticalError + echo "== HEADERS TABLE" echo request.headersTable echo "== QUERY TABLE" @@ -439,19 +387,17 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} = echo cast[string](stream) discard await conn.sendErrorResponse(HttpVersion11, Http200, true, databody = "OK") - log debug "Response sent" except AsyncTimeoutError: - log debug "Timeout reached while reading headers" discard await conn.sendErrorResponse(HttpVersion11, Http408, false) breakLoop = true except CancelledError: breakLoop = true - except HttpRecoverableFailure: + except HttpRecoverableError: breakLoop = false - except HttpCriticalFailure: + except HttpCriticalError: breakLoop = true except CatchableError as exc: @@ -464,7 +410,6 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} = server.connections.del(transp.getId()) # if server.maxConnections > 0: # server.semaphore.release() - log info "Client got disconnected" proc acceptClientLoop(server: HttpServer) {.async.} = var breakLoop = false diff --git a/chronos/apps/http/multipart.nim b/chronos/apps/http/multipart.nim index d217c0b8..894bb4b1 100644 --- a/chronos/apps/http/multipart.nim +++ b/chronos/apps/http/multipart.nim @@ -8,9 +8,11 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) import std/[monotimes, strutils] -import chronos, stew/results +import stew/results +import ../../asyncloop +import ../../streams/[asyncstream, boundstream, chunkstream] import httptable, httpcommon -export httptable, httpcommon +export httptable, httpcommon, asyncstream type MultiPartSource {.pure.} = enum @@ -20,9 +22,9 @@ type case kind: MultiPartSource of MultiPartSource.Stream: stream: AsyncStreamReader - last: BoundedAsyncStreamReader of MultiPartSource.Buffer: discard + firstTime: bool buffer: seq[byte] offset: int boundary: seq[byte] @@ -30,18 +32,25 @@ type MultiPartReaderRef* = ref MultiPartReader MultiPart* = object + case kind: MultiPartSource + of MultiPartSource.Stream: + stream*: BoundedStreamReader + of MultiPartSource.Buffer: + discard + buffer: seq[byte] headers: HttpTable - stream: BoundedAsyncStreamReader - offset: int - size: int - MultipartError* = object of HttpError + MultipartError* = object of HttpCriticalError MultipartEOMError* = object of MultipartError - MultiPartIncorrectError* = object of MultipartError - MultiPartIncompleteError* = object of MultipartError + MultipartIncorrectError* = object of MultipartError + MultipartIncompleteError* = object of MultipartError + MultipartReadError* = object of MultipartError BChar* = byte | char +proc newMultipartReadError(msg: string): ref MultipartReadError = + newException(MultipartReadError, msg) + proc startsWith*(s, prefix: openarray[byte]): bool = var i = 0 while true: @@ -100,60 +109,106 @@ proc init*[B: BChar](mpt: typedesc[MultiPartReader], proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = doAssert(mpr.kind == MultiPartSource.Stream) - # According to RFC1521 says that a boundary "must be no longer than 70 - # characters, not counting the two leading hyphens. if mpr.firstTime: - # Read and verify initial <-><-> - mpr.firstTime = false - await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2) - if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5), - mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): - if buffer[0] == byte('-') and buffer[1] == byte("-"): - raise newException(MultiPartEOMError, "Unexpected EOM encountered") - if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8: + try: + # Read and verify initial <-><-> + await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2) + mpr.firstTime = false + if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5), + mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): + if mpr.buffer[0] == byte('-') and mpr.buffer[1] == byte('-'): + raise newException(MultiPartEOMError, + "Unexpected EOM encountered") + if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8: + raise newException(MultiPartIncorrectError, + "Unexpected boundary suffix") + else: raise newException(MultiPartIncorrectError, - "Unexpected boundary suffix") - else: - raise newException(MultiPartIncorrectError, - "Unexpected boundary encountered") + "Unexpected boundary encountered") + except CancelledError as exc: + raise exc + except AsyncStreamIncompleteError: + raise newMultipartReadError("Error reading multipart message") + except AsyncStreamReadError: + raise newMultipartReadError("Error reading multipart message") # Reading part's headers - let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer), + try: + let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer), HeadersMark) - var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1)) - if headersList.failed(): - raise newException(MultiPartIncorrectError, "Incorrect part headers found") + var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false) + if headersList.failed(): + raise newException(MultiPartIncorrectError, + "Incorrect part headers found") - var part = MultiPart() + var part = MultiPart( + kind: MultiPartSource.Stream, + headers: HttpTable.init(), + stream: newBoundedStreamReader(mpr.stream, -1, mpr.boundary) + ) - await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4) - if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5), - mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): - await mpr.stream.readExactly(addr buffer[0], 2) - if buffer[0] == byte('-') and buffer[1] == byte("-"): - raise newException(MultiPartEOMError, "") - if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8: + for k, v in headersList.headers(mpr.buffer.toOpenArray(0, res - 1)): + part.headers.add(k, v) - except: - discard - # if mpr.offset >= len(mpr.buffer): - # raise newException(MultiPartEOMError, "End of multipart form encountered") + return part -proc getStream*(mp: MultiPart): AsyncStreamReader = - mp.stream + except CancelledError as exc: + raise exc + except AsyncStreamIncompleteError: + raise newMultipartReadError("Error reading multipart message") + except AsyncStreamLimitError: + raise newMultipartReadError("Multipart message headers size too big") + except AsyncStreamReadError: + raise newMultipartReadError("Error reading multipart message") proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} = - try: - let res = await mp.stream.read() - return res - except AsyncStreamError: - raise newException(HttpCriticalError, "Could not read multipart body") + case mp.kind + of MultiPartSource.Stream: + try: + let res = await mp.stream.read() + return res + except AsyncStreamError: + raise newException(HttpCriticalError, "Could not read multipart body") + of MultiPartSource.Buffer: + return mp.buffer proc consumeBody*(mp: MultiPart) {.async.} = - try: - await mp.stream.consume() - except AsyncStreamError: - raise newException(HttpCriticalError, "Could not consume multipart body") + case mp.kind + of MultiPartSource.Stream: + try: + await mp.stream.consume() + except AsyncStreamError: + raise newException(HttpCriticalError, "Could not consume multipart body") + of MultiPartSource.Buffer: + discard + +proc getBytes*(mp: MultiPart): seq[byte] = + ## Returns MultiPart value as sequence of bytes. + case mp.kind + of MultiPartSource.Buffer: + mp.buffer + of MultiPartSource.Stream: + doAssert(not(mp.stream.atEof()), "Value is not obtained yet") + mp.buffer + +proc getString*(mp: MultiPart): string = + ## Returns MultiPart value as string. + case mp.kind + of MultiPartSource.Buffer: + if len(mp.buffer) > 0: + var res = newString(len(mp.buffer)) + copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer)) + res + else: + "" + of MultiPartSource.Stream: + doAssert(not(mp.stream.atEof()), "Value is not obtained yet") + if len(mp.buffer) > 0: + var res = newString(len(mp.buffer)) + copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer)) + res + else: + "" proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] = doAssert(mpr.kind == MultiPartSource.Buffer) @@ -215,8 +270,11 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] = # We set reader's offset to the place right after mpr.offset = start + pos2 + 2 - - var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init()) + var part = MultiPart( + kind: MultiPartSource.Buffer, + headers: HttpTable.init(), + buffer: @(mpr.buffer.toOpenArray(start, start + pos2 - 1)) + ) for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)): part.headers.add(k, v) ok(part) @@ -255,7 +313,7 @@ proc boundaryValue2(c: char): bool = c in {'a'..'z', 'A' .. 'Z', '0' .. '9', '\'' .. ')', '+' .. '/', ':', '=', '?', '_'} -func getMultipartBoundary*(contentType: string): Result[string, string] = +func getMultipartBoundary*(contentType: string): HttpResult[string] = let mparts = contentType.split(";") if strip(mparts[0]).toLowerAscii() != "multipart/form-data": return err("Content-Type is not multipart") @@ -270,7 +328,7 @@ func getMultipartBoundary*(contentType: string): Result[string, string] = else: ok(strip(bparts[1])) -func getContentType*(contentHeader: seq[string]): Result[string, string] = +func getContentType*(contentHeader: seq[string]): HttpResult[string] = if len(contentHeader) > 1: return err("Multiple Content-Header values found") let mparts = contentHeader[0].split(";")