Prepare for HttpResponse.

This commit is contained in:
cheatfate 2021-01-27 21:39:14 +02:00 committed by zah
parent 60e5396a9e
commit 0e5ea5b737
3 changed files with 285 additions and 181 deletions

View File

@ -6,8 +6,8 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import stew/results, httputils
export results, httputils
import stew/results, httputils, strutils, uri
export results, httputils, strutils
const
useChroniclesLogging* {.booldefine.} = false
@ -19,9 +19,110 @@ type
HttpResultCode*[T] = Result[T, HttpCode]
HttpError* = object of CatchableError
HttpCriticalFailure* = object of HttpError
HttpRecoverableFailure* = object of HttpError
HttpCriticalError* = object of HttpError
HttpRecoverableError* = object of HttpError
TransferEncodingFlags* {.pure.} = enum
Identity, Chunked, Compress, Deflate, Gzip
ContentEncodingFlags* {.pure.} = enum
Identity, Br, Compress, Deflate, Gzip
template log*(body: untyped) =
when defined(useChroniclesLogging):
body
proc newHttpCriticalError*(msg: string): ref HttpCriticalError =
newException(HttpCriticalError, msg)
proc newHttpRecoverableError*(msg: string): ref HttpRecoverableError =
newException(HttpRecoverableError, msg)
iterator queryParams*(query: string): tuple[key: string, value: string] =
## Iterate over url-encoded query string.
for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1)
let k = items[0]
let v = if len(items) > 1: items[1] else: ""
yield (decodeUrl(k), decodeUrl(v))
func getTransferEncoding*(ch: openarray[string]): HttpResult[
set[TransferEncodingFlags]] =
## Parse value of multiple HTTP headers ``Transfer-Encoding`` and return
## it as set of ``TransferEncodingFlags``.
var res: set[TransferEncodingFlags] = {}
if len(ch) == 0:
res.incl(TransferEncodingFlags.Identity)
ok(res)
else:
for header in ch:
for item in header.split(","):
case strip(item.toLowerAscii())
of "identity":
res.incl(TransferEncodingFlags.Identity)
of "chunked":
res.incl(TransferEncodingFlags.Chunked)
of "compress":
res.incl(TransferEncodingFlags.Compress)
of "deflate":
res.incl(TransferEncodingFlags.Deflate)
of "gzip":
res.incl(TransferEncodingFlags.Gzip)
of "":
res.incl(TransferEncodingFlags.Identity)
else:
return err("Incorrect Transfer-Encoding value")
ok(res)
func getContentEncoding*(ch: openarray[string]): HttpResult[
set[ContentEncodingFlags]] =
## Parse value of multiple HTTP headers ``Content-Encoding`` and return
## it as set of ``ContentEncodingFlags``.
var res: set[ContentEncodingFlags] = {}
if len(ch) == 0:
res.incl(ContentEncodingFlags.Identity)
ok(res)
else:
for header in ch:
for item in header.split(","):
case strip(item.toLowerAscii()):
of "identity":
res.incl(ContentEncodingFlags.Identity)
of "br":
res.incl(ContentEncodingFlags.Br)
of "compress":
res.incl(ContentEncodingFlags.Compress)
of "deflate":
res.incl(ContentEncodingFlags.Deflate)
of "gzip":
res.incl(ContentEncodingFlags.Gzip)
of "":
res.incl(ContentEncodingFlags.Identity)
else:
return err("Incorrect Content-Encoding value")
ok(res)
func getContentType*(ch: openarray[string]): HttpResult[string] =
## Check and prepare value of ``Content-Type`` header.
if len(ch) > 1:
err("Multiple Content-Type values found")
else:
let mparts = ch[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
func getMultipartBoundary*(contentType: string): HttpResult[string] =
## Process ``multipart/form-data`` ``Content-Type`` header and return
## multipart boundary.
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))

View File

@ -21,15 +21,25 @@ type
HttpServerFlags* = enum
Secure
TransferEncodingFlags* {.pure.} = enum
Identity, Chunked, Compress, Deflate, Gzip
HttpConnectionStatus* = enum
DropConnection, KeepConnection
ContentEncodingFlags* {.pure.} = enum
Identity, Br, Compress, Deflate, Gzip
HttpErrorEnum* = enum
TimeoutError, CatchableError, RecoverableError, CriticalError
HttpProcessError* = object
error*: HttpErrorEnum
exc*: HttpError
remote*: TransportAddress
HttpProcessStatus*[T] = Result[T, HttpProcessError]
HttpRequestFlags* {.pure.} = enum
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
HttpProcessCallback* =
proc(request: HttpProcessStatus[HttpRequest]): Future[HttpStatus]
HttpServer* = ref object of RootRef
instance*: StreamServer
# semaphore*: AsyncSemaphore
@ -43,6 +53,7 @@ type
bodyTimeout: Duration
maxHeadersSize: int
maxRequestBodySize: int
processCallback: HttpProcessCallback
HttpServerState* = enum
ServerRunning, ServerStopped, ServerClosed
@ -64,18 +75,24 @@ type
connection*: HttpConnection
mainReader*: AsyncStreamReader
HttpResponse* = object
code*: HttpCode
version*: HttpVersion
headersTable: HttpTable
body*: seq[byte]
connection*: HttpConnection
mainWriter: AsyncStreamWriter
HttpConnection* = ref object of RootRef
server: HttpServer
server*: HttpServer
transp: StreamTransport
buffer: seq[byte]
const
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
proc new*(htype: typedesc[HttpServer],
address: TransportAddress,
flags: set[HttpServerFlags] = {},
serverUri = Uri(),
processCallback: HttpProcessCallback,
maxConnections: int = -1,
bufferSize: int = 4096,
backlogSize: int = 100,
@ -88,7 +105,8 @@ proc new*(htype: typedesc[HttpServer],
headersTimeout: httpHeadersTimeout,
bodyTimeout: httpBodyTimeout,
maxHeadersSize: maxHeadersSize,
maxRequestBodySize: maxRequestBodySize
maxRequestBodySize: maxRequestBodySize,
processCallback: processCallback
)
res.baseUri =
@ -118,92 +136,6 @@ proc getId(transp: StreamTransport): string {.inline.} =
## Returns string unique transport's identifier as string.
$transp.remoteAddress() & "_" & $transp.localAddress()
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
newException(HttpCriticalFailure, msg)
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
newException(HttpRecoverableFailure, msg)
iterator queryParams(query: string): tuple[key: string, value: string] =
for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1)
let k = items[0]
let v = if len(items) > 1: items[1] else: ""
yield (decodeUrl(k), decodeUrl(v))
func getMultipartBoundary*(contentType: string): HttpResult[string] =
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
if len(contentHeader) > 1:
return err("Multiple Content-Type values found")
let mparts = contentHeader[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
set[TransferEncodingFlags]] =
var res: set[TransferEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(TransferEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii())
of "identity":
res.incl(TransferEncodingFlags.Identity)
of "chunked":
res.incl(TransferEncodingFlags.Chunked)
of "compress":
res.incl(TransferEncodingFlags.Compress)
of "deflate":
res.incl(TransferEncodingFlags.Deflate)
of "gzip":
res.incl(TransferEncodingFlags.Gzip)
of "":
res.incl(TransferEncodingFlags.Identity)
else:
return err("Incorrect Transfer-Encoding value")
ok(res)
func getContentEncoding(contentHeader: seq[string]): HttpResult[
set[ContentEncodingFlags]] =
var res: set[ContentEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(ContentEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii()):
of "identity":
res.incl(ContentEncodingFlags.Identity)
of "br":
res.incl(ContentEncodingFlags.Br)
of "compress":
res.incl(ContentEncodingFlags.Compress)
of "deflate":
res.incl(ContentEncodingFlags.Deflate)
of "gzip":
res.incl(ContentEncodingFlags.Gzip)
of "":
res.incl(ContentEncodingFlags.Identity)
else:
return err("Incorrect Content-Encoding value")
ok(res)
proc hasBody*(request: HttpRequest): bool =
## Returns ``true`` if request has body.
request.requestFlags * {HttpRequestFlags.BoundBody,
@ -333,7 +265,7 @@ proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
try:
return await read(res.get())
except AsyncStreamError:
raise newHttpCriticalFailure("Read failure")
raise newHttpCriticalError("Read Error")
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
## Consume/discard request's body.
@ -346,7 +278,7 @@ proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
discard await reader.consume()
return
except AsyncStreamError:
raise newHttpCriticalFailure("Read failure")
raise newHttpCriticalError("Read Error")
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
code: HttpCode, keepAlive = true,
@ -374,17 +306,13 @@ proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
except CatchableError:
return false
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true,
datatype = "text/text",
databody = ""): Future[bool] =
proc sendErrorResponse*(request: HttpRequest, code: HttpCode, keepAlive = true,
datatype = "text/text",
databody = ""): Future[bool] =
sendErrorResponse(request.connection, request.version, code, keepAlive,
datatype, databody)
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $conn.transp.remoteAddress
try:
conn.buffer.setLen(conn.server.maxHeadersSize)
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
@ -392,42 +320,62 @@ proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
conn.buffer.setLen(res)
let header = parseRequest(conn.buffer)
if header.failed():
log debug "Malformed header received"
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Malformed request recieved")
raise newHttpCriticalError("Malformed request recieved")
else:
let res = prepareRequest(conn, header)
if res.isErr():
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Invalid request received")
raise newHttpCriticalError("Invalid request received")
else:
return res.get()
except TransportOsError:
log debug "Unexpected OS error"
raise newHttpCriticalFailure("Unexpected OS error")
raise newHttpCriticalError("Unexpected OS error")
except TransportIncompleteError:
log debug "Remote peer disconnected"
raise newHttpCriticalFailure("Remote peer disconnected")
raise newHttpCriticalError("Remote peer disconnected")
except TransportLimitError:
log debug "Maximum size of request headers reached"
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
raise newHttpCriticalFailure("Maximum size of request headers reached")
raise newHttpCriticalError("Maximum size of request headers reached")
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $transp.remoteAddress
var conn = HttpConnection(
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
server: server
)
log info "Client connected"
var breakLoop = false
while true:
var status: HttpProcessStatus
var arg: HttpProcessStatus[HttpRequest]
try:
let request = await conn.getRequest().wait(server.headersTimeout)
arg = ok(request)
except AsyncTimeoutError as exc:
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
breakLoop = true
arg = err(HttpProcessError(exc: exc, remote: transp.remoteAddress()))
except CancelledError:
breakLoop = true
except HttpRecoverableError:
breakLoop = false
arg = err()
except CatchableError:
breakLoop = true
if breakLoop:
break
breakLoop = false
let status =
try:
await conn.server.processCallback()
except CancelledError:
breakLoop = true
HttpCriticalError
except CatchableError:
breakLoop = true
HttpCriticalError
echo "== HEADERS TABLE"
echo request.headersTable
echo "== QUERY TABLE"
@ -439,19 +387,17 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
echo cast[string](stream)
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
databody = "OK")
log debug "Response sent"
except AsyncTimeoutError:
log debug "Timeout reached while reading headers"
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
breakLoop = true
except CancelledError:
breakLoop = true
except HttpRecoverableFailure:
except HttpRecoverableError:
breakLoop = false
except HttpCriticalFailure:
except HttpCriticalError:
breakLoop = true
except CatchableError as exc:
@ -464,7 +410,6 @@ proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
server.connections.del(transp.getId())
# if server.maxConnections > 0:
# server.semaphore.release()
log info "Client got disconnected"
proc acceptClientLoop(server: HttpServer) {.async.} =
var breakLoop = false

View File

@ -8,9 +8,11 @@
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[monotimes, strutils]
import chronos, stew/results
import stew/results
import ../../asyncloop
import ../../streams/[asyncstream, boundstream, chunkstream]
import httptable, httpcommon
export httptable, httpcommon
export httptable, httpcommon, asyncstream
type
MultiPartSource {.pure.} = enum
@ -20,9 +22,9 @@ type
case kind: MultiPartSource
of MultiPartSource.Stream:
stream: AsyncStreamReader
last: BoundedAsyncStreamReader
of MultiPartSource.Buffer:
discard
firstTime: bool
buffer: seq[byte]
offset: int
boundary: seq[byte]
@ -30,18 +32,25 @@ type
MultiPartReaderRef* = ref MultiPartReader
MultiPart* = object
case kind: MultiPartSource
of MultiPartSource.Stream:
stream*: BoundedStreamReader
of MultiPartSource.Buffer:
discard
buffer: seq[byte]
headers: HttpTable
stream: BoundedAsyncStreamReader
offset: int
size: int
MultipartError* = object of HttpError
MultipartError* = object of HttpCriticalError
MultipartEOMError* = object of MultipartError
MultiPartIncorrectError* = object of MultipartError
MultiPartIncompleteError* = object of MultipartError
MultipartIncorrectError* = object of MultipartError
MultipartIncompleteError* = object of MultipartError
MultipartReadError* = object of MultipartError
BChar* = byte | char
proc newMultipartReadError(msg: string): ref MultipartReadError =
newException(MultipartReadError, msg)
proc startsWith*(s, prefix: openarray[byte]): bool =
var i = 0
while true:
@ -100,60 +109,106 @@ proc init*[B: BChar](mpt: typedesc[MultiPartReader],
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
doAssert(mpr.kind == MultiPartSource.Stream)
# According to RFC1521 says that a boundary "must be no longer than 70
# characters, not counting the two leading hyphens.
if mpr.firstTime:
# Read and verify initial <-><-><boundary><CR><LF>
mpr.firstTime = false
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
if buffer[0] == byte('-') and buffer[1] == byte("-"):
raise newException(MultiPartEOMError, "Unexpected EOM encountered")
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8:
try:
# Read and verify initial <-><-><boundary><CR><LF>
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
mpr.firstTime = false
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
if mpr.buffer[0] == byte('-') and mpr.buffer[1] == byte('-'):
raise newException(MultiPartEOMError,
"Unexpected EOM encountered")
if mpr.buffer[0] != 0x0D'u8 or mpr.buffer[1] != 0x0A'u8:
raise newException(MultiPartIncorrectError,
"Unexpected boundary suffix")
else:
raise newException(MultiPartIncorrectError,
"Unexpected boundary suffix")
else:
raise newException(MultiPartIncorrectError,
"Unexpected boundary encountered")
"Unexpected boundary encountered")
except CancelledError as exc:
raise exc
except AsyncStreamIncompleteError:
raise newMultipartReadError("Error reading multipart message")
except AsyncStreamReadError:
raise newMultipartReadError("Error reading multipart message")
# Reading part's headers
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
try:
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
HeadersMark)
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1))
if headersList.failed():
raise newException(MultiPartIncorrectError, "Incorrect part headers found")
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1), false)
if headersList.failed():
raise newException(MultiPartIncorrectError,
"Incorrect part headers found")
var part = MultiPart()
var part = MultiPart(
kind: MultiPartSource.Stream,
headers: HttpTable.init(),
stream: newBoundedStreamReader(mpr.stream, -1, mpr.boundary)
)
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4)
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
await mpr.stream.readExactly(addr buffer[0], 2)
if buffer[0] == byte('-') and buffer[1] == byte("-"):
raise newException(MultiPartEOMError, "")
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
for k, v in headersList.headers(mpr.buffer.toOpenArray(0, res - 1)):
part.headers.add(k, v)
except:
discard
# if mpr.offset >= len(mpr.buffer):
# raise newException(MultiPartEOMError, "End of multipart form encountered")
return part
proc getStream*(mp: MultiPart): AsyncStreamReader =
mp.stream
except CancelledError as exc:
raise exc
except AsyncStreamIncompleteError:
raise newMultipartReadError("Error reading multipart message")
except AsyncStreamLimitError:
raise newMultipartReadError("Multipart message headers size too big")
except AsyncStreamReadError:
raise newMultipartReadError("Error reading multipart message")
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
try:
let res = await mp.stream.read()
return res
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not read multipart body")
case mp.kind
of MultiPartSource.Stream:
try:
let res = await mp.stream.read()
return res
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not read multipart body")
of MultiPartSource.Buffer:
return mp.buffer
proc consumeBody*(mp: MultiPart) {.async.} =
try:
await mp.stream.consume()
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not consume multipart body")
case mp.kind
of MultiPartSource.Stream:
try:
await mp.stream.consume()
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not consume multipart body")
of MultiPartSource.Buffer:
discard
proc getBytes*(mp: MultiPart): seq[byte] =
## Returns MultiPart value as sequence of bytes.
case mp.kind
of MultiPartSource.Buffer:
mp.buffer
of MultiPartSource.Stream:
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
mp.buffer
proc getString*(mp: MultiPart): string =
## Returns MultiPart value as string.
case mp.kind
of MultiPartSource.Buffer:
if len(mp.buffer) > 0:
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
of MultiPartSource.Stream:
doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
if len(mp.buffer) > 0:
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
doAssert(mpr.kind == MultiPartSource.Buffer)
@ -215,8 +270,11 @@ proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
# We set reader's offset to the place right after <CR><LF>
mpr.offset = start + pos2 + 2
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init())
var part = MultiPart(
kind: MultiPartSource.Buffer,
headers: HttpTable.init(),
buffer: @(mpr.buffer.toOpenArray(start, start + pos2 - 1))
)
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
part.headers.add(k, v)
ok(part)
@ -255,7 +313,7 @@ proc boundaryValue2(c: char): bool =
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
func getMultipartBoundary*(contentType: string): Result[string, string] =
func getMultipartBoundary*(contentType: string): HttpResult[string] =
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
@ -270,7 +328,7 @@ func getMultipartBoundary*(contentType: string): Result[string, string] =
else:
ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): Result[string, string] =
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
if len(contentHeader) > 1:
return err("Multiple Content-Header values found")
let mparts = contentHeader[0].split(";")