Prepare for HttpResponse.

This commit is contained in:
cheatfate 2021-01-27 21:39:14 +02:00 committed by zah
parent 60e5396a9e
commit 0e5ea5b737
3 changed files with 285 additions and 181 deletions

View File

@ -6,8 +6,8 @@
# Licensed under either of # Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import stew/results, httputils import stew/results, httputils, strutils, uri
export results, httputils export results, httputils, strutils
const const
useChroniclesLogging* {.booldefine.} = false useChroniclesLogging* {.booldefine.} = false
@ -19,9 +19,110 @@ type
HttpResultCode*[T] = Result[T, HttpCode] HttpResultCode*[T] = Result[T, HttpCode]
HttpError* = object of CatchableError HttpError* = object of CatchableError
HttpCriticalFailure* = object of HttpError HttpCriticalError* = object of HttpError
HttpRecoverableFailure* = object of HttpError HttpRecoverableError* = object of HttpError
TransferEncodingFlags* {.pure.} = enum
Identity, Chunked, Compress, Deflate, Gzip
ContentEncodingFlags* {.pure.} = enum
Identity, Br, Compress, Deflate, Gzip
template log*(body: untyped) = template log*(body: untyped) =
when defined(useChroniclesLogging): when defined(useChroniclesLogging):
body body
proc newHttpCriticalError*(msg: string): ref HttpCriticalError =
newException(HttpCriticalError, msg)
proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError =
newException(HttpRecoverableError, msg)
iterator queryParams*(query: string): tuple[key: string, value: string] =
## Iterate over url-encoded query string.
for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1)
let k = items[0]
let v = if len(items) > 1: items[1] else: ""
yield (decodeUrl(k), decodeUrl(v))
func getTransferEncoding*(ch: openarray[string]): HttpResult[
set[TransferEncodingFlags]] =
## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return
## it as set of ``TransferEncodingFlags``.
var res: set[TransferEncodingFlags] = {}
if len(ch) == 0:
res.incl(TransferEncodingFlags.Identity)
ok(res)
else:
for header in ch:
for item in header.split(","):
case strip(item.toLowerAscii())
of "identity":
res.incl(TransferEncodingFlags.Identity)
of "chunked":
res.incl(TransferEncodingFlags.Chunked)
of "compress":
res.incl(TransferEncodingFlags.Compress)
of "deflate":
res.incl(TransferEncodingFlags.Deflate)
of "gzip":
res.incl(TransferEncodingFlags.Gzip)
of "":
res.incl(TransferEncodingFlags.Identity)
else:
return err("Incorrect Transfer-Encoding value")
ok(res)
func getContentEncoding*(ch: openarray[string]): HttpResult[
set[ContentEncodingFlags]] =
## Parse value of multiple HTTP headers ``Content-Encoding`` and return
## it as set of ``ContentEncodingFlags``.
var res: set[ContentEncodingFlags] = {}
if len(ch) == 0:
res.incl(ContentEncodingFlags.Identity)
ok(res)
else:
for header in ch:
for item in header.split(","):
case strip(item.toLowerAscii()):
of "identity":
res.incl(ContentEncodingFlags.Identity)
of "br":
res.incl(ContentEncodingFlags.Br)
of "compress":
res.incl(ContentEncodingFlags.Compress)
of "deflate":
res.incl(ContentEncodingFlags.Deflate)
of "gzip":
res.incl(ContentEncodingFlags.Gzip)
of "":
res.incl(ContentEncodingFlags.Identity)
else:
return err("Incorrect Content-Encoding value")
ok(res)
func getContentType*(ch: openarray[string]): HttpResult[string] =
## Check and prepare value of ``Content-Type`` header.
if len(ch) > 1:
err("Multiple Content-Type values found")
else:
let mparts = ch[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
func getMultipartBoundary*(contentType: string): HttpResult[string] =
## Process ``multipart/form-data`` ``Content-Type`` header and return
## multipart boundary.
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))

View File

@ -21,15 +21,25 @@ type
HttpServerFlags* = enum HttpServerFlags* = enum
Secure Secure
TransferEncodingFlags* {.pure.} = enum HttpConnectionStatus* = enum
Identity, Chunked, Compress, Deflate, Gzip DropConnection, KeepConnection
ContentEncodingFlags* {.pure.} = enum HttpErrorEnum* = enum
Identity, Br, Compress, Deflate, Gzip TimeoutError, CatchableError, RecoverableError, CriticalError
HttpProcessError* = object
error*: HttpErrorEnum
exc*: HttpError
remote*: TransportAddress
HttpProcessStatus*[T] = Result[T, HttpProcessError]
HttpRequestFlags* {.pure.} = enum HttpRequestFlags* {.pure.} = enum
BoundBody, UnboundBody, MultipartForm, UrlencodedForm BoundBody, UnboundBody, MultipartForm, UrlencodedForm
HttpProcessCallback* =
proc(request: HttpProcessStatus[HttpRequest]): Future[HttpStatus]
HttpServer* = ref object of RootRef HttpServer* = ref object of RootRef
instance*: StreamServer instance*: StreamServer
# semaphore*: AsyncSemaphore # semaphore*: AsyncSemaphore
@ -43,6 +53,7 @@ type
bodyTimeout: Duration bodyTimeout: Duration
maxHeadersSize: int maxHeadersSize: int
maxRequestBodySize: int maxRequestBodySize: int
processCallback: HttpProcessCallback
HttpServerState* = enum HttpServerState* = enum
ServerRunning, ServerStopped, ServerClosed ServerRunning, ServerStopped, ServerClosed
@ -64,18 +75,24 @@ type
connection*: HttpConnection connection*: HttpConnection
mainReader*: AsyncStreamReader mainReader*: AsyncStreamReader
HttpResponse* = object
code*: HttpCode
version*: HttpVersion
headersTable: HttpTable
body*: seq[byte]
connection*: HttpConnection
mainWriter: AsyncStreamWriter
HttpConnection* = ref object of RootRef HttpConnection* = ref object of RootRef
server: HttpServer server*: HttpServer
transp: StreamTransport transp: StreamTransport
buffer: seq[byte] buffer: seq[byte]
const
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
proc new*(htype: typedesc[HttpServer], proc new*(htype: typedesc[HttpServer],
address: TransportAddress, address: TransportAddress,
flags: set[HttpServerFlags] = {}, flags: set[HttpServerFlags] = {},
serverUri = Uri(), serverUri = Uri(),
processCallback: HttpProcessCallback,
maxConnections: int = -1, maxConnections: int = -1,
bufferSize: int = 4096, bufferSize: int = 4096,
backlogSize: int = 100, backlogSize: int = 100,
@ -88,7 +105,8 @@ proc new*(htype: typedesc[HttpServer],
headersTimeout: httpHeadersTimeout, headersTimeout: httpHeadersTimeout,
bodyTimeout: httpBodyTimeout, bodyTimeout: httpBodyTimeout,
maxHeadersSize: maxHeadersSize, maxHeadersSize: maxHeadersSize,
maxRequestBodySize: maxRequestBodySize maxRequestBodySize: maxRequestBodySize,
processCallback: processCallback
) )
res.baseUri = res.baseUri =
@ -118,92 +136,6 @@ proc getId(transp: StreamTransport): string {.inline.} =
## Returns string unique transport's identifier as string. ## Returns string unique transport's identifier as string.
$transp.remoteAddress() & "_" & $transp.localAddress() $transp.remoteAddress() & "_" & $transp.localAddress()
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
newException(HttpCriticalFailure, msg)
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
newException(HttpRecoverableFailure, msg)
iterator queryParams(query: string): tuple[key: string, value: string] =
for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1)
let k = items[0]
let v = if len(items) > 1: items[1] else: ""
yield (decodeUrl(k), decodeUrl(v))
func getMultipartBoundary*(contentType: string): HttpResult[string] =
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
if len(contentHeader) > 1:
return err("Multiple Content-Type values found")
let mparts = contentHeader[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
set[TransferEncodingFlags]] =
var res: set[TransferEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(TransferEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii())
of "identity":
res.incl(TransferEncodingFlags.Identity)
of "chunked":
res.incl(TransferEncodingFlags.Chunked)
of "compress":
res.incl(TransferEncodingFlags.Compress)
of "deflate":
res.incl(TransferEncodingFlags.Deflate)
of "gzip":
res.incl(TransferEncodingFlags.Gzip)
of "":
res.incl(TransferEncodingFlags.Identity)
else:
return err("Incorrect Transfer-Encoding value")
ok(res)
func getContentEncoding(contentHeader: seq[string]): HttpResult[
set[ContentEncodingFlags]] =
var res: set[ContentEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(ContentEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii()):
of "identity":
res.incl(ContentEncodingFlags.Identity)
of "br":
res.incl(ContentEncodingFlags.Br)
of "compress":
res.incl(ContentEncodingFlags.Compress)
of "deflate":
res.incl(ContentEncodingFlags.Deflate)
of "gzip":
res.incl(ContentEncodingFlags.Gzip)
of "":
res.incl(ContentEncodingFlags.Identity)
else:
return err("Incorrect Content-Encoding value")
ok(res)
proc hasBody*(request: HttpRequest): bool = proc hasBody*(request: HttpRequest): bool =
## Returns ``true`` if request has body. ## Returns ``true`` if request has body.
request.requestFlags * {HttpRequestFlags.BoundBody, request.requestFlags * {HttpRequestFlags.BoundBody,
@ -333,7 +265,7 @@ proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
try: try:
return await read(res.get()) return await read(res.get())
except AsyncStreamError: except AsyncStreamError:
raise newHttpCriticalFailure("Read failure") raise newHttpCriticalError("Read Error")
proc consumeBody*(request: HttpRequest): Future[void] {.async.} = proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
## Consume/discard request's body. ## Consume/discard request's body.
@ -346,7 +278,7 @@ proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
discard await reader.consume() discard await reader.consume()
return return
except AsyncStreamError: except AsyncStreamError:
raise newHttpCriticalFailure("Read failure") raise newHttpCriticalError("Read Error")
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion, proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
code: HttpCode, keepAlive = true, code: HttpCode, keepAlive = true,
@ -374,17 +306,13 @@ proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
except CatchableError: except CatchableError:
return false return false
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true, proc sendErrorResponse*(request: HttpRequest, code: HttpCode, keepAlive = true,
datatype = "text/text", datatype = "text/text",
databody = ""): Future[bool] = databody = ""): Future[bool] =
sendErrorResponse(request.connection, request.version, code, keepAlive, sendErrorResponse(request.connection, request.version, code, keepAlive,
datatype, databody) datatype, databody)
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} = proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $conn.transp.remoteAddress
try: try:
conn.buffer.setLen(conn.server.maxHeadersSize) conn.buffer.setLen(conn.server.maxHeadersSize)
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer), let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
@ -392,42 +320,62 @@ proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
conn.buffer.setLen(res) conn.buffer.setLen(res)
let header = parseRequest(conn.buffer) let header = parseRequest(conn.buffer)
if header.failed(): if header.failed():
log debug "Malformed header received"
discard await conn.sendErrorResponse(HttpVersion11, Http400, false) discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Malformed request recieved") raise newHttpCriticalError("Malformed request recieved")
else: else:
let res = prepareRequest(conn, header) let res = prepareRequest(conn, header)
if res.isErr(): if res.isErr():
discard await conn.sendErrorResponse(HttpVersion11, Http400, false) discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Invalid request received") raise newHttpCriticalError("Invalid request received")
else: else:
return res.get() return res.get()
except TransportOsError: except TransportOsError:
log debug "Unexpected OS error" raise newHttpCriticalError("Unexpected OS error")
raise newHttpCriticalFailure("Unexpected OS error")
except TransportIncompleteError: except TransportIncompleteError:
log debug "Remote peer disconnected" raise newHttpCriticalError("Remote peer disconnected")
raise newHttpCriticalFailure("Remote peer disconnected")
except TransportLimitError: except TransportLimitError:
log debug "Maximum size of request headers reached"
discard await conn.sendErrorResponse(HttpVersion11, Http413, false) discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
raise newHttpCriticalFailure("Maximum size of request headers reached") raise newHttpCriticalError("Maximum size of request headers reached")
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} = proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $transp.remoteAddress
var conn = HttpConnection( var conn = HttpConnection(
transp: transp, buffer: newSeq[byte](server.maxHeadersSize), transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
server: server server: server
) )
log info "Client connected"
var breakLoop = false var breakLoop = false
while true: while true:
var status: HttpProcessStatus
var arg: HttpProcessStatus[HttpRequest]
try: try:
let request = await conn.getRequest().wait(server.headersTimeout) let request = await conn.getRequest().wait(server.headersTimeout)
arg = ok(request)
except AsyncTimeoutError as exc:
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
breakLoop = true
arg = err(HttpProcessError(exc: exc, remote: transp.remoteAddress()))
except CancelledError:
breakLoop = true
except HttpRecoverableError:
breakLoop = false
arg = err()
except CatchableError:
breakLoop = true
if breakLoop:
break
breakLoop = false
let status =
try:
await conn.server.processCallback()
except CancelledError:
breakLoop = true
HttpCriticalError
except CatchableError:
breakLoop = true
HttpCriticalError
echo "== HEADERS TABLE" echo "== HEADERS TABLE"
echo request.headersTable echo request.headersTable
echo "== QUERY TABLE" echo "== QUERY TABLE"
@ -439,19 +387,17 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
echo cast[string](stream) echo cast[string](stream)
discard await conn.sendErrorResponse(HttpVersion11, Http200, true, discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
databody = "OK") databody = "OK")
log debug "Response sent"
except AsyncTimeoutError: except AsyncTimeoutError:
log debug "Timeout reached while reading headers"
discard await conn.sendErrorResponse(HttpVersion11, Http408, false) discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
breakLoop = true breakLoop = true
except CancelledError: except CancelledError:
breakLoop = true breakLoop = true
except HttpRecoverableFailure: except HttpRecoverableError:
breakLoop = false breakLoop = false
except HttpCriticalFailure: except HttpCriticalError:
breakLoop = true breakLoop = true
except CatchableError as exc: except CatchableError as exc:
@ -464,7 +410,6 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
server.connections.del(transp.getId()) server.connections.del(transp.getId())
# if server.maxConnections > 0: # if server.maxConnections > 0:
# server.semaphore.release() # server.semaphore.release()
log info "Client got disconnected"
proc acceptClientLoop(server: HttpServer) {.async.} = proc acceptClientLoop(server: HttpServer) {.async.} =
var breakLoop = false var breakLoop = false

View File

@ -8,9 +8,11 @@
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import std/[monotimes, strutils] import std/[monotimes, strutils]
import chronos, stew/results import stew/results
import ../../asyncloop
import ../../streams/[asyncstream, boundstream, chunkstream]
import httptable, httpcommon import httptable, httpcommon
export httptable, httpcommon export httptable, httpcommon, asyncstream
type type
MultiPartSource {.pure.} = enum MultiPartSource {.pure.} = enum
@ -20,9 +22,9 @@ type
case kind: MultiPartSource case kind: MultiPartSource
of MultiPartSource.Stream: of MultiPartSource.Stream:
stream: AsyncStreamReader stream: AsyncStreamReader
last: BoundedAsyncStreamReader
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
discard discard
firstTime: bool
buffer: seq[byte] buffer: seq[byte]
offset: int offset: int
boundary: seq[byte] boundary: seq[byte]
@ -30,18 +32,25 @@ type
MultiPartReaderRef* = ref MultiPartReader MultiPartReaderRef* = ref MultiPartReader
MultiPart* = object MultiPart* = object
case kind: MultiPartSource
of MultiPartSource.Stream:
stream*: BoundedStreamReader
of MultiPartSource.Buffer:
discard
buffer: seq[byte]
headers: HttpTable headers: HttpTable
stream: BoundedAsyncStreamReader
offset: int
size: int
MultipartError* = object of HttpError MultipartError* = object of HttpCriticalError
MultipartEOMError* = object of MultipartError MultipartEOMError* = object of MultipartError
MultiPartIncorrectError* = object of MultipartError MultipartIncorrectError* = object of MultipartError
MultiPartIncompleteError* = object of MultipartError MultipartIncompleteError* = object of MultipartError
MultipartReadError* = object of MultipartError
BChar* = byte | char BChar* = byte | char
proc newMultipartReadError(msg: string): ref MultipartReadError =
newException(MultipartReadError, msg)
proc startsWith*(s, prefix: openarray[byte]): bool = proc startsWith*(s, prefix: openarray[byte]): bool =
var i = 0 var i = 0
while true: while true:
@ -100,60 +109,106 @@ proc init*[B: BChar](mpt: typedesc[MultiPartReader],
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} = proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
doAssert(mpr.kind == MultiPartSource.Stream) doAssert(mpr.kind == MultiPartSource.Stream)
# According to RFC1521 says that a boundary "must be no longer than 70
# characters, not counting the two leading hyphens.
if mpr.firstTime: if mpr.firstTime:
try:
# Read and verify initial <-><-><boundary><CR><LF> # Read and verify initial <-><-><boundary><CR><LF>
mpr.firstTime = false
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2) await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
mpr.firstTime = false
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5), if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)): mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
if buffer[0] == byte('-') and buffer[1] == byte("-"): if mpr.buffer[0] == byte('-') and mpr.buffer[1] == byte('-'):
raise newException(MultiPartEOMError, "Unexpected EOM encountered") raise newException(MultiPartEOMError,
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8: "Unexpected EOM encountered")
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
raise newException(MultiPartIncorrectError, raise newException(MultiPartIncorrectError,
"Unexpected boundary suffix") "Unexpected boundary suffix")
else: else:
raise newException(MultiPartIncorrectError, raise newException(MultiPartIncorrectError,
"Unexpected boundary encountered") "Unexpected boundary encountered")
except CancelledError as exc:
raise exc
except AsyncStreamIncompleteError:
raise newMultipartReadError("Error reading multipart message")
except AsyncStreamReadError:
raise newMultipartReadError("Error reading multipart message")
# Reading part's headers # Reading part's headers
try:
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer), let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
HeadersMark) HeadersMark)
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1)) var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
if headersList.failed(): if headersList.failed():
raise newException(MultiPartIncorrectError, "Incorrect part headers found") raise newException(MultiPartIncorrectError,
"Incorrect part headers found")
var part = MultiPart() var part = MultiPart(
kind: MultiPartSource.Stream,
headers: HttpTable.init(),
stream: newBoundedStreamReader(mpr.stream, -1, mpr.boundary)
)
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4) for k, v in headersList.headers(mpr.buffer.toOpenArray(0, res - 1)):
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5), part.headers.add(k, v)
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
await mpr.stream.readExactly(addr buffer[0], 2)
if buffer[0] == byte('-') and buffer[1] == byte("-"):
raise newException(MultiPartEOMError, "")
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
except: return part
discard
# if mpr.offset >= len(mpr.buffer):
# raise newException(MultiPartEOMError, "End of multipart form encountered")
proc getStream*(mp: MultiPart): AsyncStreamReader = except CancelledError as exc:
mp.stream raise exc
except AsyncStreamIncompleteError:
raise newMultipartReadError("Error reading multipart message")
except AsyncStreamLimitError:
raise newMultipartReadError("Multipart message headers size too big")
except AsyncStreamReadError:
raise newMultipartReadError("Error reading multipart message")
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} = proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
case mp.kind
of MultiPartSource.Stream:
try: try:
let res = await mp.stream.read() let res = await mp.stream.read()
return res return res
except AsyncStreamError: except AsyncStreamError:
raise newException(HttpCriticalError, "Could not read multipart body") raise newException(HttpCriticalError, "Could not read multipart body")
of MultiPartSource.Buffer:
return mp.buffer
proc consumeBody*(mp: MultiPart) {.async.} = proc consumeBody*(mp: MultiPart) {.async.} =
case mp.kind
of MultiPartSource.Stream:
try: try:
await mp.stream.consume() await mp.stream.consume()
except AsyncStreamError: except AsyncStreamError:
raise newException(HttpCriticalError, "Could not consume multipart body") raise newException(HttpCriticalError, "Could not consume multipart body")
of MultiPartSource.Buffer:
discard
proc getBytes*(mp: MultiPart): seq[byte] =
## Returns MultiPart value as sequence of bytes.
case mp.kind
of MultiPartSource.Buffer:
mp.buffer
of MultiPartSource.Stream:
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
mp.buffer
proc getString*(mp: MultiPart): string =
## Returns MultiPart value as string.
case mp.kind
of MultiPartSource.Buffer:
if len(mp.buffer) > 0:
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
of MultiPartSource.Stream:
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
if len(mp.buffer) > 0:
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] = proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
doAssert(mpr.kind == MultiPartSource.Buffer) doAssert(mpr.kind == MultiPartSource.Buffer)
@ -215,8 +270,11 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
# We set reader's offset to the place right after <CR><LF> # We set reader's offset to the place right after <CR><LF>
mpr.offset = start + pos2 + 2 mpr.offset = start + pos2 + 2
var part = MultiPart(
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init()) kind: MultiPartSource.Buffer,
headers: HttpTable.init(),
buffer: @(mpr.buffer.toOpenArray(start, start + pos2 - 1))
)
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)): for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
part.headers.add(k, v) part.headers.add(k, v)
ok(part) ok(part)
@ -255,7 +313,7 @@ proc boundaryValue2(c: char): bool =
c in {'a'..'z', 'A' .. 'Z', '0' .. '9', c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'} '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
func getMultipartBoundary*(contentType: string): Result[string, string] = func getMultipartBoundary*(contentType: string): HttpResult[string] =
let mparts = contentType.split(";") let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data": if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart") return err("Content-Type is not multipart")
@ -270,7 +328,7 @@ func getMultipartBoundary*(contentType: string): Result[string, string] =
else: else:
ok(strip(bparts[1])) ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): Result[string, string] = func getContentType*(contentHeader: seq[string]): HttpResult[string] =
if len(contentHeader) > 1: if len(contentHeader) > 1:
return err("Multiple Content-Header values found") return err("Multiple Content-Header values found")
let mparts = contentHeader[0].split(";") let mparts = contentHeader[0].split(";")