Initial commit.
This commit is contained in:
parent
0b396c34d8
commit
8381a40868
|
@ -0,0 +1,27 @@
|
|||
#
|
||||
# Chronos HTTP/S common types
|
||||
# (c) Copyright 2019-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import stew/results, httputils
|
||||
export results, httputils
|
||||
|
||||
const
|
||||
useChroniclesLogging* {.booldefine.} = false
|
||||
|
||||
HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
||||
|
||||
type
|
||||
HttpResult*[T] = Result[T, string]
|
||||
HttpResultCode*[T] = Result[T, HttpCode]
|
||||
|
||||
HttpError* = object of CatchableError
|
||||
HttpCriticalFailure* = object of HttpError
|
||||
HttpRecoverableFailure* = object of HttpError
|
||||
|
||||
template log*(body: untyped) =
|
||||
when defined(useChroniclesLogging):
|
||||
body
|
Binary file not shown.
|
@ -0,0 +1,560 @@
|
|||
#
|
||||
# Chronos HTTP/S server implementation
|
||||
# (c) Copyright 2019-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[tables, options, uri, strutils]
|
||||
import stew/results, httputils
|
||||
import ../../asyncloop, ../../asyncsync
|
||||
import ../../streams/[asyncstream, boundstream, chunkstream]
|
||||
import httptable, httpcommon
|
||||
export httpcommon
|
||||
|
||||
when defined(useChroniclesLogging):
|
||||
echo "Importing chronicles"
|
||||
import chronicles
|
||||
|
||||
type
|
||||
HttpServerFlags* = enum
|
||||
Secure
|
||||
|
||||
TransferEncodingFlags* {.pure.} = enum
|
||||
Identity, Chunked, Compress, Deflate, Gzip
|
||||
|
||||
ContentEncodingFlags* {.pure.} = enum
|
||||
Identity, Br, Compress, Deflate, Gzip
|
||||
|
||||
HttpRequestFlags* {.pure.} = enum
|
||||
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
|
||||
|
||||
HttpServer* = ref object of RootRef
|
||||
instance*: StreamServer
|
||||
# semaphore*: AsyncSemaphore
|
||||
maxConnections*: int
|
||||
baseUri*: Uri
|
||||
flags*: set[HttpServerFlags]
|
||||
connections*: Table[string, Future[void]]
|
||||
acceptLoop*: Future[void]
|
||||
lifetime*: Future[void]
|
||||
headersTimeout: Duration
|
||||
bodyTimeout: Duration
|
||||
maxHeadersSize: int
|
||||
maxRequestBodySize: int
|
||||
|
||||
HttpServerState* = enum
|
||||
ServerRunning, ServerStopped, ServerClosed
|
||||
|
||||
HttpRequest* = object
|
||||
headersTable: HttpTable
|
||||
queryTable: HttpTable
|
||||
postTable: HttpTable
|
||||
rawPath*: string
|
||||
rawQuery*: string
|
||||
uri*: Uri
|
||||
scheme*: string
|
||||
version*: HttpVersion
|
||||
meth*: HttpMethod
|
||||
contentEncoding*: set[ContentEncodingFlags]
|
||||
transferEncoding*: set[TransferEncodingFlags]
|
||||
requestFlags*: set[HttpRequestFlags]
|
||||
contentLength: int
|
||||
connection*: HttpConnection
|
||||
mainReader*: AsyncStreamReader
|
||||
|
||||
HttpConnection* = ref object of RootRef
|
||||
server: HttpServer
|
||||
transp: StreamTransport
|
||||
buffer: seq[byte]
|
||||
|
||||
const
|
||||
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
|
||||
|
||||
proc new*(htype: typedesc[HttpServer],
|
||||
address: TransportAddress,
|
||||
flags: set[HttpServerFlags] = {},
|
||||
serverUri = Uri(),
|
||||
maxConnections: int = -1,
|
||||
bufferSize: int = 4096,
|
||||
backlogSize: int = 100,
|
||||
httpHeadersTimeout = 10.seconds,
|
||||
httpBodyTimeout = 30.seconds,
|
||||
maxHeadersSize: int = 8192,
|
||||
maxRequestBodySize: int = 1_048_576): HttpResult[HttpServer] =
|
||||
var res = HttpServer(
|
||||
maxConnections: maxConnections,
|
||||
headersTimeout: httpHeadersTimeout,
|
||||
bodyTimeout: httpBodyTimeout,
|
||||
maxHeadersSize: maxHeadersSize,
|
||||
maxRequestBodySize: maxRequestBodySize
|
||||
)
|
||||
|
||||
res.baseUri =
|
||||
if len(serverUri.hostname) > 0 and isAbsolute(serverUri):
|
||||
serverUri
|
||||
else:
|
||||
if HttpServerFlags.Secure in flags:
|
||||
parseUri("https://" & $address & "/")
|
||||
else:
|
||||
parseUri("http://" & $address & "/")
|
||||
|
||||
try:
|
||||
res.instance = createStreamServer(address, flags = {ReuseAddr},
|
||||
bufferSize = bufferSize,
|
||||
backlog = backlogSize)
|
||||
# if maxConnections > 0:
|
||||
# res.semaphore = newAsyncSemaphore(maxConnections)
|
||||
res.lifetime = newFuture[void]("http.server.lifetime")
|
||||
res.connections = initTable[string, Future[void]]()
|
||||
return ok(res)
|
||||
except TransportOsError as exc:
|
||||
return err(exc.msg)
|
||||
except CatchableError as exc:
|
||||
return err(exc.msg)
|
||||
|
||||
proc getId(transp: StreamTransport): string {.inline.} =
|
||||
## Returns string unique transport's identifier as string.
|
||||
$transp.remoteAddress() & "_" & $transp.localAddress()
|
||||
|
||||
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
|
||||
newException(HttpCriticalFailure, msg)
|
||||
|
||||
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
|
||||
newException(HttpRecoverableFailure, msg)
|
||||
|
||||
iterator queryParams(query: string): tuple[key: string, value: string] =
|
||||
for pair in query.split('&'):
|
||||
let items = pair.split('=', maxsplit = 1)
|
||||
let k = items[0]
|
||||
let v = if len(items) > 1: items[1] else: ""
|
||||
yield (decodeUrl(k), decodeUrl(v))
|
||||
|
||||
func getMultipartBoundary*(contentType: string): HttpResult[string] =
|
||||
let mparts = contentType.split(";")
|
||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||
return err("Content-Type is not multipart")
|
||||
if len(mparts) < 2:
|
||||
return err("Content-Type missing boundary value")
|
||||
let stripped = strip(mparts[1])
|
||||
if not(stripped.toLowerAscii().startsWith("boundary")):
|
||||
return err("Incorrect Content-Type boundary format")
|
||||
let bparts = stripped.split("=")
|
||||
if len(bparts) < 2:
|
||||
err("Missing Content-Type boundary")
|
||||
else:
|
||||
ok(strip(bparts[1]))
|
||||
|
||||
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
|
||||
if len(contentHeader) > 1:
|
||||
return err("Multiple Content-Type values found")
|
||||
let mparts = contentHeader[0].split(";")
|
||||
ok(strip(mparts[0]).toLowerAscii())
|
||||
|
||||
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
|
||||
set[TransferEncodingFlags]] =
|
||||
var res: set[TransferEncodingFlags] = {}
|
||||
if len(contentHeader) == 0:
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in contentHeader:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii())
|
||||
of "identity":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
of "chunked":
|
||||
res.incl(TransferEncodingFlags.Chunked)
|
||||
of "compress":
|
||||
res.incl(TransferEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(TransferEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(TransferEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(TransferEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Transfer-Encoding value")
|
||||
ok(res)
|
||||
|
||||
func getContentEncoding(contentHeader: seq[string]): HttpResult[
|
||||
set[ContentEncodingFlags]] =
|
||||
var res: set[ContentEncodingFlags] = {}
|
||||
if len(contentHeader) == 0:
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
ok(res)
|
||||
else:
|
||||
for header in contentHeader:
|
||||
for item in header.split(","):
|
||||
case strip(item.toLowerAscii()):
|
||||
of "identity":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
of "br":
|
||||
res.incl(ContentEncodingFlags.Br)
|
||||
of "compress":
|
||||
res.incl(ContentEncodingFlags.Compress)
|
||||
of "deflate":
|
||||
res.incl(ContentEncodingFlags.Deflate)
|
||||
of "gzip":
|
||||
res.incl(ContentEncodingFlags.Gzip)
|
||||
of "":
|
||||
res.incl(ContentEncodingFlags.Identity)
|
||||
else:
|
||||
return err("Incorrect Content-Encoding value")
|
||||
ok(res)
|
||||
|
||||
proc hasBody*(request: HttpRequest): bool =
|
||||
## Returns ``true`` if request has body.
|
||||
request.requestFlags * {HttpRequestFlags.BoundBody,
|
||||
HttpRequestFlags.UnboundBody} != {}
|
||||
|
||||
proc prepareRequest(conn: HttpConnection,
|
||||
req: HttpRequestHeader): HttpResultCode[HttpRequest] =
|
||||
var request = HttpRequest()
|
||||
|
||||
if req.version notin {HttpVersion10, HttpVersion11}:
|
||||
return err(Http505)
|
||||
|
||||
request.version = req.version
|
||||
request.meth = req.meth
|
||||
|
||||
request.rawPath =
|
||||
block:
|
||||
let res = req.uri()
|
||||
if len(res) == 0:
|
||||
return err(Http400)
|
||||
res
|
||||
|
||||
request.uri =
|
||||
if request.rawPath != "*":
|
||||
let uri = parseUri(request.rawPath)
|
||||
if uri.scheme notin ["http", "https", ""]:
|
||||
return err(Http400)
|
||||
uri
|
||||
else:
|
||||
var uri = initUri()
|
||||
uri.path = "*"
|
||||
uri
|
||||
|
||||
request.queryTable =
|
||||
block:
|
||||
var table = HttpTable.init()
|
||||
for key, value in queryParams(request.uri.query):
|
||||
table.add(key, value)
|
||||
table
|
||||
|
||||
request.headersTable =
|
||||
block:
|
||||
var table = HttpTable.init()
|
||||
# Retrieve headers and values
|
||||
for key, value in req.headers():
|
||||
table.add(key, value)
|
||||
# Validating HTTP request headers
|
||||
# Some of the headers must be present only once.
|
||||
if table.count("content-type") > 1:
|
||||
return err(Http400)
|
||||
if table.count("content-length") > 1:
|
||||
return err(Http400)
|
||||
if table.count("transfer-encoding") > 1:
|
||||
return err(Http400)
|
||||
table
|
||||
|
||||
# Preprocessing "Content-Encoding" header.
|
||||
request.contentEncoding =
|
||||
block:
|
||||
let res = getContentEncoding(
|
||||
request.headersTable.getList("content-encoding"))
|
||||
if res.isErr():
|
||||
return err(Http400)
|
||||
else:
|
||||
res.get()
|
||||
|
||||
# Preprocessing "Transfer-Encoding" header.
|
||||
request.transferEncoding =
|
||||
block:
|
||||
let res = getTransferEncoding(
|
||||
request.headersTable.getList("transfer-encoding"))
|
||||
if res.isErr():
|
||||
return err(Http400)
|
||||
else:
|
||||
res.get()
|
||||
|
||||
# Almost all HTTP requests could have body (except TRACE), we perform some
|
||||
# steps to reveal information about body.
|
||||
if "content-length" in request.headersTable:
|
||||
let length = request.headersTable.getInt("content-length")
|
||||
if length > 0:
|
||||
if request.meth == MethodTrace:
|
||||
return err(Http400)
|
||||
if length > uint64(high(int)):
|
||||
return err(Http413)
|
||||
if length > uint64(conn.server.maxRequestBodySize):
|
||||
return err(Http413)
|
||||
request.contentLength = int(length)
|
||||
request.requestFlags.incl(HttpRequestFlags.BoundBody)
|
||||
else:
|
||||
if TransferEncodingFlags.Chunked in request.transferEncoding:
|
||||
if request.meth == MethodTrace:
|
||||
return err(Http400)
|
||||
request.requestFlags.incl(HttpRequestFlags.UnboundBody)
|
||||
|
||||
if request.hasBody():
|
||||
# If request has body, we going to understand how its encoded.
|
||||
const
|
||||
UrlEncodedType = "application/x-www-form-urlencoded"
|
||||
MultipartType = "multipart/form-data"
|
||||
|
||||
if "content-type" in request.headersTable:
|
||||
let contentType = request.headersTable.getString("content-type")
|
||||
let tmp = strip(contentType).toLowerAscii()
|
||||
if tmp.startsWith(UrlEncodedType):
|
||||
request.requestFlags.incl(UrlencodedForm)
|
||||
elif tmp.startsWith(MultipartType):
|
||||
request.requestFlags.incl(MultipartForm)
|
||||
|
||||
request.mainReader = newAsyncStreamReader(conn.transp)
|
||||
ok(request)
|
||||
|
||||
proc getBodyStream*(request: HttpRequest): HttpResult[AsyncStreamReader] =
|
||||
if HttpRequestFlags.BoundBody in request.requestFlags:
|
||||
ok(newBoundedStreamReader(request.mainReader, request.contentLength))
|
||||
elif HttpRequestFlags.UnboundBody in request.requestFlags:
|
||||
ok(newChunkedStreamReader(request.mainReader))
|
||||
else:
|
||||
err("Request do not have body available")
|
||||
|
||||
proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
|
||||
## Obtain request's body as sequence of bytes.
|
||||
let res = request.getBodyStream()
|
||||
if res.isErr():
|
||||
return @[]
|
||||
else:
|
||||
try:
|
||||
return await read(res.get())
|
||||
except AsyncStreamError:
|
||||
raise newHttpCriticalFailure("Read failure")
|
||||
|
||||
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
|
||||
## Consume/discard request's body.
|
||||
let res = request.getBodyStream()
|
||||
if res.isErr():
|
||||
return
|
||||
else:
|
||||
let reader = res.get()
|
||||
try:
|
||||
discard await reader.consume()
|
||||
return
|
||||
except AsyncStreamError:
|
||||
raise newHttpCriticalFailure("Read failure")
|
||||
|
||||
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
|
||||
code: HttpCode, keepAlive = true,
|
||||
datatype = "text/text",
|
||||
databody = ""): Future[bool] {.async.} =
|
||||
var answer = $version & " " & $code & "\r\n"
|
||||
answer.add("Date: " & httpDate() & "\r\n")
|
||||
if len(databody) > 0:
|
||||
answer.add("Content-Type: " & datatype & "\r\n")
|
||||
answer.add("Content-Length: " & $len(databody) & "\r\n")
|
||||
if keepAlive:
|
||||
answer.add("Connection: keep-alive\r\n")
|
||||
else:
|
||||
answer.add("Connection: close\r\n")
|
||||
answer.add("\r\n")
|
||||
if len(databody) > 0:
|
||||
answer.add(databody)
|
||||
try:
|
||||
let res {.used.} = await conn.transp.write(answer)
|
||||
return true
|
||||
except CancelledError:
|
||||
return false
|
||||
except TransportOsError:
|
||||
return false
|
||||
except CatchableError:
|
||||
return false
|
||||
|
||||
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true,
|
||||
datatype = "text/text",
|
||||
databody = ""): Future[bool] =
|
||||
sendErrorResponse(request.connection, request.version, code, keepAlive,
|
||||
datatype, databody)
|
||||
|
||||
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
|
||||
when defined(useChroniclesLogging):
|
||||
logScope:
|
||||
peer = $conn.transp.remoteAddress
|
||||
|
||||
try:
|
||||
conn.buffer.setLen(conn.server.maxHeadersSize)
|
||||
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
|
||||
HeadersMark)
|
||||
conn.buffer.setLen(res)
|
||||
let header = parseRequest(conn.buffer)
|
||||
if header.failed():
|
||||
log debug "Malformed header received"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||
raise newHttpCriticalFailure("Malformed request recieved")
|
||||
else:
|
||||
let res = prepareRequest(conn, header)
|
||||
if res.isErr():
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
|
||||
raise newHttpCriticalFailure("Invalid request received")
|
||||
else:
|
||||
return res.get()
|
||||
except TransportOsError:
|
||||
log debug "Unexpected OS error"
|
||||
raise newHttpCriticalFailure("Unexpected OS error")
|
||||
except TransportIncompleteError:
|
||||
log debug "Remote peer disconnected"
|
||||
raise newHttpCriticalFailure("Remote peer disconnected")
|
||||
except TransportLimitError:
|
||||
log debug "Maximum size of request headers reached"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
|
||||
raise newHttpCriticalFailure("Maximum size of request headers reached")
|
||||
|
||||
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
|
||||
when defined(useChroniclesLogging):
|
||||
logScope:
|
||||
peer = $transp.remoteAddress
|
||||
|
||||
var conn = HttpConnection(
|
||||
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
|
||||
server: server
|
||||
)
|
||||
|
||||
log info "Client connected"
|
||||
var breakLoop = false
|
||||
while true:
|
||||
try:
|
||||
let request = await conn.getRequest().wait(server.headersTimeout)
|
||||
echo "== HEADERS TABLE"
|
||||
echo request.headersTable
|
||||
echo "== QUERY TABLE"
|
||||
echo request.queryTable
|
||||
echo "== TRANSFER ENCODING ", request.transferEncoding
|
||||
echo "== CONTENT ENCODING ", request.contentEncoding
|
||||
echo "== REQUEST FLAGS ", request.requestFlags
|
||||
var stream = await request.getBody()
|
||||
echo cast[string](stream)
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
|
||||
databody = "OK")
|
||||
log debug "Response sent"
|
||||
except AsyncTimeoutError:
|
||||
log debug "Timeout reached while reading headers"
|
||||
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
|
||||
breakLoop = true
|
||||
|
||||
except CancelledError:
|
||||
breakLoop = true
|
||||
|
||||
except HttpRecoverableFailure:
|
||||
breakLoop = false
|
||||
|
||||
except HttpCriticalFailure:
|
||||
breakLoop = true
|
||||
|
||||
except CatchableError as exc:
|
||||
echo "CatchableError received ", exc.name
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
|
||||
await transp.closeWait()
|
||||
server.connections.del(transp.getId())
|
||||
# if server.maxConnections > 0:
|
||||
# server.semaphore.release()
|
||||
log info "Client got disconnected"
|
||||
|
||||
proc acceptClientLoop(server: HttpServer) {.async.} =
|
||||
var breakLoop = false
|
||||
while true:
|
||||
try:
|
||||
# if server.maxConnections > 0:
|
||||
# await server.semaphore.acquire()
|
||||
|
||||
let transp = await server.instance.accept()
|
||||
server.connections[transp.getId()] = processLoop(server, transp)
|
||||
|
||||
except CancelledError:
|
||||
# Server was stopped
|
||||
breakLoop = true
|
||||
except TransportOsError:
|
||||
# This is some critical unrecoverable error.
|
||||
breakLoop = true
|
||||
except TransportTooManyError:
|
||||
# Non critical error
|
||||
breakLoop = false
|
||||
except CatchableError:
|
||||
# Unexpected error
|
||||
breakLoop = true
|
||||
discard
|
||||
|
||||
if breakLoop:
|
||||
break
|
||||
|
||||
proc state*(server: HttpServer): HttpServerState =
|
||||
## Returns current HTTP server's state.
|
||||
if server.lifetime.finished():
|
||||
ServerClosed
|
||||
else:
|
||||
if isNil(server.acceptLoop):
|
||||
ServerStopped
|
||||
else:
|
||||
if server.acceptLoop.finished():
|
||||
ServerStopped
|
||||
else:
|
||||
ServerRunning
|
||||
|
||||
proc start*(server: HttpServer) =
|
||||
## Starts HTTP server.
|
||||
if server.state == ServerStopped:
|
||||
server.acceptLoop = acceptClientLoop(server)
|
||||
|
||||
proc stop*(server: HttpServer) {.async.} =
|
||||
## Stop HTTP server from accepting new connections.
|
||||
if server.state == ServerRunning:
|
||||
await server.acceptLoop.cancelAndWait()
|
||||
|
||||
proc drop*(server: HttpServer) {.async.} =
|
||||
## Drop all pending HTTP connections.
|
||||
if server.state in {ServerStopped, ServerRunning}:
|
||||
discard
|
||||
|
||||
proc close*(server: HttpServer) {.async.} =
|
||||
## Stop HTTP server and drop all the pending connections.
|
||||
if server.state != ServerClosed:
|
||||
await server.stop()
|
||||
await server.drop()
|
||||
await server.instance.closeWait()
|
||||
server.lifetime.complete()
|
||||
|
||||
proc join*(server: HttpServer): Future[void] =
|
||||
## Wait until HTTP server will not be closed.
|
||||
var retFuture = newFuture[void]("http.server.join")
|
||||
|
||||
proc continuation(udata: pointer) {.gcsafe.} =
|
||||
if not(retFuture.finished()):
|
||||
retFuture.complete()
|
||||
|
||||
proc cancellation(udata: pointer) {.gcsafe.} =
|
||||
if not(retFuture.finished()):
|
||||
server.lifetime.removeCallback(continuation, cast[pointer](retFuture))
|
||||
|
||||
if server.state == ServerClosed:
|
||||
retFuture.complete()
|
||||
else:
|
||||
server.lifetime.addCallback(continuation, cast[pointer](retFuture))
|
||||
retFuture.cancelCallback = cancellation
|
||||
|
||||
retFuture
|
||||
|
||||
when isMainModule:
|
||||
let res = HttpServer.new(initTAddress("127.0.0.1:30080"), maxConnections = 1)
|
||||
if res.isOk():
|
||||
let server = res.get()
|
||||
server.start()
|
||||
echo "HTTP server was started"
|
||||
waitFor server.join()
|
||||
else:
|
||||
echo "Failed to start server: ", res.error
|
|
@ -0,0 +1,124 @@
|
|||
#
|
||||
# Chronos HTTP/S case-insensitive non-unique
|
||||
# key-value memory storage
|
||||
# (c) Copyright 2019-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[tables, strutils]
|
||||
|
||||
type
|
||||
HttpTable* = object
|
||||
table: Table[string, seq[string]]
|
||||
|
||||
HttpTableRef* = ref HttpTable
|
||||
|
||||
HttpTables* = HttpTable | HttpTableRef
|
||||
|
||||
proc `-`(x: uint32): uint32 {.inline.} =
|
||||
(0xFFFF_FFFF'u32 - x) + 1'u32
|
||||
|
||||
proc LT(x, y: uint32): uint32 {.inline.} =
|
||||
let z = x - y
|
||||
(z xor ((y xor x) and (y xor z))) shr 31
|
||||
|
||||
proc decValue(c: byte): int =
|
||||
let x = uint32(c) - 0x30'u32
|
||||
let r = ((x + 1'u32) and -LT(x, 10))
|
||||
int(r) - 1
|
||||
|
||||
proc bytesToDec*[T: byte|char](src: openarray[T]): uint64 =
|
||||
var v = 0'u64
|
||||
for i in 0 ..< len(src):
|
||||
let d =
|
||||
when T is byte:
|
||||
decValue(src[i])
|
||||
else:
|
||||
decValue(byte(src[i]))
|
||||
if d < 0:
|
||||
# non-decimal character encountered
|
||||
return v
|
||||
else:
|
||||
let nv = ((v shl 3) + (v shl 1)) + uint64(d)
|
||||
if nv < v:
|
||||
# overflow happened
|
||||
return v
|
||||
else:
|
||||
v = nv
|
||||
v
|
||||
|
||||
proc add*(ht: var HttpTables, key: string, value: string) =
|
||||
let lowkey = key.toLowerAscii()
|
||||
var nitem = @[value]
|
||||
if ht.table.hasKeyOrPut(lowkey, nitem):
|
||||
var oitem = ht.table[lowkey]
|
||||
oitem.add(value)
|
||||
ht.table[lowkey] = oitem
|
||||
|
||||
proc add*(ht: var HttpTables, key: string, value: SomeInteger) =
|
||||
ht.add(key, $value)
|
||||
|
||||
proc contains*(ht: var HttpTables, key: string): bool =
|
||||
ht.table.contains(key.toLowerAscii())
|
||||
|
||||
proc getList*(ht: HttpTables, key: string): seq[string] =
|
||||
var default: seq[string]
|
||||
ht.table.getOrDefault(key.toLowerAscii(), default)
|
||||
|
||||
proc getString*(ht: HttpTables, key: string): string =
|
||||
var default: seq[string]
|
||||
ht.table.getOrDefault(key.toLowerAscii(), default).join(",")
|
||||
|
||||
proc count*(ht: HttpTables, key: string): int =
|
||||
var default: seq[string]
|
||||
len(ht.table.getOrDefault(key, default))
|
||||
|
||||
proc getInt*(ht: HttpTables, key: string): uint64 =
|
||||
bytesToDec(ht.getString(key))
|
||||
|
||||
proc getLastString*(ht: HttpTables, key: string): string =
|
||||
var default: seq[string]
|
||||
let item = ht.table.getOrDefault(key.toLowerAscii(), default)
|
||||
if len(item) == 0:
|
||||
""
|
||||
else:
|
||||
item[^1]
|
||||
|
||||
proc getLastInt*(ht: HttpTables, key: string): uint64 =
|
||||
bytesToDec(ht.getLastString())
|
||||
|
||||
proc init*(htt: typedesc[HttpTable]): HttpTable =
|
||||
HttpTable(table: initTable[string, seq[string]]())
|
||||
|
||||
proc new*(htt: typedesc[HttpTableRef]): HttpTableRef =
|
||||
HttpTableRef(table: initTable[string, seq[string]]())
|
||||
|
||||
proc normalizeHeaderName*(value: string): string =
|
||||
var res = value.toLowerAscii()
|
||||
var k = 0
|
||||
while k < len(res):
|
||||
if k == 0:
|
||||
res[k] = toUpperAscii(res[k])
|
||||
inc(k, 1)
|
||||
else:
|
||||
if res[k] == '-':
|
||||
if k + 1 < len(res):
|
||||
res[k + 1] = toUpperAscii(res[k + 1])
|
||||
inc(k, 2)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
inc(k, 1)
|
||||
res
|
||||
|
||||
proc `$`*(ht: HttpTables): string =
|
||||
var res = ""
|
||||
for key, value in ht.table.pairs():
|
||||
for item in value:
|
||||
res.add(key.normalizeHeaderName())
|
||||
res.add(": ")
|
||||
res.add(item)
|
||||
res.add("\p")
|
||||
res
|
|
@ -0,0 +1,286 @@
|
|||
#
|
||||
# Chronos HTTP/S multipart/form
|
||||
# encoding and decoding helper procedures
|
||||
# (c) Copyright 2019-Present
|
||||
# Status Research & Development GmbH
|
||||
#
|
||||
# Licensed under either of
|
||||
# Apache License, version 2.0, (LICENSE-APACHEv2)
|
||||
# MIT license (LICENSE-MIT)
|
||||
import std/[monotimes, strutils]
|
||||
import chronos, stew/results
|
||||
import httptable, httpcommon
|
||||
export httptable, httpcommon
|
||||
|
||||
type
|
||||
MultiPartSource {.pure.} = enum
|
||||
Stream, Buffer
|
||||
|
||||
MultiPartReader* = object
|
||||
case kind: MultiPartSource
|
||||
of MultiPartSource.Stream:
|
||||
stream: AsyncStreamReader
|
||||
last: BoundedAsyncStreamReader
|
||||
of MultiPartSource.Buffer:
|
||||
discard
|
||||
buffer: seq[byte]
|
||||
offset: int
|
||||
boundary: seq[byte]
|
||||
|
||||
MultiPartReaderRef* = ref MultiPartReader
|
||||
|
||||
MultiPart* = object
|
||||
headers: HttpTable
|
||||
stream: BoundedAsyncStreamReader
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
MultipartError* = object of HttpError
|
||||
MultipartEOMError* = object of MultipartError
|
||||
MultiPartIncorrectError* = object of MultipartError
|
||||
MultiPartIncompleteError* = object of MultipartError
|
||||
|
||||
BChar* = byte | char
|
||||
|
||||
proc startsWith*(s, prefix: openarray[byte]): bool =
|
||||
var i = 0
|
||||
while true:
|
||||
if i >= len(prefix): return true
|
||||
if i >= len(s) or s[i] != prefix[i]: return false
|
||||
inc(i)
|
||||
|
||||
proc parseUntil*(s, until: openarray[byte]): int =
|
||||
var i = 0
|
||||
while i < len(s):
|
||||
if len(until) > 0 and s[i] == until[0]:
|
||||
var u = 1
|
||||
while i + u < len(s) and u < len(until) and s[i + u] == until[u]:
|
||||
inc u
|
||||
if u >= len(until): return i
|
||||
inc(i)
|
||||
-1
|
||||
|
||||
proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader],
|
||||
buffer: openarray[A],
|
||||
boundary: openarray[B]): MultiPartReader =
|
||||
# Boundary should not be empty.
|
||||
doAssert(len(boundary) > 0)
|
||||
# Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can reuse
|
||||
# different parts of this sequence for processing.
|
||||
var fboundary = newSeq[byte](len(boundary) + 4)
|
||||
fboundary[0] = 0x0D'u8
|
||||
fboundary[1] = 0x0A'u8
|
||||
fboundary[2] = byte('-')
|
||||
fboundary[3] = byte('-')
|
||||
copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary))
|
||||
# Make copy of buffer, because all the returned parts depending on it.
|
||||
var buf = newSeq[byte](len(buffer))
|
||||
if len(buf) > 0:
|
||||
copyMem(addr buf[0], unsafeAddr buffer[0], len(buffer))
|
||||
MultiPartReader(kind: MultiPartSource.Buffer,
|
||||
buffer: buf, offset: 0, boundary: fboundary)
|
||||
|
||||
proc init*[B: BChar](mpt: typedesc[MultiPartReader],
|
||||
stream: AsyncStreamReader,
|
||||
boundary: openarray[B],
|
||||
partHeadersMaxSize = 4096): MultiPartReader =
|
||||
# Boundary should not be empty.
|
||||
doAssert(len(boundary) > 0)
|
||||
# Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can reuse
|
||||
# different parts of this sequence for processing.
|
||||
var fboundary = newSeq[byte](len(boundary) + 4)
|
||||
fboundary[0] = 0x0D'u8
|
||||
fboundary[1] = 0x0A'u8
|
||||
fboundary[2] = byte('-')
|
||||
fboundary[3] = byte('-')
|
||||
copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary))
|
||||
MultiPartReader(kind: MultiPartSource.Stream,
|
||||
stream: stream, offset: 0, boundary: fboundary,
|
||||
buffer: newSeq[byte](partHeadersMaxSize))
|
||||
|
||||
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
|
||||
doAssert(mpr.kind == MultiPartSource.Stream)
|
||||
# According to RFC1521 says that a boundary "must be no longer than 70
|
||||
# characters, not counting the two leading hyphens.
|
||||
if mpr.firstTime:
|
||||
# Read and verify initial <-><-><boundary><CR><LF>
|
||||
mpr.firstTime = false
|
||||
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
|
||||
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
||||
raise newException(MultiPartEOMError, "Unexpected EOM encountered")
|
||||
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8:
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Unexpected boundary suffix")
|
||||
else:
|
||||
raise newException(MultiPartIncorrectError,
|
||||
"Unexpected boundary encountered")
|
||||
|
||||
# Reading part's headers
|
||||
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
|
||||
HeadersMark)
|
||||
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1))
|
||||
if headersList.failed():
|
||||
raise newException(MultiPartIncorrectError, "Incorrect part headers found")
|
||||
|
||||
var part = MultiPart()
|
||||
|
||||
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4)
|
||||
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
await mpr.stream.readExactly(addr buffer[0], 2)
|
||||
if buffer[0] == byte('-') and buffer[1] == byte("-"):
|
||||
raise newException(MultiPartEOMError, "")
|
||||
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
|
||||
|
||||
except:
|
||||
discard
|
||||
# if mpr.offset >= len(mpr.buffer):
|
||||
# raise newException(MultiPartEOMError, "End of multipart form encountered")
|
||||
|
||||
proc getStream*(mp: MultiPart): AsyncStreamReader =
|
||||
mp.stream
|
||||
|
||||
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
|
||||
try:
|
||||
let res = await mp.stream.read()
|
||||
return res
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not read multipart body")
|
||||
|
||||
proc consumeBody*(mp: MultiPart) {.async.} =
|
||||
try:
|
||||
await mp.stream.consume()
|
||||
except AsyncStreamError:
|
||||
raise newException(HttpCriticalError, "Could not consume multipart body")
|
||||
|
||||
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
|
||||
doAssert(mpr.kind == MultiPartSource.Buffer)
|
||||
if mpr.offset >= len(mpr.buffer):
|
||||
return err("End of multipart form encountered")
|
||||
|
||||
if startsWith(mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1),
|
||||
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
|
||||
# Buffer must start at <-><-><boundary>
|
||||
mpr.offset += (len(mpr.boundary) - 2)
|
||||
|
||||
# After boundary there should be at least 2 symbols <-><-> or <CR><LF>.
|
||||
if len(mpr.buffer) <= mpr.offset + 1:
|
||||
return err("Incomplete multipart form")
|
||||
|
||||
if mpr.buffer[mpr.offset] == byte('-') and
|
||||
mpr.buffer[mpr.offset + 1] == byte('-'):
|
||||
# If we have <-><-><boundary><-><-> it means we have found last boundary
|
||||
# of multipart message.
|
||||
mpr.offset += 2
|
||||
return err("End of multipart form encountered")
|
||||
|
||||
if mpr.buffer[mpr.offset] == 0x0D'u8 and
|
||||
mpr.buffer[mpr.offset + 1] == 0x0A'u8:
|
||||
# If we have <-><-><boundary><CR><LF> it means that we have found another
|
||||
# part of multipart message.
|
||||
mpr.offset += 2
|
||||
# Multipart form must always have at least single Content-Disposition
|
||||
# header, so we searching position where all the headers should be
|
||||
# finished <CR><LF><CR><LF>.
|
||||
let pos1 = parseUntil(
|
||||
mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1),
|
||||
[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
|
||||
)
|
||||
|
||||
if pos1 < 0:
|
||||
return err("Incomplete multipart form")
|
||||
|
||||
# parseUntil returns 0-based position without `until` sequence.
|
||||
let start = mpr.offset + pos1 + 4
|
||||
|
||||
# Multipart headers position
|
||||
let hstart = mpr.offset
|
||||
let hfinish = mpr.offset + pos1 + 4 - 1
|
||||
|
||||
let headersList = parseHeaders(mpr.buffer.toOpenArray(hstart, hfinish),
|
||||
false)
|
||||
if headersList.failed():
|
||||
return err("Incorrect or incomplete multipart headers received")
|
||||
|
||||
# Searching for value's boundary <CR><LF><-><-><boundary>.
|
||||
let pos2 = parseUntil(
|
||||
mpr.buffer.toOpenArray(start, len(mpr.buffer) - 1),
|
||||
mpr.boundary.toOpenArray(0, len(mpr.boundary) - 1)
|
||||
)
|
||||
|
||||
if pos2 < 0:
|
||||
return err("Incomplete multipart form")
|
||||
|
||||
# We set reader's offset to the place right after <CR><LF>
|
||||
mpr.offset = start + pos2 + 2
|
||||
|
||||
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init())
|
||||
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
|
||||
part.headers.add(k, v)
|
||||
ok(part)
|
||||
else:
|
||||
err("Incorrect multipart form")
|
||||
else:
|
||||
err("Incorrect multipart form")
|
||||
|
||||
template `-`(x: uint32): uint32 =
|
||||
(0xFFFF_FFFF'u32 - x) + 1'u32
|
||||
|
||||
template LT(x, y: uint32): uint32 =
|
||||
let z = x - y
|
||||
(z xor ((y xor x) and (y xor z))) shr 31
|
||||
|
||||
proc boundaryValue(c: char): bool =
|
||||
let a0 = uint32(c) - 0x27'u32
|
||||
let a1 = uint32(c) - 0x2B'u32
|
||||
let a2 = uint32(c) - 0x3A'u32
|
||||
let a3 = uint32(c) - 0x3D'u32
|
||||
let a4 = uint32(c) - 0x3F'u32
|
||||
let a5 = uint32(c) - 0x41'u32
|
||||
let a6 = uint32(c) - 0x5F
|
||||
let a7 = uint32(c) - 0x61'u32
|
||||
let r = ((a0 + 1'u32) and -LT(a0, 3)) or
|
||||
((a1 + 1'u32) and -LT(a1, 15)) or
|
||||
((a2 + 1'u32) and -LT(a2, 1)) or
|
||||
((a3 + 1'u32) and -LT(a3, 1)) or
|
||||
((a4 + 1'u32) and -LT(a4, 1)) or
|
||||
((a5 + 1'u32) and -LT(a5, 26)) or
|
||||
((a6 + 1'u32) and -LT(a6, 1)) or
|
||||
((a7 + 1'u32) and -LT(a7, 26))
|
||||
(int(r) - 1) > 0
|
||||
|
||||
proc boundaryValue2(c: char): bool =
|
||||
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
|
||||
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
|
||||
|
||||
func getMultipartBoundary*(contentType: string): Result[string, string] =
|
||||
let mparts = contentType.split(";")
|
||||
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
|
||||
return err("Content-Type is not multipart")
|
||||
if len(mparts) < 2:
|
||||
return err("Content-Type missing boundary value")
|
||||
let stripped = strip(mparts[1])
|
||||
if not(stripped.toLowerAscii().startsWith("boundary")):
|
||||
return err("Incorrect Content-Type boundary format")
|
||||
let bparts = stripped.split("=")
|
||||
if len(bparts) < 2:
|
||||
err("Missing Content-Type boundary")
|
||||
else:
|
||||
ok(strip(bparts[1]))
|
||||
|
||||
func getContentType*(contentHeader: seq[string]): Result[string, string] =
|
||||
if len(contentHeader) > 1:
|
||||
return err("Multiple Content-Header values found")
|
||||
let mparts = contentHeader[0].split(";")
|
||||
ok(strip(mparts[0]).toLowerAscii())
|
||||
|
||||
when isMainModule:
|
||||
var buf = "--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key1\"\r\n\r\nvalue1\r\n--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key2\"\r\n\r\nvalue2\r\n--------------------------5e7d0dd0ed6eb849--"
|
||||
var reader = MultiPartReader.init(buf, "------------------------5e7d0dd0ed6eb849")
|
||||
echo getPart(reader)
|
||||
echo "===="
|
||||
echo getPart(reader)
|
||||
echo "===="
|
||||
echo getPart(reader)
|
Loading…
Reference in New Issue