Prepare for HttpResponse.
This commit is contained in:
parent
60e5396a9e
commit
0e5ea5b737
|
@ -6,8 +6,8 @@
|
||||||
# Licensed under either of
|
# Licensed under either of
|
||||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||||
# MIT license (LICENSE-MIT)
|
# MIT license (LICENSE-MIT)
|
||||||
import stew/results, httputils
|
import stew/results, httputils, strutils, uri
|
||||||
export results, httputils
|
export results, httputils, strutils
|
||||||
|
|
||||||
const
|
const
|
||||||
useChroniclesLogging* {.booldefine.} = false
|
useChroniclesLogging* {.booldefine.} = false
|
||||||
|
@ -19,9 +19,110 @@ type
|
||||||
HttpResultCode*[T] = Result[T, HttpCode]
|
HttpResultCode*[T] = Result[T, HttpCode]
|
||||||
|
|
||||||
HttpError* = object of CatchableError
|
HttpError* = object of CatchableError
|
||||||
HttpCriticalFailure* = object of HttpError
|
HttpCriticalError* = object of HttpError
|
||||||
HttpRecoverableFailure* = object of HttpError
|
HttpRecoverableError* = object of HttpError
|
||||||
|
|
||||||
|
TransferEncodingFlags* {.pure.} = enum
|
||||||
|
Identity, Chunked, Compress, Deflate, Gzip
|
||||||
|
|
||||||
|
ContentEncodingFlags* {.pure.} = enum
|
||||||
|
Identity, Br, Compress, Deflate, Gzip
|
||||||
|
|
||||||
template log*(body: untyped) =
|
template log*(body: untyped) =
|
||||||
when defined(useChroniclesLogging):
|
when defined(useChroniclesLogging):
|
||||||
body
|
body
|
||||||
|
|
||||||
|
proc newHttpCriticalError*(msg: string): ref HttpCriticalError =
|
||||||
|
newException(HttpCriticalError, msg)
|
||||||
|
|
||||||
|
proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError =
|
||||||
|
newException(HttpRecoverableError, msg)
|
||||||
|
|
||||||
|
iterator queryParams*(query: string): tuple[key: string, value: string] =
|
||||||
|
## Iterate over url-encoded query string.
|
||||||
|
for pair in query.split('&'):
|
||||||
|
let items = pair.split('=', maxsplit = 1)
|
||||||
|
let k = items[0]
|
||||||
|
let v = if len(items) > 1: items[1] else: ""
|
||||||
|
yield (decodeUrl(k), decodeUrl(v))
|
||||||
|
|
||||||
|
func getTransferEncoding*(ch: openarray[string]): HttpResult[
|
||||||
|
set[TransferEncodingFlags]] =
|
||||||
|
## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return
|
||||||
|
## it as set of ``TransferEncodingFlags``.
|
||||||
|
var res: set[TransferEncodingFlags] = {}
|
||||||
|
if len(ch) == 0:
|
||||||
|
res.incl(TransferEncodingFlags.Identity)
|
||||||
|
ok(res)
|
||||||
|
else:
|
||||||
|
for header in ch:
|
||||||
|
for item in header.split(","):
|
||||||
|
case strip(item.toLowerAscii())
|
||||||
|
of "identity":
|
||||||
|
res.incl(TransferEncodingFlags.Identity)
|
||||||
|
of "chunked":
|
||||||
|
res.incl(TransferEncodingFlags.Chunked)
|
||||||
|
of "compress":
|
||||||
|
res.incl(TransferEncodingFlags.Compress)
|
||||||
|
of "deflate":
|
||||||
|
res.incl(TransferEncodingFlags.Deflate)
|
||||||
|
of "gzip":
|
||||||
|
res.incl(TransferEncodingFlags.Gzip)
|
||||||
|
of "":
|
||||||
|
res.incl(TransferEncodingFlags.Identity)
|
||||||
|
else:
|
||||||
|
return err("Incorrect Transfer-Encoding value")
|
||||||
|
ok(res)
|
||||||
|
|
||||||
|
func getContentEncoding*(ch: openarray[string]): HttpResult[
|
||||||
|
set[ContentEncodingFlags]] =
|
||||||
|
## Parse value of multiple HTTP headers ``Content-Encoding`` and return
|
||||||
|
## it as set of ``ContentEncodingFlags``.
|
||||||
|
var res: set[ContentEncodingFlags] = {}
|
||||||
|
if len(ch) == 0:
|
||||||
|
res.incl(ContentEncodingFlags.Identity)
|
||||||
|
ok(res)
|
||||||
|
else:
|
||||||
|
for header in ch:
|
||||||
|
for item in header.split(","):
|
||||||
|
case strip(item.toLowerAscii()):
|
||||||
|
of "identity":
|
||||||
|
res.incl(ContentEncodingFlags.Identity)
|
||||||
|
of "br":
|
||||||
|
res.incl(ContentEncodingFlags.Br)
|
||||||
|
of "compress":
|
||||||
|
res.incl(ContentEncodingFlags.Compress)
|
||||||
|
of "deflate":
|
||||||
|
res.incl(ContentEncodingFlags.Deflate)
|
||||||
|
of "gzip":
|
||||||
|
res.incl(ContentEncodingFlags.Gzip)
|
||||||
|
of "":
|
||||||
|
res.incl(ContentEncodingFlags.Identity)
|
||||||
|
else:
|
||||||
|
return err("Incorrect Content-Encoding value")
|
||||||
|
ok(res)
|
||||||
|
|
||||||
|
func getContentType*(ch: openarray[string]): HttpResult[string] =
|
||||||
|
## Check and prepare value of ``Content-Type`` header.
|
||||||
|
if len(ch) > 1:
|
||||||
|
err("Multiple Content-Type values found")
|
||||||
|
else:
|
||||||
|
let mparts = ch[0].split(";")
|
||||||
|
ok(strip(mparts[0]).toLowerAscii())
|
||||||
|
|
||||||
|
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||||
|
## Process ``multipart/form-data`` ``Content-Type`` header and return
|
||||||
|
## multipart boundary.
|
||||||
|
let mparts = contentType.split(";")
|
||||||
|
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||||
|
return err("Content-Type is not multipart")
|
||||||
|
if len(mparts) < 2:
|
||||||
|
return err("Content-Type missing boundary value")
|
||||||
|
let stripped = strip(mparts[1])
|
||||||
|
if not(stripped.toLowerAscii().startsWith("boundary")):
|
||||||
|
return err("Incorrect Content-Type boundary format")
|
||||||
|
let bparts = stripped.split("=")
|
||||||
|
if len(bparts) < 2:
|
||||||
|
err("Missing Content-Type boundary")
|
||||||
|
else:
|
||||||
|
ok(strip(bparts[1]))
|
||||||
|
|
|
@ -21,15 +21,25 @@ type
|
||||||
HttpServerFlags* = enum
|
HttpServerFlags* = enum
|
||||||
Secure
|
Secure
|
||||||
|
|
||||||
TransferEncodingFlags* {.pure.} = enum
|
HttpConnectionStatus* = enum
|
||||||
Identity, Chunked, Compress, Deflate, Gzip
|
DropConnection, KeepConnection
|
||||||
|
|
||||||
ContentEncodingFlags* {.pure.} = enum
|
HttpErrorEnum* = enum
|
||||||
Identity, Br, Compress, Deflate, Gzip
|
TimeoutError, CatchableError, RecoverableError, CriticalError
|
||||||
|
|
||||||
|
HttpProcessError* = object
|
||||||
|
error*: HttpErrorEnum
|
||||||
|
exc*: HttpError
|
||||||
|
remote*: TransportAddress
|
||||||
|
|
||||||
|
HttpProcessStatus*[T] = Result[T, HttpProcessError]
|
||||||
|
|
||||||
HttpRequestFlags* {.pure.} = enum
|
HttpRequestFlags* {.pure.} = enum
|
||||||
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
|
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
|
||||||
|
|
||||||
|
HttpProcessCallback* =
|
||||||
|
proc(request: HttpProcessStatus[HttpRequest]): Future[HttpStatus]
|
||||||
|
|
||||||
HttpServer* = ref object of RootRef
|
HttpServer* = ref object of RootRef
|
||||||
instance*: StreamServer
|
instance*: StreamServer
|
||||||
# semaphore*: AsyncSemaphore
|
# semaphore*: AsyncSemaphore
|
||||||
|
@ -43,6 +53,7 @@ type
|
||||||
bodyTimeout: Duration
|
bodyTimeout: Duration
|
||||||
maxHeadersSize: int
|
maxHeadersSize: int
|
||||||
maxRequestBodySize: int
|
maxRequestBodySize: int
|
||||||
|
processCallback: HttpProcessCallback
|
||||||
|
|
||||||
HttpServerState* = enum
|
HttpServerState* = enum
|
||||||
ServerRunning, ServerStopped, ServerClosed
|
ServerRunning, ServerStopped, ServerClosed
|
||||||
|
@ -64,18 +75,24 @@ type
|
||||||
connection*: HttpConnection
|
connection*: HttpConnection
|
||||||
mainReader*: AsyncStreamReader
|
mainReader*: AsyncStreamReader
|
||||||
|
|
||||||
|
HttpResponse* = object
|
||||||
|
code*: HttpCode
|
||||||
|
version*: HttpVersion
|
||||||
|
headersTable: HttpTable
|
||||||
|
body*: seq[byte]
|
||||||
|
connection*: HttpConnection
|
||||||
|
mainWriter: AsyncStreamWriter
|
||||||
|
|
||||||
HttpConnection* = ref object of RootRef
|
HttpConnection* = ref object of RootRef
|
||||||
server: HttpServer
|
server*: HttpServer
|
||||||
transp: StreamTransport
|
transp: StreamTransport
|
||||||
buffer: seq[byte]
|
buffer: seq[byte]
|
||||||
|
|
||||||
const
|
|
||||||
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
|
||||||
|
|
||||||
proc new*(htype: typedesc[HttpServer],
|
proc new*(htype: typedesc[HttpServer],
|
||||||
address: TransportAddress,
|
address: TransportAddress,
|
||||||
flags: set[HttpServerFlags] = {},
|
flags: set[HttpServerFlags] = {},
|
||||||
serverUri = Uri(),
|
serverUri = Uri(),
|
||||||
|
processCallback: HttpProcessCallback,
|
||||||
maxConnections: int = -1,
|
maxConnections: int = -1,
|
||||||
bufferSize: int = 4096,
|
bufferSize: int = 4096,
|
||||||
backlogSize: int = 100,
|
backlogSize: int = 100,
|
||||||
|
@ -88,7 +105,8 @@ proc new*(htype: typedesc[HttpServer],
|
||||||
headersTimeout: httpHeadersTimeout,
|
headersTimeout: httpHeadersTimeout,
|
||||||
bodyTimeout: httpBodyTimeout,
|
bodyTimeout: httpBodyTimeout,
|
||||||
maxHeadersSize: maxHeadersSize,
|
maxHeadersSize: maxHeadersSize,
|
||||||
maxRequestBodySize: maxRequestBodySize
|
maxRequestBodySize: maxRequestBodySize,
|
||||||
|
processCallback: processCallback
|
||||||
)
|
)
|
||||||
|
|
||||||
res.baseUri =
|
res.baseUri =
|
||||||
|
@ -118,92 +136,6 @@ proc getId(transp: StreamTransport): string {.inline.} =
|
||||||
## Returns string unique transport's identifier as string.
|
## Returns string unique transport's identifier as string.
|
||||||
$transp.remoteAddress() & "_" & $transp.localAddress()
|
$transp.remoteAddress() & "_" & $transp.localAddress()
|
||||||
|
|
||||||
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
|
|
||||||
newException(HttpCriticalFailure, msg)
|
|
||||||
|
|
||||||
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
|
|
||||||
newException(HttpRecoverableFailure, msg)
|
|
||||||
|
|
||||||
iterator queryParams(query: string): tuple[key: string, value: string] =
|
|
||||||
for pair in query.split('&'):
|
|
||||||
let items = pair.split('=', maxsplit = 1)
|
|
||||||
let k = items[0]
|
|
||||||
let v = if len(items) > 1: items[1] else: ""
|
|
||||||
yield (decodeUrl(k), decodeUrl(v))
|
|
||||||
|
|
||||||
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
|
||||||
let mparts = contentType.split(";")
|
|
||||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
|
||||||
return err("Content-Type is not multipart")
|
|
||||||
if len(mparts) < 2:
|
|
||||||
return err("Content-Type missing boundary value")
|
|
||||||
let stripped = strip(mparts[1])
|
|
||||||
if not(stripped.toLowerAscii().startsWith("boundary")):
|
|
||||||
return err("Incorrect Content-Type boundary format")
|
|
||||||
let bparts = stripped.split("=")
|
|
||||||
if len(bparts) < 2:
|
|
||||||
err("Missing Content-Type boundary")
|
|
||||||
else:
|
|
||||||
ok(strip(bparts[1]))
|
|
||||||
|
|
||||||
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
|
|
||||||
if len(contentHeader) > 1:
|
|
||||||
return err("Multiple Content-Type values found")
|
|
||||||
let mparts = contentHeader[0].split(";")
|
|
||||||
ok(strip(mparts[0]).toLowerAscii())
|
|
||||||
|
|
||||||
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
|
|
||||||
set[TransferEncodingFlags]] =
|
|
||||||
var res: set[TransferEncodingFlags] = {}
|
|
||||||
if len(contentHeader) == 0:
|
|
||||||
res.incl(TransferEncodingFlags.Identity)
|
|
||||||
ok(res)
|
|
||||||
else:
|
|
||||||
for header in contentHeader:
|
|
||||||
for item in header.split(","):
|
|
||||||
case strip(item.toLowerAscii())
|
|
||||||
of "identity":
|
|
||||||
res.incl(TransferEncodingFlags.Identity)
|
|
||||||
of "chunked":
|
|
||||||
res.incl(TransferEncodingFlags.Chunked)
|
|
||||||
of "compress":
|
|
||||||
res.incl(TransferEncodingFlags.Compress)
|
|
||||||
of "deflate":
|
|
||||||
res.incl(TransferEncodingFlags.Deflate)
|
|
||||||
of "gzip":
|
|
||||||
res.incl(TransferEncodingFlags.Gzip)
|
|
||||||
of "":
|
|
||||||
res.incl(TransferEncodingFlags.Identity)
|
|
||||||
else:
|
|
||||||
return err("Incorrect Transfer-Encoding value")
|
|
||||||
ok(res)
|
|
||||||
|
|
||||||
func getContentEncoding(contentHeader: seq[string]): HttpResult[
|
|
||||||
set[ContentEncodingFlags]] =
|
|
||||||
var res: set[ContentEncodingFlags] = {}
|
|
||||||
if len(contentHeader) == 0:
|
|
||||||
res.incl(ContentEncodingFlags.Identity)
|
|
||||||
ok(res)
|
|
||||||
else:
|
|
||||||
for header in contentHeader:
|
|
||||||
for item in header.split(","):
|
|
||||||
case strip(item.toLowerAscii()):
|
|
||||||
of "identity":
|
|
||||||
res.incl(ContentEncodingFlags.Identity)
|
|
||||||
of "br":
|
|
||||||
res.incl(ContentEncodingFlags.Br)
|
|
||||||
of "compress":
|
|
||||||
res.incl(ContentEncodingFlags.Compress)
|
|
||||||
of "deflate":
|
|
||||||
res.incl(ContentEncodingFlags.Deflate)
|
|
||||||
of "gzip":
|
|
||||||
res.incl(ContentEncodingFlags.Gzip)
|
|
||||||
of "":
|
|
||||||
res.incl(ContentEncodingFlags.Identity)
|
|
||||||
else:
|
|
||||||
return err("Incorrect Content-Encoding value")
|
|
||||||
ok(res)
|
|
||||||
|
|
||||||
proc hasBody*(request: HttpRequest): bool =
|
proc hasBody*(request: HttpRequest): bool =
|
||||||
## Returns ``true`` if request has body.
|
## Returns ``true`` if request has body.
|
||||||
request.requestFlags * {HttpRequestFlags.BoundBody,
|
request.requestFlags * {HttpRequestFlags.BoundBody,
|
||||||
|
@ -333,7 +265,7 @@ proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
|
||||||
try:
|
try:
|
||||||
return await read(res.get())
|
return await read(res.get())
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
raise newHttpCriticalFailure("Read failure")
|
raise newHttpCriticalError("Read Error")
|
||||||
|
|
||||||
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
||||||
## Consume/discard request's body.
|
## Consume/discard request's body.
|
||||||
|
@ -346,7 +278,7 @@ proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
||||||
discard await reader.consume()
|
discard await reader.consume()
|
||||||
return
|
return
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
raise newHttpCriticalFailure("Read failure")
|
raise newHttpCriticalError("Read Error")
|
||||||
|
|
||||||
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
||||||
code: HttpCode, keepAlive = true,
|
code: HttpCode, keepAlive = true,
|
||||||
|
@ -374,17 +306,13 @@ proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
||||||
except CatchableError:
|
except CatchableError:
|
||||||
return false
|
return false
|
||||||
|
|
||||||
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true,
|
proc sendErrorResponse*(request: HttpRequest, code: HttpCode, keepAlive = true,
|
||||||
datatype = "text/text",
|
datatype = "text/text",
|
||||||
databody = ""): Future[bool] =
|
databody = ""): Future[bool] =
|
||||||
sendErrorResponse(request.connection, request.version, code, keepAlive,
|
sendErrorResponse(request.connection, request.version, code, keepAlive,
|
||||||
datatype, databody)
|
datatype, databody)
|
||||||
|
|
||||||
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
||||||
when defined(useChroniclesLogging):
|
|
||||||
logScope:
|
|
||||||
peer = $conn.transp.remoteAddress
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn.buffer.setLen(conn.server.maxHeadersSize)
|
conn.buffer.setLen(conn.server.maxHeadersSize)
|
||||||
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
|
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
|
||||||
|
@ -392,42 +320,62 @@ proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
||||||
conn.buffer.setLen(res)
|
conn.buffer.setLen(res)
|
||||||
let header = parseRequest(conn.buffer)
|
let header = parseRequest(conn.buffer)
|
||||||
if header.failed():
|
if header.failed():
|
||||||
log debug "Malformed header received"
|
|
||||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||||
raise newHttpCriticalFailure("Malformed request recieved")
|
raise newHttpCriticalError("Malformed request recieved")
|
||||||
else:
|
else:
|
||||||
let res = prepareRequest(conn, header)
|
let res = prepareRequest(conn, header)
|
||||||
if res.isErr():
|
if res.isErr():
|
||||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||||
raise newHttpCriticalFailure("Invalid request received")
|
raise newHttpCriticalError("Invalid request received")
|
||||||
else:
|
else:
|
||||||
return res.get()
|
return res.get()
|
||||||
except TransportOsError:
|
except TransportOsError:
|
||||||
log debug "Unexpected OS error"
|
raise newHttpCriticalError("Unexpected OS error")
|
||||||
raise newHttpCriticalFailure("Unexpected OS error")
|
|
||||||
except TransportIncompleteError:
|
except TransportIncompleteError:
|
||||||
log debug "Remote peer disconnected"
|
raise newHttpCriticalError("Remote peer disconnected")
|
||||||
raise newHttpCriticalFailure("Remote peer disconnected")
|
|
||||||
except TransportLimitError:
|
except TransportLimitError:
|
||||||
log debug "Maximum size of request headers reached"
|
|
||||||
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
|
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
|
||||||
raise newHttpCriticalFailure("Maximum size of request headers reached")
|
raise newHttpCriticalError("Maximum size of request headers reached")
|
||||||
|
|
||||||
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
||||||
when defined(useChroniclesLogging):
|
|
||||||
logScope:
|
|
||||||
peer = $transp.remoteAddress
|
|
||||||
|
|
||||||
var conn = HttpConnection(
|
var conn = HttpConnection(
|
||||||
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
|
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
|
||||||
server: server
|
server: server
|
||||||
)
|
)
|
||||||
|
|
||||||
log info "Client connected"
|
|
||||||
var breakLoop = false
|
var breakLoop = false
|
||||||
while true:
|
while true:
|
||||||
|
var status: HttpProcessStatus
|
||||||
|
var arg: HttpProcessStatus[HttpRequest]
|
||||||
try:
|
try:
|
||||||
let request = await conn.getRequest().wait(server.headersTimeout)
|
let request = await conn.getRequest().wait(server.headersTimeout)
|
||||||
|
arg = ok(request)
|
||||||
|
except AsyncTimeoutError as exc:
|
||||||
|
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
||||||
|
breakLoop = true
|
||||||
|
arg = err(HttpProcessError(exc: exc, remote: transp.remoteAddress()))
|
||||||
|
except CancelledError:
|
||||||
|
breakLoop = true
|
||||||
|
except HttpRecoverableError:
|
||||||
|
breakLoop = false
|
||||||
|
arg = err()
|
||||||
|
except CatchableError:
|
||||||
|
breakLoop = true
|
||||||
|
|
||||||
|
if breakLoop:
|
||||||
|
break
|
||||||
|
|
||||||
|
breakLoop = false
|
||||||
|
let status =
|
||||||
|
try:
|
||||||
|
await conn.server.processCallback()
|
||||||
|
except CancelledError:
|
||||||
|
breakLoop = true
|
||||||
|
HttpCriticalError
|
||||||
|
except CatchableError:
|
||||||
|
breakLoop = true
|
||||||
|
HttpCriticalError
|
||||||
|
|
||||||
echo "== HEADERS TABLE"
|
echo "== HEADERS TABLE"
|
||||||
echo request.headersTable
|
echo request.headersTable
|
||||||
echo "== QUERY TABLE"
|
echo "== QUERY TABLE"
|
||||||
|
@ -439,19 +387,17 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
||||||
echo cast[string](stream)
|
echo cast[string](stream)
|
||||||
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
|
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
|
||||||
databody = "OK")
|
databody = "OK")
|
||||||
log debug "Response sent"
|
|
||||||
except AsyncTimeoutError:
|
except AsyncTimeoutError:
|
||||||
log debug "Timeout reached while reading headers"
|
|
||||||
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
||||||
breakLoop = true
|
breakLoop = true
|
||||||
|
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
breakLoop = true
|
breakLoop = true
|
||||||
|
|
||||||
except HttpRecoverableFailure:
|
except HttpRecoverableError:
|
||||||
breakLoop = false
|
breakLoop = false
|
||||||
|
|
||||||
except HttpCriticalFailure:
|
except HttpCriticalError:
|
||||||
breakLoop = true
|
breakLoop = true
|
||||||
|
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
|
@ -464,7 +410,6 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
||||||
server.connections.del(transp.getId())
|
server.connections.del(transp.getId())
|
||||||
# if server.maxConnections > 0:
|
# if server.maxConnections > 0:
|
||||||
# server.semaphore.release()
|
# server.semaphore.release()
|
||||||
log info "Client got disconnected"
|
|
||||||
|
|
||||||
proc acceptClientLoop(server: HttpServer) {.async.} =
|
proc acceptClientLoop(server: HttpServer) {.async.} =
|
||||||
var breakLoop = false
|
var breakLoop = false
|
||||||
|
|
|
@ -8,9 +8,11 @@
|
||||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||||
# MIT license (LICENSE-MIT)
|
# MIT license (LICENSE-MIT)
|
||||||
import std/[monotimes, strutils]
|
import std/[monotimes, strutils]
|
||||||
import chronos, stew/results
|
import stew/results
|
||||||
|
import ../../asyncloop
|
||||||
|
import ../../streams/[asyncstream, boundstream, chunkstream]
|
||||||
import httptable, httpcommon
|
import httptable, httpcommon
|
||||||
export httptable, httpcommon
|
export httptable, httpcommon, asyncstream
|
||||||
|
|
||||||
type
|
type
|
||||||
MultiPartSource {.pure.} = enum
|
MultiPartSource {.pure.} = enum
|
||||||
|
@ -20,9 +22,9 @@ type
|
||||||
case kind: MultiPartSource
|
case kind: MultiPartSource
|
||||||
of MultiPartSource.Stream:
|
of MultiPartSource.Stream:
|
||||||
stream: AsyncStreamReader
|
stream: AsyncStreamReader
|
||||||
last: BoundedAsyncStreamReader
|
|
||||||
of MultiPartSource.Buffer:
|
of MultiPartSource.Buffer:
|
||||||
discard
|
discard
|
||||||
|
firstTime: bool
|
||||||
buffer: seq[byte]
|
buffer: seq[byte]
|
||||||
offset: int
|
offset: int
|
||||||
boundary: seq[byte]
|
boundary: seq[byte]
|
||||||
|
@ -30,18 +32,25 @@ type
|
||||||
MultiPartReaderRef* = ref MultiPartReader
|
MultiPartReaderRef* = ref MultiPartReader
|
||||||
|
|
||||||
MultiPart* = object
|
MultiPart* = object
|
||||||
|
case kind: MultiPartSource
|
||||||
|
of MultiPartSource.Stream:
|
||||||
|
stream*: BoundedStreamReader
|
||||||
|
of MultiPartSource.Buffer:
|
||||||
|
discard
|
||||||
|
buffer: seq[byte]
|
||||||
headers: HttpTable
|
headers: HttpTable
|
||||||
stream: BoundedAsyncStreamReader
|
|
||||||
offset: int
|
|
||||||
size: int
|
|
||||||
|
|
||||||
MultipartError* = object of HttpError
|
MultipartError* = object of HttpCriticalError
|
||||||
MultipartEOMError* = object of MultipartError
|
MultipartEOMError* = object of MultipartError
|
||||||
MultiPartIncorrectError* = object of MultipartError
|
MultipartIncorrectError* = object of MultipartError
|
||||||
MultiPartIncompleteError* = object of MultipartError
|
MultipartIncompleteError* = object of MultipartError
|
||||||
|
MultipartReadError* = object of MultipartError
|
||||||
|
|
||||||
BChar* = byte | char
|
BChar* = byte | char
|
||||||
|
|
||||||
|
proc newMultipartReadError(msg: string): ref MultipartReadError =
|
||||||
|
newException(MultipartReadError, msg)
|
||||||
|
|
||||||
proc startsWith*(s, prefix: openarray[byte]): bool =
|
proc startsWith*(s, prefix: openarray[byte]): bool =
|
||||||
var i = 0
|
var i = 0
|
||||||
while true:
|
while true:
|
||||||
|
@ -100,60 +109,106 @@ proc init*[B: BChar](mpt: typedesc[MultiPartReader],
|
||||||
|
|
||||||
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||||
doAssert(mpr.kind == MultiPartSource.Stream)
|
doAssert(mpr.kind == MultiPartSource.Stream)
|
||||||
# According to RFC1521 says that a boundary "must be no longer than 70
|
|
||||||
# characters, not counting the two leading hyphens.
|
|
||||||
if mpr.firstTime:
|
if mpr.firstTime:
|
||||||
|
try:
|
||||||
# Read and verify initial <-><-><boundary><CR><LF>
|
# Read and verify initial <-><-><boundary><CR><LF>
|
||||||
mpr.firstTime = false
|
|
||||||
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
|
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
|
||||||
|
mpr.firstTime = false
|
||||||
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
if mpr.buffer[0] == byte('-') and mpr.buffer[1] == byte('-'):
|
||||||
raise newException(MultiPartEOMError, "Unexpected EOM encountered")
|
raise newException(MultiPartEOMError,
|
||||||
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8:
|
"Unexpected EOM encountered")
|
||||||
|
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
|
||||||
raise newException(MultiPartIncorrectError,
|
raise newException(MultiPartIncorrectError,
|
||||||
"Unexpected boundary suffix")
|
"Unexpected boundary suffix")
|
||||||
else:
|
else:
|
||||||
raise newException(MultiPartIncorrectError,
|
raise newException(MultiPartIncorrectError,
|
||||||
"Unexpected boundary encountered")
|
"Unexpected boundary encountered")
|
||||||
|
except CancelledError as exc:
|
||||||
|
raise exc
|
||||||
|
except AsyncStreamIncompleteError:
|
||||||
|
raise newMultipartReadError("Error reading multipart message")
|
||||||
|
except AsyncStreamReadError:
|
||||||
|
raise newMultipartReadError("Error reading multipart message")
|
||||||
|
|
||||||
# Reading part's headers
|
# Reading part's headers
|
||||||
|
try:
|
||||||
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
|
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
|
||||||
HeadersMark)
|
HeadersMark)
|
||||||
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1))
|
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
|
||||||
if headersList.failed():
|
if headersList.failed():
|
||||||
raise newException(MultiPartIncorrectError, "Incorrect part headers found")
|
raise newException(MultiPartIncorrectError,
|
||||||
|
"Incorrect part headers found")
|
||||||
|
|
||||||
var part = MultiPart()
|
var part = MultiPart(
|
||||||
|
kind: MultiPartSource.Stream,
|
||||||
|
headers: HttpTable.init(),
|
||||||
|
stream: newBoundedStreamReader(mpr.stream, -1, mpr.boundary)
|
||||||
|
)
|
||||||
|
|
||||||
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4)
|
for k, v in headersList.headers(mpr.buffer.toOpenArray(0, res - 1)):
|
||||||
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
part.headers.add(k, v)
|
||||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
|
||||||
await mpr.stream.readExactly(addr buffer[0], 2)
|
|
||||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
|
||||||
raise newException(MultiPartEOMError, "")
|
|
||||||
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
|
|
||||||
|
|
||||||
except:
|
return part
|
||||||
discard
|
|
||||||
# if mpr.offset >= len(mpr.buffer):
|
|
||||||
# raise newException(MultiPartEOMError, "End of multipart form encountered")
|
|
||||||
|
|
||||||
proc getStream*(mp: MultiPart): AsyncStreamReader =
|
except CancelledError as exc:
|
||||||
mp.stream
|
raise exc
|
||||||
|
except AsyncStreamIncompleteError:
|
||||||
|
raise newMultipartReadError("Error reading multipart message")
|
||||||
|
except AsyncStreamLimitError:
|
||||||
|
raise newMultipartReadError("Multipart message headers size too big")
|
||||||
|
except AsyncStreamReadError:
|
||||||
|
raise newMultipartReadError("Error reading multipart message")
|
||||||
|
|
||||||
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
|
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
|
||||||
|
case mp.kind
|
||||||
|
of MultiPartSource.Stream:
|
||||||
try:
|
try:
|
||||||
let res = await mp.stream.read()
|
let res = await mp.stream.read()
|
||||||
return res
|
return res
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
raise newException(HttpCriticalError, "Could not read multipart body")
|
raise newException(HttpCriticalError, "Could not read multipart body")
|
||||||
|
of MultiPartSource.Buffer:
|
||||||
|
return mp.buffer
|
||||||
|
|
||||||
proc consumeBody*(mp: MultiPart) {.async.} =
|
proc consumeBody*(mp: MultiPart) {.async.} =
|
||||||
|
case mp.kind
|
||||||
|
of MultiPartSource.Stream:
|
||||||
try:
|
try:
|
||||||
await mp.stream.consume()
|
await mp.stream.consume()
|
||||||
except AsyncStreamError:
|
except AsyncStreamError:
|
||||||
raise newException(HttpCriticalError, "Could not consume multipart body")
|
raise newException(HttpCriticalError, "Could not consume multipart body")
|
||||||
|
of MultiPartSource.Buffer:
|
||||||
|
discard
|
||||||
|
|
||||||
|
proc getBytes*(mp: MultiPart): seq[byte] =
|
||||||
|
## Returns MultiPart value as sequence of bytes.
|
||||||
|
case mp.kind
|
||||||
|
of MultiPartSource.Buffer:
|
||||||
|
mp.buffer
|
||||||
|
of MultiPartSource.Stream:
|
||||||
|
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
|
||||||
|
mp.buffer
|
||||||
|
|
||||||
|
proc getString*(mp: MultiPart): string =
|
||||||
|
## Returns MultiPart value as string.
|
||||||
|
case mp.kind
|
||||||
|
of MultiPartSource.Buffer:
|
||||||
|
if len(mp.buffer) > 0:
|
||||||
|
var res = newString(len(mp.buffer))
|
||||||
|
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
|
||||||
|
res
|
||||||
|
else:
|
||||||
|
""
|
||||||
|
of MultiPartSource.Stream:
|
||||||
|
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
|
||||||
|
if len(mp.buffer) > 0:
|
||||||
|
var res = newString(len(mp.buffer))
|
||||||
|
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
|
||||||
|
res
|
||||||
|
else:
|
||||||
|
""
|
||||||
|
|
||||||
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
||||||
doAssert(mpr.kind == MultiPartSource.Buffer)
|
doAssert(mpr.kind == MultiPartSource.Buffer)
|
||||||
|
@ -215,8 +270,11 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
||||||
|
|
||||||
# We set reader's offset to the place right after <CR><LF>
|
# We set reader's offset to the place right after <CR><LF>
|
||||||
mpr.offset = start + pos2 + 2
|
mpr.offset = start + pos2 + 2
|
||||||
|
var part = MultiPart(
|
||||||
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init())
|
kind: MultiPartSource.Buffer,
|
||||||
|
headers: HttpTable.init(),
|
||||||
|
buffer: @(mpr.buffer.toOpenArray(start, start + pos2 - 1))
|
||||||
|
)
|
||||||
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
|
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
|
||||||
part.headers.add(k, v)
|
part.headers.add(k, v)
|
||||||
ok(part)
|
ok(part)
|
||||||
|
@ -255,7 +313,7 @@ proc boundaryValue2(c: char): bool =
|
||||||
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
|
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
|
||||||
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
|
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
|
||||||
|
|
||||||
func getMultipartBoundary*(contentType: string): Result[string, string] =
|
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||||
let mparts = contentType.split(";")
|
let mparts = contentType.split(";")
|
||||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||||
return err("Content-Type is not multipart")
|
return err("Content-Type is not multipart")
|
||||||
|
@ -270,7 +328,7 @@ func getMultipartBoundary*(contentType: string): Result[string, string] =
|
||||||
else:
|
else:
|
||||||
ok(strip(bparts[1]))
|
ok(strip(bparts[1]))
|
||||||
|
|
||||||
func getContentType*(contentHeader: seq[string]): Result[string, string] =
|
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
|
||||||
if len(contentHeader) > 1:
|
if len(contentHeader) > 1:
|
||||||
return err("Multiple Content-Header values found")
|
return err("Multiple Content-Header values found")
|
||||||
let mparts = contentHeader[0].split(";")
|
let mparts = contentHeader[0].split(";")
|
||||||
|
|
Loading…
Reference in New Issue