diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim new file mode 100644 index 0000000..997425b --- /dev/null +++ b/chronos/apps/http/httpcommon.nim @@ -0,0 +1,27 @@ +# +# Chronos HTTP/S common types +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import stew/results, httputils +export results, httputils + +const + useChroniclesLogging* {.booldefine.} = false + + HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + +type + HttpResult*[T] = Result[T, string] + HttpResultCode*[T] = Result[T, HttpCode] + + HttpError* = object of CatchableError + HttpCriticalFailure* = object of HttpError + HttpRecoverableFailure* = object of HttpError + +template log*(body: untyped) = + when defined(useChroniclesLogging): + body diff --git a/chronos/apps/http/httpserver.exe b/chronos/apps/http/httpserver.exe new file mode 100644 index 0000000..a27c3c8 Binary files /dev/null and b/chronos/apps/http/httpserver.exe differ diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim new file mode 100644 index 0000000..b4d6185 --- /dev/null +++ b/chronos/apps/http/httpserver.nim @@ -0,0 +1,560 @@ +# +# Chronos HTTP/S server implementation +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[tables, options, uri, strutils] +import stew/results, httputils +import ../../asyncloop, ../../asyncsync +import ../../streams/[asyncstream, boundstream, chunkstream] +import httptable, httpcommon +export httpcommon + +when defined(useChroniclesLogging): + echo "Importing chronicles" + import chronicles + +type + HttpServerFlags* = enum + Secure + + TransferEncodingFlags* {.pure.} = enum + Identity, Chunked, Compress, Deflate, Gzip + + ContentEncodingFlags* {.pure.} = enum + Identity, Br, Compress, Deflate, Gzip + + HttpRequestFlags* {.pure.} = enum + BoundBody, UnboundBody, MultipartForm, UrlencodedForm + + HttpServer* = ref object of RootRef + instance*: StreamServer + # semaphore*: AsyncSemaphore + maxConnections*: int + baseUri*: Uri + flags*: set[HttpServerFlags] + connections*: Table[string, Future[void]] + acceptLoop*: Future[void] + lifetime*: Future[void] + headersTimeout: Duration + bodyTimeout: Duration + maxHeadersSize: int + maxRequestBodySize: int + + HttpServerState* = enum + ServerRunning, ServerStopped, ServerClosed + + HttpRequest* = object + headersTable: HttpTable + queryTable: HttpTable + postTable: HttpTable + rawPath*: string + rawQuery*: string + uri*: Uri + scheme*: string + version*: HttpVersion + meth*: HttpMethod + contentEncoding*: set[ContentEncodingFlags] + transferEncoding*: set[TransferEncodingFlags] + requestFlags*: set[HttpRequestFlags] + contentLength: int + connection*: HttpConnection + mainReader*: AsyncStreamReader + + HttpConnection* = ref object of RootRef + server: HttpServer + transp: StreamTransport + buffer: seq[byte] + +const + HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)] + +proc new*(htype: typedesc[HttpServer], + address: TransportAddress, + flags: set[HttpServerFlags] = {}, + serverUri = Uri(), + maxConnections: int = -1, + bufferSize: int = 4096, + backlogSize: int = 100, + httpHeadersTimeout = 10.seconds, + httpBodyTimeout = 30.seconds, + maxHeadersSize: int = 8192, + maxRequestBodySize: int = 1_048_576): HttpResult[HttpServer] = + var res = HttpServer( + maxConnections: maxConnections, + headersTimeout: httpHeadersTimeout, + bodyTimeout: httpBodyTimeout, + maxHeadersSize: maxHeadersSize, + maxRequestBodySize: maxRequestBodySize + ) + + res.baseUri = + if len(serverUri.hostname) > 0 and isAbsolute(serverUri): + serverUri + else: + if HttpServerFlags.Secure in flags: + parseUri("https://" & $address & "/") + else: + parseUri("http://" & $address & "/") + + try: + res.instance = createStreamServer(address, flags = {ReuseAddr}, + bufferSize = bufferSize, + backlog = backlogSize) + # if maxConnections > 0: + # res.semaphore = newAsyncSemaphore(maxConnections) + res.lifetime = newFuture[void]("http.server.lifetime") + res.connections = initTable[string, Future[void]]() + return ok(res) + except TransportOsError as exc: + return err(exc.msg) + except CatchableError as exc: + return err(exc.msg) + +proc getId(transp: StreamTransport): string {.inline.} = + ## Returns string unique transport's identifier as string. + $transp.remoteAddress() & "_" & $transp.localAddress() + +proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure = + newException(HttpCriticalFailure, msg) + +proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure = + newException(HttpRecoverableFailure, msg) + +iterator queryParams(query: string): tuple[key: string, value: string] = + for pair in query.split('&'): + let items = pair.split('=', maxsplit = 1) + let k = items[0] + let v = if len(items) > 1: items[1] else: "" + yield (decodeUrl(k), decodeUrl(v)) + +func getMultipartBoundary*(contentType: string): HttpResult[string] = + let mparts = contentType.split(";") + if strip(mparts[0]).toLowerAscii() != "multipart/form-data": + return err("Content-Type is not multipart") + if len(mparts) < 2: + return err("Content-Type missing boundary value") + let stripped = strip(mparts[1]) + if not(stripped.toLowerAscii().startsWith("boundary")): + return err("Incorrect Content-Type boundary format") + let bparts = stripped.split("=") + if len(bparts) < 2: + err("Missing Content-Type boundary") + else: + ok(strip(bparts[1])) + +func getContentType*(contentHeader: seq[string]): HttpResult[string] = + if len(contentHeader) > 1: + return err("Multiple Content-Type values found") + let mparts = contentHeader[0].split(";") + ok(strip(mparts[0]).toLowerAscii()) + +func getTransferEncoding(contentHeader: seq[string]): HttpResult[ + set[TransferEncodingFlags]] = + var res: set[TransferEncodingFlags] = {} + if len(contentHeader) == 0: + res.incl(TransferEncodingFlags.Identity) + ok(res) + else: + for header in contentHeader: + for item in header.split(","): + case strip(item.toLowerAscii()) + of "identity": + res.incl(TransferEncodingFlags.Identity) + of "chunked": + res.incl(TransferEncodingFlags.Chunked) + of "compress": + res.incl(TransferEncodingFlags.Compress) + of "deflate": + res.incl(TransferEncodingFlags.Deflate) + of "gzip": + res.incl(TransferEncodingFlags.Gzip) + of "": + res.incl(TransferEncodingFlags.Identity) + else: + return err("Incorrect Transfer-Encoding value") + ok(res) + +func getContentEncoding(contentHeader: seq[string]): HttpResult[ + set[ContentEncodingFlags]] = + var res: set[ContentEncodingFlags] = {} + if len(contentHeader) == 0: + res.incl(ContentEncodingFlags.Identity) + ok(res) + else: + for header in contentHeader: + for item in header.split(","): + case strip(item.toLowerAscii()): + of "identity": + res.incl(ContentEncodingFlags.Identity) + of "br": + res.incl(ContentEncodingFlags.Br) + of "compress": + res.incl(ContentEncodingFlags.Compress) + of "deflate": + res.incl(ContentEncodingFlags.Deflate) + of "gzip": + res.incl(ContentEncodingFlags.Gzip) + of "": + res.incl(ContentEncodingFlags.Identity) + else: + return err("Incorrect Content-Encoding value") + ok(res) + +proc hasBody*(request: HttpRequest): bool = + ## Returns ``true`` if request has body. + request.requestFlags * {HttpRequestFlags.BoundBody, + HttpRequestFlags.UnboundBody} != {} + +proc prepareRequest(conn: HttpConnection, + req: HttpRequestHeader): HttpResultCode[HttpRequest] = + var request = HttpRequest() + + if req.version notin {HttpVersion10, HttpVersion11}: + return err(Http505) + + request.version = req.version + request.meth = req.meth + + request.rawPath = + block: + let res = req.uri() + if len(res) == 0: + return err(Http400) + res + + request.uri = + if request.rawPath != "*": + let uri = parseUri(request.rawPath) + if uri.scheme notin ["http", "https", ""]: + return err(Http400) + uri + else: + var uri = initUri() + uri.path = "*" + uri + + request.queryTable = + block: + var table = HttpTable.init() + for key, value in queryParams(request.uri.query): + table.add(key, value) + table + + request.headersTable = + block: + var table = HttpTable.init() + # Retrieve headers and values + for key, value in req.headers(): + table.add(key, value) + # Validating HTTP request headers + # Some of the headers must be present only once. + if table.count("content-type") > 1: + return err(Http400) + if table.count("content-length") > 1: + return err(Http400) + if table.count("transfer-encoding") > 1: + return err(Http400) + table + + # Preprocessing "Content-Encoding" header. + request.contentEncoding = + block: + let res = getContentEncoding( + request.headersTable.getList("content-encoding")) + if res.isErr(): + return err(Http400) + else: + res.get() + + # Preprocessing "Transfer-Encoding" header. + request.transferEncoding = + block: + let res = getTransferEncoding( + request.headersTable.getList("transfer-encoding")) + if res.isErr(): + return err(Http400) + else: + res.get() + + # Almost all HTTP requests could have body (except TRACE), we perform some + # steps to reveal information about body. + if "content-length" in request.headersTable: + let length = request.headersTable.getInt("content-length") + if length > 0: + if request.meth == MethodTrace: + return err(Http400) + if length > uint64(high(int)): + return err(Http413) + if length > uint64(conn.server.maxRequestBodySize): + return err(Http413) + request.contentLength = int(length) + request.requestFlags.incl(HttpRequestFlags.BoundBody) + else: + if TransferEncodingFlags.Chunked in request.transferEncoding: + if request.meth == MethodTrace: + return err(Http400) + request.requestFlags.incl(HttpRequestFlags.UnboundBody) + + if request.hasBody(): + # If request has body, we going to understand how its encoded. + const + UrlEncodedType = "application/x-www-form-urlencoded" + MultipartType = "multipart/form-data" + + if "content-type" in request.headersTable: + let contentType = request.headersTable.getString("content-type") + let tmp = strip(contentType).toLowerAscii() + if tmp.startsWith(UrlEncodedType): + request.requestFlags.incl(UrlencodedForm) + elif tmp.startsWith(MultipartType): + request.requestFlags.incl(MultipartForm) + + request.mainReader = newAsyncStreamReader(conn.transp) + ok(request) + +proc getBodyStream*(request: HttpRequest): HttpResult[AsyncStreamReader] = + if HttpRequestFlags.BoundBody in request.requestFlags: + ok(newBoundedStreamReader(request.mainReader, request.contentLength)) + elif HttpRequestFlags.UnboundBody in request.requestFlags: + ok(newChunkedStreamReader(request.mainReader)) + else: + err("Request do not have body available") + +proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} = + ## Obtain request's body as sequence of bytes. + let res = request.getBodyStream() + if res.isErr(): + return @[] + else: + try: + return await read(res.get()) + except AsyncStreamError: + raise newHttpCriticalFailure("Read failure") + +proc consumeBody*(request: HttpRequest): Future[void] {.async.} = + ## Consume/discard request's body. + let res = request.getBodyStream() + if res.isErr(): + return + else: + let reader = res.get() + try: + discard await reader.consume() + return + except AsyncStreamError: + raise newHttpCriticalFailure("Read failure") + +proc sendErrorResponse(conn: HttpConnection, version: HttpVersion, + code: HttpCode, keepAlive = true, + datatype = "text/text", + databody = ""): Future[bool] {.async.} = + var answer = $version & " " & $code & "\r\n" + answer.add("Date: " & httpDate() & "\r\n") + if len(databody) > 0: + answer.add("Content-Type: " & datatype & "\r\n") + answer.add("Content-Length: " & $len(databody) & "\r\n") + if keepAlive: + answer.add("Connection: keep-alive\r\n") + else: + answer.add("Connection: close\r\n") + answer.add("\r\n") + if len(databody) > 0: + answer.add(databody) + try: + let res {.used.} = await conn.transp.write(answer) + return true + except CancelledError: + return false + except TransportOsError: + return false + except CatchableError: + return false + +proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true, + datatype = "text/text", + databody = ""): Future[bool] = + sendErrorResponse(request.connection, request.version, code, keepAlive, + datatype, databody) + +proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} = + when defined(useChroniclesLogging): + logScope: + peer = $conn.transp.remoteAddress + + try: + conn.buffer.setLen(conn.server.maxHeadersSize) + let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer), + HeadersMark) + conn.buffer.setLen(res) + let header = parseRequest(conn.buffer) + if header.failed(): + log debug "Malformed header received" + discard await conn.sendErrorResponse(HttpVersion11, Http400, false) + raise newHttpCriticalFailure("Malformed request recieved") + else: + let res = prepareRequest(conn, header) + if res.isErr(): + discard await conn.sendErrorResponse(HttpVersion11, Http400, false) + raise newHttpCriticalFailure("Invalid request received") + else: + return res.get() + except TransportOsError: + log debug "Unexpected OS error" + raise newHttpCriticalFailure("Unexpected OS error") + except TransportIncompleteError: + log debug "Remote peer disconnected" + raise newHttpCriticalFailure("Remote peer disconnected") + except TransportLimitError: + log debug "Maximum size of request headers reached" + discard await conn.sendErrorResponse(HttpVersion11, Http413, false) + raise newHttpCriticalFailure("Maximum size of request headers reached") + +proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} = + when defined(useChroniclesLogging): + logScope: + peer = $transp.remoteAddress + + var conn = HttpConnection( + transp: transp, buffer: newSeq[byte](server.maxHeadersSize), + server: server + ) + + log info "Client connected" + var breakLoop = false + while true: + try: + let request = await conn.getRequest().wait(server.headersTimeout) + echo "== HEADERS TABLE" + echo request.headersTable + echo "== QUERY TABLE" + echo request.queryTable + echo "== TRANSFER ENCODING ", request.transferEncoding + echo "== CONTENT ENCODING ", request.contentEncoding + echo "== REQUEST FLAGS ", request.requestFlags + var stream = await request.getBody() + echo cast[string](stream) + discard await conn.sendErrorResponse(HttpVersion11, Http200, true, + databody = "OK") + log debug "Response sent" + except AsyncTimeoutError: + log debug "Timeout reached while reading headers" + discard await conn.sendErrorResponse(HttpVersion11, Http408, false) + breakLoop = true + + except CancelledError: + breakLoop = true + + except HttpRecoverableFailure: + breakLoop = false + + except HttpCriticalFailure: + breakLoop = true + + except CatchableError as exc: + echo "CatchableError received ", exc.name + + if breakLoop: + break + + await transp.closeWait() + server.connections.del(transp.getId()) + # if server.maxConnections > 0: + # server.semaphore.release() + log info "Client got disconnected" + +proc acceptClientLoop(server: HttpServer) {.async.} = + var breakLoop = false + while true: + try: + # if server.maxConnections > 0: + # await server.semaphore.acquire() + + let transp = await server.instance.accept() + server.connections[transp.getId()] = processLoop(server, transp) + + except CancelledError: + # Server was stopped + breakLoop = true + except TransportOsError: + # This is some critical unrecoverable error. + breakLoop = true + except TransportTooManyError: + # Non critical error + breakLoop = false + except CatchableError: + # Unexpected error + breakLoop = true + discard + + if breakLoop: + break + +proc state*(server: HttpServer): HttpServerState = + ## Returns current HTTP server's state. + if server.lifetime.finished(): + ServerClosed + else: + if isNil(server.acceptLoop): + ServerStopped + else: + if server.acceptLoop.finished(): + ServerStopped + else: + ServerRunning + +proc start*(server: HttpServer) = + ## Starts HTTP server. + if server.state == ServerStopped: + server.acceptLoop = acceptClientLoop(server) + +proc stop*(server: HttpServer) {.async.} = + ## Stop HTTP server from accepting new connections. + if server.state == ServerRunning: + await server.acceptLoop.cancelAndWait() + +proc drop*(server: HttpServer) {.async.} = + ## Drop all pending HTTP connections. + if server.state in {ServerStopped, ServerRunning}: + discard + +proc close*(server: HttpServer) {.async.} = + ## Stop HTTP server and drop all the pending connections. + if server.state != ServerClosed: + await server.stop() + await server.drop() + await server.instance.closeWait() + server.lifetime.complete() + +proc join*(server: HttpServer): Future[void] = + ## Wait until HTTP server will not be closed. + var retFuture = newFuture[void]("http.server.join") + + proc continuation(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): + retFuture.complete() + + proc cancellation(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): + server.lifetime.removeCallback(continuation, cast[pointer](retFuture)) + + if server.state == ServerClosed: + retFuture.complete() + else: + server.lifetime.addCallback(continuation, cast[pointer](retFuture)) + retFuture.cancelCallback = cancellation + + retFuture + +when isMainModule: + let res = HttpServer.new(initTAddress("127.0.0.1:30080"), maxConnections = 1) + if res.isOk(): + let server = res.get() + server.start() + echo "HTTP server was started" + waitFor server.join() + else: + echo "Failed to start server: ", res.error diff --git a/chronos/apps/http/httptable.nim b/chronos/apps/http/httptable.nim new file mode 100644 index 0000000..db9cc51 --- /dev/null +++ b/chronos/apps/http/httptable.nim @@ -0,0 +1,124 @@ +# +# Chronos HTTP/S case-insensitive non-unique +# key-value memory storage +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[tables, strutils] + +type + HttpTable* = object + table: Table[string, seq[string]] + + HttpTableRef* = ref HttpTable + + HttpTables* = HttpTable | HttpTableRef + +proc `-`(x: uint32): uint32 {.inline.} = + (0xFFFF_FFFF'u32 - x) + 1'u32 + +proc LT(x, y: uint32): uint32 {.inline.} = + let z = x - y + (z xor ((y xor x) and (y xor z))) shr 31 + +proc decValue(c: byte): int = + let x = uint32(c) - 0x30'u32 + let r = ((x + 1'u32) and -LT(x, 10)) + int(r) - 1 + +proc bytesToDec*[T: byte|char](src: openarray[T]): uint64 = + var v = 0'u64 + for i in 0 ..< len(src): + let d = + when T is byte: + decValue(src[i]) + else: + decValue(byte(src[i])) + if d < 0: + # non-decimal character encountered + return v + else: + let nv = ((v shl 3) + (v shl 1)) + uint64(d) + if nv < v: + # overflow happened + return v + else: + v = nv + v + +proc add*(ht: var HttpTables, key: string, value: string) = + let lowkey = key.toLowerAscii() + var nitem = @[value] + if ht.table.hasKeyOrPut(lowkey, nitem): + var oitem = ht.table[lowkey] + oitem.add(value) + ht.table[lowkey] = oitem + +proc add*(ht: var HttpTables, key: string, value: SomeInteger) = + ht.add(key, $value) + +proc contains*(ht: var HttpTables, key: string): bool = + ht.table.contains(key.toLowerAscii()) + +proc getList*(ht: HttpTables, key: string): seq[string] = + var default: seq[string] + ht.table.getOrDefault(key.toLowerAscii(), default) + +proc getString*(ht: HttpTables, key: string): string = + var default: seq[string] + ht.table.getOrDefault(key.toLowerAscii(), default).join(",") + +proc count*(ht: HttpTables, key: string): int = + var default: seq[string] + len(ht.table.getOrDefault(key, default)) + +proc getInt*(ht: HttpTables, key: string): uint64 = + bytesToDec(ht.getString(key)) + +proc getLastString*(ht: HttpTables, key: string): string = + var default: seq[string] + let item = ht.table.getOrDefault(key.toLowerAscii(), default) + if len(item) == 0: + "" + else: + item[^1] + +proc getLastInt*(ht: HttpTables, key: string): uint64 = + bytesToDec(ht.getLastString()) + +proc init*(htt: typedesc[HttpTable]): HttpTable = + HttpTable(table: initTable[string, seq[string]]()) + +proc new*(htt: typedesc[HttpTableRef]): HttpTableRef = + HttpTableRef(table: initTable[string, seq[string]]()) + +proc normalizeHeaderName*(value: string): string = + var res = value.toLowerAscii() + var k = 0 + while k < len(res): + if k == 0: + res[k] = toUpperAscii(res[k]) + inc(k, 1) + else: + if res[k] == '-': + if k + 1 < len(res): + res[k + 1] = toUpperAscii(res[k + 1]) + inc(k, 2) + else: + break + else: + inc(k, 1) + res + +proc `$`*(ht: HttpTables): string = + var res = "" + for key, value in ht.table.pairs(): + for item in value: + res.add(key.normalizeHeaderName()) + res.add(": ") + res.add(item) + res.add("\p") + res diff --git a/chronos/apps/http/multipart.nim b/chronos/apps/http/multipart.nim new file mode 100644 index 0000000..d217c0b --- /dev/null +++ b/chronos/apps/http/multipart.nim @@ -0,0 +1,286 @@ +# +# Chronos HTTP/S multipart/form +# encoding and decoding helper procedures +# (c) Copyright 2019-Present +# Status Research & Development GmbH +# +# Licensed under either of +# Apache License, version 2.0, (LICENSE-APACHEv2) +# MIT license (LICENSE-MIT) +import std/[monotimes, strutils] +import chronos, stew/results +import httptable, httpcommon +export httptable, httpcommon + +type + MultiPartSource {.pure.} = enum + Stream, Buffer + + MultiPartReader* = object + case kind: MultiPartSource + of MultiPartSource.Stream: + stream: AsyncStreamReader + last: BoundedAsyncStreamReader + of MultiPartSource.Buffer: + discard + buffer: seq[byte] + offset: int + boundary: seq[byte] + + MultiPartReaderRef* = ref MultiPartReader + + MultiPart* = object + headers: HttpTable + stream: BoundedAsyncStreamReader + offset: int + size: int + + MultipartError* = object of HttpError + MultipartEOMError* = object of MultipartError + MultiPartIncorrectError* = object of MultipartError + MultiPartIncompleteError* = object of MultipartError + + BChar* = byte | char + +proc startsWith*(s, prefix: openarray[byte]): bool = + var i = 0 + while true: + if i >= len(prefix): return true + if i >= len(s) or s[i] != prefix[i]: return false + inc(i) + +proc parseUntil*(s, until: openarray[byte]): int = + var i = 0 + while i < len(s): + if len(until) > 0 and s[i] == until[0]: + var u = 1 + while i + u < len(s) and u < len(until) and s[i + u] == until[u]: + inc u + if u >= len(until): return i + inc(i) + -1 + +proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader], + buffer: openarray[A], + boundary: openarray[B]): MultiPartReader = + # Boundary should not be empty. + doAssert(len(boundary) > 0) + # Our internal boundary has format `<-><->`, so we can reuse + # different parts of this sequence for processing. + var fboundary = newSeq[byte](len(boundary) + 4) + fboundary[0] = 0x0D'u8 + fboundary[1] = 0x0A'u8 + fboundary[2] = byte('-') + fboundary[3] = byte('-') + copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary)) + # Make copy of buffer, because all the returned parts depending on it. + var buf = newSeq[byte](len(buffer)) + if len(buf) > 0: + copyMem(addr buf[0], unsafeAddr buffer[0], len(buffer)) + MultiPartReader(kind: MultiPartSource.Buffer, + buffer: buf, offset: 0, boundary: fboundary) + +proc init*[B: BChar](mpt: typedesc[MultiPartReader], + stream: AsyncStreamReader, + boundary: openarray[B], + partHeadersMaxSize = 4096): MultiPartReader = + # Boundary should not be empty. + doAssert(len(boundary) > 0) + # Our internal boundary has format `<-><->`, so we can reuse + # different parts of this sequence for processing. + var fboundary = newSeq[byte](len(boundary) + 4) + fboundary[0] = 0x0D'u8 + fboundary[1] = 0x0A'u8 + fboundary[2] = byte('-') + fboundary[3] = byte('-') + copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary)) + MultiPartReader(kind: MultiPartSource.Stream, + stream: stream, offset: 0, boundary: fboundary, + buffer: newSeq[byte](partHeadersMaxSize)) + +proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = + doAssert(mpr.kind == MultiPartSource.Stream) + # According to RFC1521 says that a boundary "must be no longer than 70 + # characters, not counting the two leading hyphens. + if mpr.firstTime: + # Read and verify initial <-><-> + mpr.firstTime = false + await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2) + if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5), + mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): + if buffer[0] == byte('-') and buffer[1] == byte("-"): + raise newException(MultiPartEOMError, "Unexpected EOM encountered") + if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8: + raise newException(MultiPartIncorrectError, + "Unexpected boundary suffix") + else: + raise newException(MultiPartIncorrectError, + "Unexpected boundary encountered") + + # Reading part's headers + let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer), + HeadersMark) + var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1)) + if headersList.failed(): + raise newException(MultiPartIncorrectError, "Incorrect part headers found") + + var part = MultiPart() + + await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4) + if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5), + mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): + await mpr.stream.readExactly(addr buffer[0], 2) + if buffer[0] == byte('-') and buffer[1] == byte("-"): + raise newException(MultiPartEOMError, "") + if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8: + + except: + discard + # if mpr.offset >= len(mpr.buffer): + # raise newException(MultiPartEOMError, "End of multipart form encountered") + +proc getStream*(mp: MultiPart): AsyncStreamReader = + mp.stream + +proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} = + try: + let res = await mp.stream.read() + return res + except AsyncStreamError: + raise newException(HttpCriticalError, "Could not read multipart body") + +proc consumeBody*(mp: MultiPart) {.async.} = + try: + await mp.stream.consume() + except AsyncStreamError: + raise newException(HttpCriticalError, "Could not consume multipart body") + +proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] = + doAssert(mpr.kind == MultiPartSource.Buffer) + if mpr.offset >= len(mpr.buffer): + return err("End of multipart form encountered") + + if startsWith(mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1), + mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): + # Buffer must start at <-><-> + mpr.offset += (len(mpr.boundary) - 2) + + # After boundary there should be at least 2 symbols <-><-> or . + if len(mpr.buffer) <= mpr.offset + 1: + return err("Incomplete multipart form") + + if mpr.buffer[mpr.offset] == byte('-') and + mpr.buffer[mpr.offset + 1] == byte('-'): + # If we have <-><-><-><-> it means we have found last boundary + # of multipart message. + mpr.offset += 2 + return err("End of multipart form encountered") + + if mpr.buffer[mpr.offset] == 0x0D'u8 and + mpr.buffer[mpr.offset + 1] == 0x0A'u8: + # If we have <-><-> it means that we have found another + # part of multipart message. + mpr.offset += 2 + # Multipart form must always have at least single Content-Disposition + # header, so we searching position where all the headers should be + # finished . + let pos1 = parseUntil( + mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1), + [0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8] + ) + + if pos1 < 0: + return err("Incomplete multipart form") + + # parseUntil returns 0-based position without `until` sequence. + let start = mpr.offset + pos1 + 4 + + # Multipart headers position + let hstart = mpr.offset + let hfinish = mpr.offset + pos1 + 4 - 1 + + let headersList = parseHeaders(mpr.buffer.toOpenArray(hstart, hfinish), + false) + if headersList.failed(): + return err("Incorrect or incomplete multipart headers received") + + # Searching for value's boundary <-><->. + let pos2 = parseUntil( + mpr.buffer.toOpenArray(start, len(mpr.buffer) - 1), + mpr.boundary.toOpenArray(0, len(mpr.boundary) - 1) + ) + + if pos2 < 0: + return err("Incomplete multipart form") + + # We set reader's offset to the place right after + mpr.offset = start + pos2 + 2 + + var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init()) + for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)): + part.headers.add(k, v) + ok(part) + else: + err("Incorrect multipart form") + else: + err("Incorrect multipart form") + +template `-`(x: uint32): uint32 = + (0xFFFF_FFFF'u32 - x) + 1'u32 + +template LT(x, y: uint32): uint32 = + let z = x - y + (z xor ((y xor x) and (y xor z))) shr 31 + +proc boundaryValue(c: char): bool = + let a0 = uint32(c) - 0x27'u32 + let a1 = uint32(c) - 0x2B'u32 + let a2 = uint32(c) - 0x3A'u32 + let a3 = uint32(c) - 0x3D'u32 + let a4 = uint32(c) - 0x3F'u32 + let a5 = uint32(c) - 0x41'u32 + let a6 = uint32(c) - 0x5F + let a7 = uint32(c) - 0x61'u32 + let r = ((a0 + 1'u32) and -LT(a0, 3)) or + ((a1 + 1'u32) and -LT(a1, 15)) or + ((a2 + 1'u32) and -LT(a2, 1)) or + ((a3 + 1'u32) and -LT(a3, 1)) or + ((a4 + 1'u32) and -LT(a4, 1)) or + ((a5 + 1'u32) and -LT(a5, 26)) or + ((a6 + 1'u32) and -LT(a6, 1)) or + ((a7 + 1'u32) and -LT(a7, 26)) + (int(r) - 1) > 0 + +proc boundaryValue2(c: char): bool = + c in {'a'..'z', 'A' .. 'Z', '0' .. '9', + '\'' .. ')', '+' .. '/', ':', '=', '?', '_'} + +func getMultipartBoundary*(contentType: string): Result[string, string] = + let mparts = contentType.split(";") + if strip(mparts[0]).toLowerAscii() != "multipart/form-data": + return err("Content-Type is not multipart") + if len(mparts) < 2: + return err("Content-Type missing boundary value") + let stripped = strip(mparts[1]) + if not(stripped.toLowerAscii().startsWith("boundary")): + return err("Incorrect Content-Type boundary format") + let bparts = stripped.split("=") + if len(bparts) < 2: + err("Missing Content-Type boundary") + else: + ok(strip(bparts[1])) + +func getContentType*(contentHeader: seq[string]): Result[string, string] = + if len(contentHeader) > 1: + return err("Multiple Content-Header values found") + let mparts = contentHeader[0].split(";") + ok(strip(mparts[0]).toLowerAscii()) + +when isMainModule: + var buf = "--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key1\"\r\n\r\nvalue1\r\n--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key2\"\r\n\r\nvalue2\r\n--------------------------5e7d0dd0ed6eb849--" + var reader = MultiPartReader.init(buf, "------------------------5e7d0dd0ed6eb849") + echo getPart(reader) + echo "====" + echo getPart(reader) + echo "====" + echo getPart(reader)