Prepare for HttpResponse.
This commit is contained in:
parent
60e5396a9e
commit
0e5ea5b737
|
@ -6,8 +6,8 @@
|
|||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import stew/results, httputils
|
||||
export results, httputils
|
||||
import stew/results, httputils, strutils, uri
|
||||
export results, httputils, strutils
|
||||
|
||||
const
|
||||
useChroniclesLogging* {.booldefine.} = false
|
||||
|
@ -19,9 +19,110 @@ type
|
|||
HttpResultCode*[T] = Result[T, HttpCode]
|
||||
|
||||
HttpError* = object of CatchableError
|
||||
HttpCriticalFailure* = object of HttpError
|
||||
HttpRecoverableFailure* = object of HttpError
|
||||
HttpCriticalError* = object of HttpError
|
||||
HttpRecoverableError* = object of HttpError
|
||||
|
||||
TransferEncodingFlags* {.pure.} = enum
|
||||
Identity, Chunked, Compress, Deflate, Gzip
|
||||
|
||||
ContentEncodingFlags* {.pure.} = enum
|
||||
Identity, Br, Compress, Deflate, Gzip
|
||||
|
||||
template log*(body: untyped) =
|
||||
when defined(useChroniclesLogging):
|
||||
body
|
||||
|
||||
proc newHttpCriticalError*(msg: string): ref HttpCriticalError =
|
||||
newException(HttpCriticalError, msg)
|
||||
|
||||
proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError =
|
||||
newException(HttpRecoverableError, msg)
|
||||
|
||||
iterator queryParams*(query: string): tuple[key: string, value: string] =
|
||||
## Iterate over url-encoded query string.
|
||||
for pair in query.split('&'):
|
||||
let items = pair.split('=', maxsplit = 1)
|
||||
let k = items[0]
|
||||
let v = if len(items) > 1: items[1] else: ""
|
||||
yield (decodeUrl(k), decodeUrl(v))
|
||||
|
||||
func getTransferEncoding*(ch: openarray[string]): HttpResult[
|
||||
set[TransferEncodingFlags]] =
|
||||
## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return
|
||||
## it as set of ``TransferEncodingFlags``.
|
||||
var res: set[TransferEncodingFlags] = {}
|
||||
if len(ch) == 0:
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in ch:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii())
|
||||
of "identity":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
of "chunked":
|
||||
res.incl(TransferEncodingFlags.Chunked)
|
||||
of "compress":
|
||||
res.incl(TransferEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(TransferEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(TransferEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Transfer-Encoding value")
|
||||
ok(res)
|
||||
|
||||
func getContentEncoding*(ch: openarray[string]): HttpResult[
|
||||
set[ContentEncodingFlags]] =
|
||||
## Parse value of multiple HTTP headers ``Content-Encoding`` and return
|
||||
## it as set of ``ContentEncodingFlags``.
|
||||
var res: set[ContentEncodingFlags] = {}
|
||||
if len(ch) == 0:
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in ch:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii()):
|
||||
of "identity":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
of "br":
|
||||
res.incl(ContentEncodingFlags.Br)
|
||||
of "compress":
|
||||
res.incl(ContentEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(ContentEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(ContentEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Content-Encoding value")
|
||||
ok(res)
|
||||
|
||||
func getContentType*(ch: openarray[string]): HttpResult[string] =
|
||||
## Check and prepare value of ``Content-Type`` header.
|
||||
if len(ch) > 1:
|
||||
err("Multiple Content-Type values found")
|
||||
else:
|
||||
let mparts = ch[0].split(";")
|
||||
ok(strip(mparts[0]).toLowerAscii())
|
||||
|
||||
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||
## Process ``multipart/form-data`` ``Content-Type`` header and return
|
||||
## multipart boundary.
|
||||
let mparts = contentType.split(";")
|
||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||
return err("Content-Type is not multipart")
|
||||
if len(mparts) < 2:
|
||||
return err("Content-Type missing boundary value")
|
||||
let stripped = strip(mparts[1])
|
||||
if not(stripped.toLowerAscii().startsWith("boundary")):
|
||||
return err("Incorrect Content-Type boundary format")
|
||||
let bparts = stripped.split("=")
|
||||
if len(bparts) < 2:
|
||||
err("Missing Content-Type boundary")
|
||||
else:
|
||||
ok(strip(bparts[1]))
|
||||
|
|
|
@ -21,15 +21,25 @@ type
|
|||
HttpServerFlags* = enum
|
||||
Secure
|
||||
|
||||
TransferEncodingFlags* {.pure.} = enum
|
||||
Identity, Chunked, Compress, Deflate, Gzip
|
||||
HttpConnectionStatus* = enum
|
||||
DropConnection, KeepConnection
|
||||
|
||||
ContentEncodingFlags* {.pure.} = enum
|
||||
Identity, Br, Compress, Deflate, Gzip
|
||||
HttpErrorEnum* = enum
|
||||
TimeoutError, CatchableError, RecoverableError, CriticalError
|
||||
|
||||
HttpProcessError* = object
|
||||
error*: HttpErrorEnum
|
||||
exc*: HttpError
|
||||
remote*: TransportAddress
|
||||
|
||||
HttpProcessStatus*[T] = Result[T, HttpProcessError]
|
||||
|
||||
HttpRequestFlags* {.pure.} = enum
|
||||
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
|
||||
|
||||
HttpProcessCallback* =
|
||||
proc(request: HttpProcessStatus[HttpRequest]): Future[HttpStatus]
|
||||
|
||||
HttpServer* = ref object of RootRef
|
||||
instance*: StreamServer
|
||||
# semaphore*: AsyncSemaphore
|
||||
|
@ -43,6 +53,7 @@ type
|
|||
bodyTimeout: Duration
|
||||
maxHeadersSize: int
|
||||
maxRequestBodySize: int
|
||||
processCallback: HttpProcessCallback
|
||||
|
||||
HttpServerState* = enum
|
||||
ServerRunning, ServerStopped, ServerClosed
|
||||
|
@ -64,18 +75,24 @@ type
|
|||
connection*: HttpConnection
|
||||
mainReader*: AsyncStreamReader
|
||||
|
||||
HttpResponse* = object
|
||||
code*: HttpCode
|
||||
version*: HttpVersion
|
||||
headersTable: HttpTable
|
||||
body*: seq[byte]
|
||||
connection*: HttpConnection
|
||||
mainWriter: AsyncStreamWriter
|
||||
|
||||
HttpConnection* = ref object of RootRef
|
||||
server: HttpServer
|
||||
server*: HttpServer
|
||||
transp: StreamTransport
|
||||
buffer: seq[byte]
|
||||
|
||||
const
|
||||
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
||||
|
||||
proc new*(htype: typedesc[HttpServer],
|
||||
address: TransportAddress,
|
||||
flags: set[HttpServerFlags] = {},
|
||||
serverUri = Uri(),
|
||||
processCallback: HttpProcessCallback,
|
||||
maxConnections: int = -1,
|
||||
bufferSize: int = 4096,
|
||||
backlogSize: int = 100,
|
||||
|
@ -88,7 +105,8 @@ proc new*(htype: typedesc[HttpServer],
|
|||
headersTimeout: httpHeadersTimeout,
|
||||
bodyTimeout: httpBodyTimeout,
|
||||
maxHeadersSize: maxHeadersSize,
|
||||
maxRequestBodySize: maxRequestBodySize
|
||||
maxRequestBodySize: maxRequestBodySize,
|
||||
processCallback: processCallback
|
||||
)
|
||||
|
||||
res.baseUri =
|
||||
|
@ -118,92 +136,6 @@ proc getId(transp: StreamTransport): string {.inline.} =
|
|||
## Returns string unique transport's identifier as string.
|
||||
$transp.remoteAddress() & "_" & $transp.localAddress()
|
||||
|
||||
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
|
||||
newException(HttpCriticalFailure, msg)
|
||||
|
||||
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
|
||||
newException(HttpRecoverableFailure, msg)
|
||||
|
||||
iterator queryParams(query: string): tuple[key: string, value: string] =
|
||||
for pair in query.split('&'):
|
||||
let items = pair.split('=', maxsplit = 1)
|
||||
let k = items[0]
|
||||
let v = if len(items) > 1: items[1] else: ""
|
||||
yield (decodeUrl(k), decodeUrl(v))
|
||||
|
||||
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||
let mparts = contentType.split(";")
|
||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||
return err("Content-Type is not multipart")
|
||||
if len(mparts) < 2:
|
||||
return err("Content-Type missing boundary value")
|
||||
let stripped = strip(mparts[1])
|
||||
if not(stripped.toLowerAscii().startsWith("boundary")):
|
||||
return err("Incorrect Content-Type boundary format")
|
||||
let bparts = stripped.split("=")
|
||||
if len(bparts) < 2:
|
||||
err("Missing Content-Type boundary")
|
||||
else:
|
||||
ok(strip(bparts[1]))
|
||||
|
||||
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
|
||||
if len(contentHeader) > 1:
|
||||
return err("Multiple Content-Type values found")
|
||||
let mparts = contentHeader[0].split(";")
|
||||
ok(strip(mparts[0]).toLowerAscii())
|
||||
|
||||
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
|
||||
set[TransferEncodingFlags]] =
|
||||
var res: set[TransferEncodingFlags] = {}
|
||||
if len(contentHeader) == 0:
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in contentHeader:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii())
|
||||
of "identity":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
of "chunked":
|
||||
res.incl(TransferEncodingFlags.Chunked)
|
||||
of "compress":
|
||||
res.incl(TransferEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(TransferEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(TransferEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Transfer-Encoding value")
|
||||
ok(res)
|
||||
|
||||
func getContentEncoding(contentHeader: seq[string]): HttpResult[
|
||||
set[ContentEncodingFlags]] =
|
||||
var res: set[ContentEncodingFlags] = {}
|
||||
if len(contentHeader) == 0:
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in contentHeader:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii()):
|
||||
of "identity":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
of "br":
|
||||
res.incl(ContentEncodingFlags.Br)
|
||||
of "compress":
|
||||
res.incl(ContentEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(ContentEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(ContentEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Content-Encoding value")
|
||||
ok(res)
|
||||
|
||||
proc hasBody*(request: HttpRequest): bool =
|
||||
## Returns ``true`` if request has body.
|
||||
request.requestFlags * {HttpRequestFlags.BoundBody,
|
||||
|
@ -333,7 +265,7 @@ proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
|
|||
try:
|
||||
return await read(res.get())
|
||||
except AsyncStreamError:
|
||||
raise newHttpCriticalFailure("Read failure")
|
||||
raise newHttpCriticalError("Read Error")
|
||||
|
||||
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
||||
## Consume/discard request's body.
|
||||
|
@ -346,7 +278,7 @@ proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
|||
discard await reader.consume()
|
||||
return
|
||||
except AsyncStreamError:
|
||||
raise newHttpCriticalFailure("Read failure")
|
||||
raise newHttpCriticalError("Read Error")
|
||||
|
||||
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
||||
code: HttpCode, keepAlive = true,
|
||||
|
@ -374,17 +306,13 @@ proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
|||
except CatchableError:
|
||||
return false
|
||||
|
||||
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true,
|
||||
datatype = "text/text",
|
||||
databody = ""): Future[bool] =
|
||||
proc sendErrorResponse*(request: HttpRequest, code: HttpCode, keepAlive = true,
|
||||
datatype = "text/text",
|
||||
databody = ""): Future[bool] =
|
||||
sendErrorResponse(request.connection, request.version, code, keepAlive,
|
||||
datatype, databody)
|
||||
|
||||
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
||||
when defined(useChroniclesLogging):
|
||||
logScope:
|
||||
peer = $conn.transp.remoteAddress
|
||||
|
||||
try:
|
||||
conn.buffer.setLen(conn.server.maxHeadersSize)
|
||||
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
|
||||
|
@ -392,42 +320,62 @@ proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
|||
conn.buffer.setLen(res)
|
||||
let header = parseRequest(conn.buffer)
|
||||
if header.failed():
|
||||
log debug "Malformed header received"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||
raise newHttpCriticalFailure("Malformed request recieved")
|
||||
raise newHttpCriticalError("Malformed request recieved")
|
||||
else:
|
||||
let res = prepareRequest(conn, header)
|
||||
if res.isErr():
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||
raise newHttpCriticalFailure("Invalid request received")
|
||||
raise newHttpCriticalError("Invalid request received")
|
||||
else:
|
||||
return res.get()
|
||||
except TransportOsError:
|
||||
log debug "Unexpected OS error"
|
||||
raise newHttpCriticalFailure("Unexpected OS error")
|
||||
raise newHttpCriticalError("Unexpected OS error")
|
||||
except TransportIncompleteError:
|
||||
log debug "Remote peer disconnected"
|
||||
raise newHttpCriticalFailure("Remote peer disconnected")
|
||||
raise newHttpCriticalError("Remote peer disconnected")
|
||||
except TransportLimitError:
|
||||
log debug "Maximum size of request headers reached"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
|
||||
raise newHttpCriticalFailure("Maximum size of request headers reached")
|
||||
raise newHttpCriticalError("Maximum size of request headers reached")
|
||||
|
||||
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
||||
when defined(useChroniclesLogging):
|
||||
logScope:
|
||||
peer = $transp.remoteAddress
|
||||
|
||||
var conn = HttpConnection(
|
||||
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
|
||||
server: server
|
||||
)
|
||||
|
||||
log info "Client connected"
|
||||
var breakLoop = false
|
||||
while true:
|
||||
var status: HttpProcessStatus
|
||||
var arg: HttpProcessStatus[HttpRequest]
|
||||
try:
|
||||
let request = await conn.getRequest().wait(server.headersTimeout)
|
||||
arg = ok(request)
|
||||
except AsyncTimeoutError as exc:
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
||||
breakLoop = true
|
||||
arg = err(HttpProcessError(exc: exc, remote: transp.remoteAddress()))
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
except HttpRecoverableError:
|
||||
breakLoop = false
|
||||
arg = err()
|
||||
except CatchableError:
|
||||
breakLoop = true
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
|
||||
breakLoop = false
|
||||
let status =
|
||||
try:
|
||||
await conn.server.processCallback()
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
HttpCriticalError
|
||||
except CatchableError:
|
||||
breakLoop = true
|
||||
HttpCriticalError
|
||||
|
||||
echo "== HEADERS TABLE"
|
||||
echo request.headersTable
|
||||
echo "== QUERY TABLE"
|
||||
|
@ -439,19 +387,17 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
|||
echo cast[string](stream)
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
|
||||
databody = "OK")
|
||||
log debug "Response sent"
|
||||
except AsyncTimeoutError:
|
||||
log debug "Timeout reached while reading headers"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
||||
breakLoop = true
|
||||
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
|
||||
except HttpRecoverableFailure:
|
||||
except HttpRecoverableError:
|
||||
breakLoop = false
|
||||
|
||||
except HttpCriticalFailure:
|
||||
except HttpCriticalError:
|
||||
breakLoop = true
|
||||
|
||||
except CatchableError as exc:
|
||||
|
@ -464,7 +410,6 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
|||
server.connections.del(transp.getId())
|
||||
# if server.maxConnections > 0:
|
||||
# server.semaphore.release()
|
||||
log info "Client got disconnected"
|
||||
|
||||
proc acceptClientLoop(server: HttpServer) {.async.} =
|
||||
var breakLoop = false
|
||||
|
|
|
@ -8,9 +8,11 @@
|
|||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[monotimes, strutils]
|
||||
import chronos, stew/results
|
||||
import stew/results
|
||||
import ../../asyncloop
|
||||
import ../../streams/[asyncstream, boundstream, chunkstream]
|
||||
import httptable, httpcommon
|
||||
export httptable, httpcommon
|
||||
export httptable, httpcommon, asyncstream
|
||||
|
||||
type
|
||||
MultiPartSource {.pure.} = enum
|
||||
|
@ -20,9 +22,9 @@ type
|
|||
case kind: MultiPartSource
|
||||
of MultiPartSource.Stream:
|
||||
stream: AsyncStreamReader
|
||||
last: BoundedAsyncStreamReader
|
||||
of MultiPartSource.Buffer:
|
||||
discard
|
||||
firstTime: bool
|
||||
buffer: seq[byte]
|
||||
offset: int
|
||||
boundary: seq[byte]
|
||||
|
@ -30,18 +32,25 @@ type
|
|||
MultiPartReaderRef* = ref MultiPartReader
|
||||
|
||||
MultiPart* = object
|
||||
case kind: MultiPartSource
|
||||
of MultiPartSource.Stream:
|
||||
stream*: BoundedStreamReader
|
||||
of MultiPartSource.Buffer:
|
||||
discard
|
||||
buffer: seq[byte]
|
||||
headers: HttpTable
|
||||
stream: BoundedAsyncStreamReader
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
MultipartError* = object of HttpError
|
||||
MultipartError* = object of HttpCriticalError
|
||||
MultipartEOMError* = object of MultipartError
|
||||
MultiPartIncorrectError* = object of MultipartError
|
||||
MultiPartIncompleteError* = object of MultipartError
|
||||
MultipartIncorrectError* = object of MultipartError
|
||||
MultipartIncompleteError* = object of MultipartError
|
||||
MultipartReadError* = object of MultipartError
|
||||
|
||||
BChar* = byte | char
|
||||
|
||||
proc newMultipartReadError(msg: string): ref MultipartReadError =
|
||||
newException(MultipartReadError, msg)
|
||||
|
||||
proc startsWith*(s, prefix: openarray[byte]): bool =
|
||||
var i = 0
|
||||
while true:
|
||||
|
@ -100,60 +109,106 @@ proc init*[B: BChar](mpt: typedesc[MultiPartReader],
|
|||
|
||||
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||
doAssert(mpr.kind == MultiPartSource.Stream)
|
||||
# According to RFC1521 says that a boundary "must be no longer than 70
|
||||
# characters, not counting the two leading hyphens.
|
||||
if mpr.firstTime:
|
||||
# Read and verify initial <-><-><boundary><CR><LF>
|
||||
mpr.firstTime = false
|
||||
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
|
||||
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
||||
raise newException(MultiPartEOMError, "Unexpected EOM encountered")
|
||||
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8:
|
||||
try:
|
||||
# Read and verify initial <-><-><boundary><CR><LF>
|
||||
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
|
||||
mpr.firstTime = false
|
||||
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
if mpr.buffer[0] == byte('-') and mpr.buffer[1] == byte('-'):
|
||||
raise newException(MultiPartEOMError,
|
||||
"Unexpected EOM encountered")
|
||||
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Unexpected boundary suffix")
|
||||
else:
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Unexpected boundary suffix")
|
||||
else:
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Unexpected boundary encountered")
|
||||
"Unexpected boundary encountered")
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newMultipartReadError("Error reading multipart message")
|
||||
except AsyncStreamReadError:
|
||||
raise newMultipartReadError("Error reading multipart message")
|
||||
|
||||
# Reading part's headers
|
||||
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
|
||||
try:
|
||||
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
|
||||
HeadersMark)
|
||||
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1))
|
||||
if headersList.failed():
|
||||
raise newException(MultiPartIncorrectError, "Incorrect part headers found")
|
||||
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
|
||||
if headersList.failed():
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Incorrect part headers found")
|
||||
|
||||
var part = MultiPart()
|
||||
var part = MultiPart(
|
||||
kind: MultiPartSource.Stream,
|
||||
headers: HttpTable.init(),
|
||||
stream: newBoundedStreamReader(mpr.stream, -1, mpr.boundary)
|
||||
)
|
||||
|
||||
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4)
|
||||
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
await mpr.stream.readExactly(addr buffer[0], 2)
|
||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
||||
raise newException(MultiPartEOMError, "")
|
||||
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
|
||||
for k, v in headersList.headers(mpr.buffer.toOpenArray(0, res - 1)):
|
||||
part.headers.add(k, v)
|
||||
|
||||
except:
|
||||
discard
|
||||
# if mpr.offset >= len(mpr.buffer):
|
||||
# raise newException(MultiPartEOMError, "End of multipart form encountered")
|
||||
return part
|
||||
|
||||
proc getStream*(mp: MultiPart): AsyncStreamReader =
|
||||
mp.stream
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except AsyncStreamIncompleteError:
|
||||
raise newMultipartReadError("Error reading multipart message")
|
||||
except AsyncStreamLimitError:
|
||||
raise newMultipartReadError("Multipart message headers size too big")
|
||||
except AsyncStreamReadError:
|
||||
raise newMultipartReadError("Error reading multipart message")
|
||||
|
||||
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
|
||||
try:
|
||||
let res = await mp.stream.read()
|
||||
return res
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not read multipart body")
|
||||
case mp.kind
|
||||
of MultiPartSource.Stream:
|
||||
try:
|
||||
let res = await mp.stream.read()
|
||||
return res
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not read multipart body")
|
||||
of MultiPartSource.Buffer:
|
||||
return mp.buffer
|
||||
|
||||
proc consumeBody*(mp: MultiPart) {.async.} =
|
||||
try:
|
||||
await mp.stream.consume()
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not consume multipart body")
|
||||
case mp.kind
|
||||
of MultiPartSource.Stream:
|
||||
try:
|
||||
await mp.stream.consume()
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not consume multipart body")
|
||||
of MultiPartSource.Buffer:
|
||||
discard
|
||||
|
||||
proc getBytes*(mp: MultiPart): seq[byte] =
|
||||
## Returns MultiPart value as sequence of bytes.
|
||||
case mp.kind
|
||||
of MultiPartSource.Buffer:
|
||||
mp.buffer
|
||||
of MultiPartSource.Stream:
|
||||
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
|
||||
mp.buffer
|
||||
|
||||
proc getString*(mp: MultiPart): string =
|
||||
## Returns MultiPart value as string.
|
||||
case mp.kind
|
||||
of MultiPartSource.Buffer:
|
||||
if len(mp.buffer) > 0:
|
||||
var res = newString(len(mp.buffer))
|
||||
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
|
||||
res
|
||||
else:
|
||||
""
|
||||
of MultiPartSource.Stream:
|
||||
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
|
||||
if len(mp.buffer) > 0:
|
||||
var res = newString(len(mp.buffer))
|
||||
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
|
||||
res
|
||||
else:
|
||||
""
|
||||
|
||||
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
||||
doAssert(mpr.kind == MultiPartSource.Buffer)
|
||||
|
@ -215,8 +270,11 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
|||
|
||||
# We set reader's offset to the place right after <CR><LF>
|
||||
mpr.offset = start + pos2 + 2
|
||||
|
||||
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init())
|
||||
var part = MultiPart(
|
||||
kind: MultiPartSource.Buffer,
|
||||
headers: HttpTable.init(),
|
||||
buffer: @(mpr.buffer.toOpenArray(start, start + pos2 - 1))
|
||||
)
|
||||
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
|
||||
part.headers.add(k, v)
|
||||
ok(part)
|
||||
|
@ -255,7 +313,7 @@ proc boundaryValue2(c: char): bool =
|
|||
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
|
||||
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
|
||||
|
||||
func getMultipartBoundary*(contentType: string): Result[string, string] =
|
||||
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||
let mparts = contentType.split(";")
|
||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||
return err("Content-Type is not multipart")
|
||||
|
@ -270,7 +328,7 @@ func getMultipartBoundary*(contentType: string): Result[string, string] =
|
|||
else:
|
||||
ok(strip(bparts[1]))
|
||||
|
||||
func getContentType*(contentHeader: seq[string]): Result[string, string] =
|
||||
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
|
||||
if len(contentHeader) > 1:
|
||||
return err("Multiple Content-Header values found")
|
||||
let mparts = contentHeader[0].split(";")
|
||||
|
|
Loading…
Reference in New Issue