Initial commit.

This commit is contained in:
cheatfate 2021-01-27 08:14:17 +02:00 committed by zah
parent 0b396c34d8
commit 8381a40868
5 changed files with 997 additions and 0 deletions

View File

@ -0,0 +1,27 @@
#
# Chronos HTTP/S common types
# (c) Copyright 2019-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import stew/results, httputils
export results, httputils
const
useChroniclesLogging* {.booldefine.} = false
HeadersMark* = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
type
HttpResult*[T] = Result[T, string]
HttpResultCode*[T] = Result[T, HttpCode]
HttpError* = object of CatchableError
HttpCriticalFailure* = object of HttpError
HttpRecoverableFailure* = object of HttpError
template log*(body: untyped) =
when defined(useChroniclesLogging):
body

Binary file not shown.

View File

@ -0,0 +1,560 @@
#
# Chronos HTTP/S server implementation
# (c) Copyright 2019-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[tables, options, uri, strutils]
import stew/results, httputils
import ../../asyncloop, ../../asyncsync
import ../../streams/[asyncstream, boundstream, chunkstream]
import httptable, httpcommon
export httpcommon
when defined(useChroniclesLogging):
echo "Importing chronicles"
import chronicles
type
HttpServerFlags* = enum
Secure
TransferEncodingFlags* {.pure.} = enum
Identity, Chunked, Compress, Deflate, Gzip
ContentEncodingFlags* {.pure.} = enum
Identity, Br, Compress, Deflate, Gzip
HttpRequestFlags* {.pure.} = enum
BoundBody, UnboundBody, MultipartForm, UrlencodedForm
HttpServer* = ref object of RootRef
instance*: StreamServer
# semaphore*: AsyncSemaphore
maxConnections*: int
baseUri*: Uri
flags*: set[HttpServerFlags]
connections*: Table[string, Future[void]]
acceptLoop*: Future[void]
lifetime*: Future[void]
headersTimeout: Duration
bodyTimeout: Duration
maxHeadersSize: int
maxRequestBodySize: int
HttpServerState* = enum
ServerRunning, ServerStopped, ServerClosed
HttpRequest* = object
headersTable: HttpTable
queryTable: HttpTable
postTable: HttpTable
rawPath*: string
rawQuery*: string
uri*: Uri
scheme*: string
version*: HttpVersion
meth*: HttpMethod
contentEncoding*: set[ContentEncodingFlags]
transferEncoding*: set[TransferEncodingFlags]
requestFlags*: set[HttpRequestFlags]
contentLength: int
connection*: HttpConnection
mainReader*: AsyncStreamReader
HttpConnection* = ref object of RootRef
server: HttpServer
transp: StreamTransport
buffer: seq[byte]
const
HeadersMark = @[byte(0x0D), byte(0x0A), byte(0x0D), byte(0x0A)]
proc new*(htype: typedesc[HttpServer],
address: TransportAddress,
flags: set[HttpServerFlags] = {},
serverUri = Uri(),
maxConnections: int = -1,
bufferSize: int = 4096,
backlogSize: int = 100,
httpHeadersTimeout = 10.seconds,
httpBodyTimeout = 30.seconds,
maxHeadersSize: int = 8192,
maxRequestBodySize: int = 1_048_576): HttpResult[HttpServer] =
var res = HttpServer(
maxConnections: maxConnections,
headersTimeout: httpHeadersTimeout,
bodyTimeout: httpBodyTimeout,
maxHeadersSize: maxHeadersSize,
maxRequestBodySize: maxRequestBodySize
)
res.baseUri =
if len(serverUri.hostname) > 0 and isAbsolute(serverUri):
serverUri
else:
if HttpServerFlags.Secure in flags:
parseUri("https://" & $address & "/")
else:
parseUri("http://" & $address & "/")
try:
res.instance = createStreamServer(address, flags = {ReuseAddr},
bufferSize = bufferSize,
backlog = backlogSize)
# if maxConnections > 0:
# res.semaphore = newAsyncSemaphore(maxConnections)
res.lifetime = newFuture[void]("http.server.lifetime")
res.connections = initTable[string, Future[void]]()
return ok(res)
except TransportOsError as exc:
return err(exc.msg)
except CatchableError as exc:
return err(exc.msg)
proc getId(transp: StreamTransport): string {.inline.} =
## Returns string unique transport's identifier as string.
$transp.remoteAddress() & "_" & $transp.localAddress()
proc newHttpCriticalFailure*(msg: string): ref HttpCriticalFailure =
newException(HttpCriticalFailure, msg)
proc newHttpRecoverableFailure*(msg: string): ref HttpRecoverableFailure =
newException(HttpRecoverableFailure, msg)
iterator queryParams(query: string): tuple[key: string, value: string] =
for pair in query.split('&'):
let items = pair.split('=', maxsplit = 1)
let k = items[0]
let v = if len(items) > 1: items[1] else: ""
yield (decodeUrl(k), decodeUrl(v))
func getMultipartBoundary*(contentType: string): HttpResult[string] =
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): HttpResult[string] =
if len(contentHeader) > 1:
return err("Multiple Content-Type values found")
let mparts = contentHeader[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
func getTransferEncoding(contentHeader: seq[string]): HttpResult[
set[TransferEncodingFlags]] =
var res: set[TransferEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(TransferEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii())
of "identity":
res.incl(TransferEncodingFlags.Identity)
of "chunked":
res.incl(TransferEncodingFlags.Chunked)
of "compress":
res.incl(TransferEncodingFlags.Compress)
of "deflate":
res.incl(TransferEncodingFlags.Deflate)
of "gzip":
res.incl(TransferEncodingFlags.Gzip)
of "":
res.incl(TransferEncodingFlags.Identity)
else:
return err("Incorrect Transfer-Encoding value")
ok(res)
func getContentEncoding(contentHeader: seq[string]): HttpResult[
set[ContentEncodingFlags]] =
var res: set[ContentEncodingFlags] = {}
if len(contentHeader) == 0:
res.incl(ContentEncodingFlags.Identity)
ok(res)
else:
for header in contentHeader:
for item in header.split(","):
case strip(item.toLowerAscii()):
of "identity":
res.incl(ContentEncodingFlags.Identity)
of "br":
res.incl(ContentEncodingFlags.Br)
of "compress":
res.incl(ContentEncodingFlags.Compress)
of "deflate":
res.incl(ContentEncodingFlags.Deflate)
of "gzip":
res.incl(ContentEncodingFlags.Gzip)
of "":
res.incl(ContentEncodingFlags.Identity)
else:
return err("Incorrect Content-Encoding value")
ok(res)
proc hasBody*(request: HttpRequest): bool =
## Returns ``true`` if request has body.
request.requestFlags * {HttpRequestFlags.BoundBody,
HttpRequestFlags.UnboundBody} != {}
proc prepareRequest(conn: HttpConnection,
req: HttpRequestHeader): HttpResultCode[HttpRequest] =
var request = HttpRequest()
if req.version notin {HttpVersion10, HttpVersion11}:
return err(Http505)
request.version = req.version
request.meth = req.meth
request.rawPath =
block:
let res = req.uri()
if len(res) == 0:
return err(Http400)
res
request.uri =
if request.rawPath != "*":
let uri = parseUri(request.rawPath)
if uri.scheme notin ["http", "https", ""]:
return err(Http400)
uri
else:
var uri = initUri()
uri.path = "*"
uri
request.queryTable =
block:
var table = HttpTable.init()
for key, value in queryParams(request.uri.query):
table.add(key, value)
table
request.headersTable =
block:
var table = HttpTable.init()
# Retrieve headers and values
for key, value in req.headers():
table.add(key, value)
# Validating HTTP request headers
# Some of the headers must be present only once.
if table.count("content-type") > 1:
return err(Http400)
if table.count("content-length") > 1:
return err(Http400)
if table.count("transfer-encoding") > 1:
return err(Http400)
table
# Preprocessing "Content-Encoding" header.
request.contentEncoding =
block:
let res = getContentEncoding(
request.headersTable.getList("content-encoding"))
if res.isErr():
return err(Http400)
else:
res.get()
# Preprocessing "Transfer-Encoding" header.
request.transferEncoding =
block:
let res = getTransferEncoding(
request.headersTable.getList("transfer-encoding"))
if res.isErr():
return err(Http400)
else:
res.get()
# Almost all HTTP requests could have body (except TRACE), we perform some
# steps to reveal information about body.
if "content-length" in request.headersTable:
let length = request.headersTable.getInt("content-length")
if length > 0:
if request.meth == MethodTrace:
return err(Http400)
if length > uint64(high(int)):
return err(Http413)
if length > uint64(conn.server.maxRequestBodySize):
return err(Http413)
request.contentLength = int(length)
request.requestFlags.incl(HttpRequestFlags.BoundBody)
else:
if TransferEncodingFlags.Chunked in request.transferEncoding:
if request.meth == MethodTrace:
return err(Http400)
request.requestFlags.incl(HttpRequestFlags.UnboundBody)
if request.hasBody():
# If request has body, we going to understand how its encoded.
const
UrlEncodedType = "application/x-www-form-urlencoded"
MultipartType = "multipart/form-data"
if "content-type" in request.headersTable:
let contentType = request.headersTable.getString("content-type")
let tmp = strip(contentType).toLowerAscii()
if tmp.startsWith(UrlEncodedType):
request.requestFlags.incl(UrlencodedForm)
elif tmp.startsWith(MultipartType):
request.requestFlags.incl(MultipartForm)
request.mainReader = newAsyncStreamReader(conn.transp)
ok(request)
proc getBodyStream*(request: HttpRequest): HttpResult[AsyncStreamReader] =
if HttpRequestFlags.BoundBody in request.requestFlags:
ok(newBoundedStreamReader(request.mainReader, request.contentLength))
elif HttpRequestFlags.UnboundBody in request.requestFlags:
ok(newChunkedStreamReader(request.mainReader))
else:
err("Request do not have body available")
proc getBody*(request: HttpRequest): Future[seq[byte]] {.async.} =
## Obtain request's body as sequence of bytes.
let res = request.getBodyStream()
if res.isErr():
return @[]
else:
try:
return await read(res.get())
except AsyncStreamError:
raise newHttpCriticalFailure("Read failure")
proc consumeBody*(request: HttpRequest): Future[void] {.async.} =
## Consume/discard request's body.
let res = request.getBodyStream()
if res.isErr():
return
else:
let reader = res.get()
try:
discard await reader.consume()
return
except AsyncStreamError:
raise newHttpCriticalFailure("Read failure")
proc sendErrorResponse(conn: HttpConnection, version: HttpVersion,
code: HttpCode, keepAlive = true,
datatype = "text/text",
databody = ""): Future[bool] {.async.} =
var answer = $version & " " & $code & "\r\n"
answer.add("Date: " & httpDate() & "\r\n")
if len(databody) > 0:
answer.add("Content-Type: " & datatype & "\r\n")
answer.add("Content-Length: " & $len(databody) & "\r\n")
if keepAlive:
answer.add("Connection: keep-alive\r\n")
else:
answer.add("Connection: close\r\n")
answer.add("\r\n")
if len(databody) > 0:
answer.add(databody)
try:
let res {.used.} = await conn.transp.write(answer)
return true
except CancelledError:
return false
except TransportOsError:
return false
except CatchableError:
return false
proc sendErrorResponse(request: HttpRequest, code: HttpCode, keepAlive = true,
datatype = "text/text",
databody = ""): Future[bool] =
sendErrorResponse(request.connection, request.version, code, keepAlive,
datatype, databody)
proc getRequest*(conn: HttpConnection): Future[HttpRequest] {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $conn.transp.remoteAddress
try:
conn.buffer.setLen(conn.server.maxHeadersSize)
let res = await conn.transp.readUntil(addr conn.buffer[0], len(conn.buffer),
HeadersMark)
conn.buffer.setLen(res)
let header = parseRequest(conn.buffer)
if header.failed():
log debug "Malformed header received"
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Malformed request recieved")
else:
let res = prepareRequest(conn, header)
if res.isErr():
discard await conn.sendErrorResponse(HttpVersion11, Http400, false)
raise newHttpCriticalFailure("Invalid request received")
else:
return res.get()
except TransportOsError:
log debug "Unexpected OS error"
raise newHttpCriticalFailure("Unexpected OS error")
except TransportIncompleteError:
log debug "Remote peer disconnected"
raise newHttpCriticalFailure("Remote peer disconnected")
except TransportLimitError:
log debug "Maximum size of request headers reached"
discard await conn.sendErrorResponse(HttpVersion11, Http413, false)
raise newHttpCriticalFailure("Maximum size of request headers reached")
proc processLoop(server: HttpServer, transp: StreamTransport) {.async.} =
when defined(useChroniclesLogging):
logScope:
peer = $transp.remoteAddress
var conn = HttpConnection(
transp: transp, buffer: newSeq[byte](server.maxHeadersSize),
server: server
)
log info "Client connected"
var breakLoop = false
while true:
try:
let request = await conn.getRequest().wait(server.headersTimeout)
echo "== HEADERS TABLE"
echo request.headersTable
echo "== QUERY TABLE"
echo request.queryTable
echo "== TRANSFER ENCODING ", request.transferEncoding
echo "== CONTENT ENCODING ", request.contentEncoding
echo "== REQUEST FLAGS ", request.requestFlags
var stream = await request.getBody()
echo cast[string](stream)
discard await conn.sendErrorResponse(HttpVersion11, Http200, true,
databody = "OK")
log debug "Response sent"
except AsyncTimeoutError:
log debug "Timeout reached while reading headers"
discard await conn.sendErrorResponse(HttpVersion11, Http408, false)
breakLoop = true
except CancelledError:
breakLoop = true
except HttpRecoverableFailure:
breakLoop = false
except HttpCriticalFailure:
breakLoop = true
except CatchableError as exc:
echo "CatchableError received ", exc.name
if breakLoop:
break
await transp.closeWait()
server.connections.del(transp.getId())
# if server.maxConnections > 0:
# server.semaphore.release()
log info "Client got disconnected"
proc acceptClientLoop(server: HttpServer) {.async.} =
var breakLoop = false
while true:
try:
# if server.maxConnections > 0:
# await server.semaphore.acquire()
let transp = await server.instance.accept()
server.connections[transp.getId()] = processLoop(server, transp)
except CancelledError:
# Server was stopped
breakLoop = true
except TransportOsError:
# This is some critical unrecoverable error.
breakLoop = true
except TransportTooManyError:
# Non critical error
breakLoop = false
except CatchableError:
# Unexpected error
breakLoop = true
discard
if breakLoop:
break
proc state*(server: HttpServer): HttpServerState =
## Returns current HTTP server's state.
if server.lifetime.finished():
ServerClosed
else:
if isNil(server.acceptLoop):
ServerStopped
else:
if server.acceptLoop.finished():
ServerStopped
else:
ServerRunning
proc start*(server: HttpServer) =
## Starts HTTP server.
if server.state == ServerStopped:
server.acceptLoop = acceptClientLoop(server)
proc stop*(server: HttpServer) {.async.} =
## Stop HTTP server from accepting new connections.
if server.state == ServerRunning:
await server.acceptLoop.cancelAndWait()
proc drop*(server: HttpServer) {.async.} =
## Drop all pending HTTP connections.
if server.state in {ServerStopped, ServerRunning}:
discard
proc close*(server: HttpServer) {.async.} =
## Stop HTTP server and drop all the pending connections.
if server.state != ServerClosed:
await server.stop()
await server.drop()
await server.instance.closeWait()
server.lifetime.complete()
proc join*(server: HttpServer): Future[void] =
## Wait until HTTP server will not be closed.
var retFuture = newFuture[void]("http.server.join")
proc continuation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()):
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
if not(retFuture.finished()):
server.lifetime.removeCallback(continuation, cast[pointer](retFuture))
if server.state == ServerClosed:
retFuture.complete()
else:
server.lifetime.addCallback(continuation, cast[pointer](retFuture))
retFuture.cancelCallback = cancellation
retFuture
when isMainModule:
let res = HttpServer.new(initTAddress("127.0.0.1:30080"), maxConnections = 1)
if res.isOk():
let server = res.get()
server.start()
echo "HTTP server was started"
waitFor server.join()
else:
echo "Failed to start server: ", res.error

View File

@ -0,0 +1,124 @@
#
# Chronos HTTP/S case-insensitive non-unique
# key-value memory storage
# (c) Copyright 2019-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[tables, strutils]
type
HttpTable* = object
table: Table[string, seq[string]]
HttpTableRef* = ref HttpTable
HttpTables* = HttpTable | HttpTableRef
proc `-`(x: uint32): uint32 {.inline.} =
(0xFFFF_FFFF'u32 - x) + 1'u32
proc LT(x, y: uint32): uint32 {.inline.} =
let z = x - y
(z xor ((y xor x) and (y xor z))) shr 31
proc decValue(c: byte): int =
let x = uint32(c) - 0x30'u32
let r = ((x + 1'u32) and -LT(x, 10))
int(r) - 1
proc bytesToDec*[T: byte|char](src: openarray[T]): uint64 =
var v = 0'u64
for i in 0 ..< len(src):
let d =
when T is byte:
decValue(src[i])
else:
decValue(byte(src[i]))
if d < 0:
# non-decimal character encountered
return v
else:
let nv = ((v shl 3) + (v shl 1)) + uint64(d)
if nv < v:
# overflow happened
return v
else:
v = nv
v
proc add*(ht: var HttpTables, key: string, value: string) =
let lowkey = key.toLowerAscii()
var nitem = @[value]
if ht.table.hasKeyOrPut(lowkey, nitem):
var oitem = ht.table[lowkey]
oitem.add(value)
ht.table[lowkey] = oitem
proc add*(ht: var HttpTables, key: string, value: SomeInteger) =
ht.add(key, $value)
proc contains*(ht: var HttpTables, key: string): bool =
ht.table.contains(key.toLowerAscii())
proc getList*(ht: HttpTables, key: string): seq[string] =
var default: seq[string]
ht.table.getOrDefault(key.toLowerAscii(), default)
proc getString*(ht: HttpTables, key: string): string =
var default: seq[string]
ht.table.getOrDefault(key.toLowerAscii(), default).join(",")
proc count*(ht: HttpTables, key: string): int =
var default: seq[string]
len(ht.table.getOrDefault(key, default))
proc getInt*(ht: HttpTables, key: string): uint64 =
bytesToDec(ht.getString(key))
proc getLastString*(ht: HttpTables, key: string): string =
var default: seq[string]
let item = ht.table.getOrDefault(key.toLowerAscii(), default)
if len(item) == 0:
""
else:
item[^1]
proc getLastInt*(ht: HttpTables, key: string): uint64 =
bytesToDec(ht.getLastString())
proc init*(htt: typedesc[HttpTable]): HttpTable =
HttpTable(table: initTable[string, seq[string]]())
proc new*(htt: typedesc[HttpTableRef]): HttpTableRef =
HttpTableRef(table: initTable[string, seq[string]]())
proc normalizeHeaderName*(value: string): string =
var res = value.toLowerAscii()
var k = 0
while k < len(res):
if k == 0:
res[k] = toUpperAscii(res[k])
inc(k, 1)
else:
if res[k] == '-':
if k + 1 < len(res):
res[k + 1] = toUpperAscii(res[k + 1])
inc(k, 2)
else:
break
else:
inc(k, 1)
res
proc `$`*(ht: HttpTables): string =
var res = ""
for key, value in ht.table.pairs():
for item in value:
res.add(key.normalizeHeaderName())
res.add(": ")
res.add(item)
res.add("\p")
res

View File

@ -0,0 +1,286 @@
#
# Chronos HTTP/S multipart/form
# encoding and decoding helper procedures
# (c) Copyright 2019-Present
# Status Research & Development GmbH
#
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[monotimes, strutils]
import chronos, stew/results
import httptable, httpcommon
export httptable, httpcommon
type
MultiPartSource {.pure.} = enum
Stream, Buffer
MultiPartReader* = object
case kind: MultiPartSource
of MultiPartSource.Stream:
stream: AsyncStreamReader
last: BoundedAsyncStreamReader
of MultiPartSource.Buffer:
discard
buffer: seq[byte]
offset: int
boundary: seq[byte]
MultiPartReaderRef* = ref MultiPartReader
MultiPart* = object
headers: HttpTable
stream: BoundedAsyncStreamReader
offset: int
size: int
MultipartError* = object of HttpError
MultipartEOMError* = object of MultipartError
MultiPartIncorrectError* = object of MultipartError
MultiPartIncompleteError* = object of MultipartError
BChar* = byte | char
proc startsWith*(s, prefix: openarray[byte]): bool =
var i = 0
while true:
if i >= len(prefix): return true
if i >= len(s) or s[i] != prefix[i]: return false
inc(i)
proc parseUntil*(s, until: openarray[byte]): int =
var i = 0
while i < len(s):
if len(until) > 0 and s[i] == until[0]:
var u = 1
while i + u < len(s) and u < len(until) and s[i + u] == until[u]:
inc u
if u >= len(until): return i
inc(i)
-1
proc init*[A: BChar, B: BChar](mpt: typedesc[MultiPartReader],
buffer: openarray[A],
boundary: openarray[B]): MultiPartReader =
# Boundary should not be empty.
doAssert(len(boundary) > 0)
# Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can reuse
# different parts of this sequence for processing.
var fboundary = newSeq[byte](len(boundary) + 4)
fboundary[0] = 0x0D'u8
fboundary[1] = 0x0A'u8
fboundary[2] = byte('-')
fboundary[3] = byte('-')
copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary))
# Make copy of buffer, because all the returned parts depending on it.
var buf = newSeq[byte](len(buffer))
if len(buf) > 0:
copyMem(addr buf[0], unsafeAddr buffer[0], len(buffer))
MultiPartReader(kind: MultiPartSource.Buffer,
buffer: buf, offset: 0, boundary: fboundary)
proc init*[B: BChar](mpt: typedesc[MultiPartReader],
stream: AsyncStreamReader,
boundary: openarray[B],
partHeadersMaxSize = 4096): MultiPartReader =
# Boundary should not be empty.
doAssert(len(boundary) > 0)
# Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can reuse
# different parts of this sequence for processing.
var fboundary = newSeq[byte](len(boundary) + 4)
fboundary[0] = 0x0D'u8
fboundary[1] = 0x0A'u8
fboundary[2] = byte('-')
fboundary[3] = byte('-')
copyMem(addr fboundary[4], unsafeAddr boundary[0], len(boundary))
MultiPartReader(kind: MultiPartSource.Stream,
stream: stream, offset: 0, boundary: fboundary,
buffer: newSeq[byte](partHeadersMaxSize))
proc readPart*(mpr: MultiPartReaderRef): Future[MultiPart] {.async.} =
doAssert(mpr.kind == MultiPartSource.Stream)
# According to RFC1521 says that a boundary "must be no longer than 70
# characters, not counting the two leading hyphens.
if mpr.firstTime:
# Read and verify initial <-><-><boundary><CR><LF>
mpr.firstTime = false
await mpr.stream.readExactly(addr mpr.buffer[0], len(mpr.boundary) - 2)
if startsWith(mpr.buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
if buffer[0] == byte('-') and buffer[1] == byte("-"):
raise newException(MultiPartEOMError, "Unexpected EOM encountered")
if buffer[0] != 0x0D'u8 or buffer[1] != 0x0A'u8:
raise newException(MultiPartIncorrectError,
"Unexpected boundary suffix")
else:
raise newException(MultiPartIncorrectError,
"Unexpected boundary encountered")
# Reading part's headers
let res = await mpr.stream.readUntil(addr mpr.buffer[0], len(mpr.buffer),
HeadersMark)
var headersList = parseHeaders(mpr.buffer.toOpenArray(0, res - 1))
if headersList.failed():
raise newException(MultiPartIncorrectError, "Incorrect part headers found")
var part = MultiPart()
await mpr.stream.readExactly(addr buffer[0], len(mpr.boundary) - 4)
if startsWith(buffer.toOpenArray(0, len(mpr.boundary) - 5),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
await mpr.stream.readExactly(addr buffer[0], 2)
if buffer[0] == byte('-') and buffer[1] == byte("-"):
raise newException(MultiPartEOMError, "")
if buffer[0] == 0x0D'u8 and buffer[1] == 0x0A'u8:
except:
discard
# if mpr.offset >= len(mpr.buffer):
# raise newException(MultiPartEOMError, "End of multipart form encountered")
proc getStream*(mp: MultiPart): AsyncStreamReader =
mp.stream
proc getBody*(mp: MultiPart): Future[seq[byte]] {.async.} =
try:
let res = await mp.stream.read()
return res
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not read multipart body")
proc consumeBody*(mp: MultiPart) {.async.} =
try:
await mp.stream.consume()
except AsyncStreamError:
raise newException(HttpCriticalError, "Could not consume multipart body")
proc getPart*(mpr: var MultiPartReader): Result[MultiPart, string] =
doAssert(mpr.kind == MultiPartSource.Buffer)
if mpr.offset >= len(mpr.buffer):
return err("End of multipart form encountered")
if startsWith(mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1),
mpr.boundary.toOpenArray(2, len(mpr.boundary) - 1)):
# Buffer must start at <-><-><boundary>
mpr.offset += (len(mpr.boundary) - 2)
# After boundary there should be at least 2 symbols <-><-> or <CR><LF>.
if len(mpr.buffer) <= mpr.offset + 1:
return err("Incomplete multipart form")
if mpr.buffer[mpr.offset] == byte('-') and
mpr.buffer[mpr.offset + 1] == byte('-'):
# If we have <-><-><boundary><-><-> it means we have found last boundary
# of multipart message.
mpr.offset += 2
return err("End of multipart form encountered")
if mpr.buffer[mpr.offset] == 0x0D'u8 and
mpr.buffer[mpr.offset + 1] == 0x0A'u8:
# If we have <-><-><boundary><CR><LF> it means that we have found another
# part of multipart message.
mpr.offset += 2
# Multipart form must always have at least single Content-Disposition
# header, so we searching position where all the headers should be
# finished <CR><LF><CR><LF>.
let pos1 = parseUntil(
mpr.buffer.toOpenArray(mpr.offset, len(mpr.buffer) - 1),
[0x0D'u8, 0x0A'u8, 0x0D'u8, 0x0A'u8]
)
if pos1 < 0:
return err("Incomplete multipart form")
# parseUntil returns 0-based position without `until` sequence.
let start = mpr.offset + pos1 + 4
# Multipart headers position
let hstart = mpr.offset
let hfinish = mpr.offset + pos1 + 4 - 1
let headersList = parseHeaders(mpr.buffer.toOpenArray(hstart, hfinish),
false)
if headersList.failed():
return err("Incorrect or incomplete multipart headers received")
# Searching for value's boundary <CR><LF><-><-><boundary>.
let pos2 = parseUntil(
mpr.buffer.toOpenArray(start, len(mpr.buffer) - 1),
mpr.boundary.toOpenArray(0, len(mpr.boundary) - 1)
)
if pos2 < 0:
return err("Incomplete multipart form")
# We set reader's offset to the place right after <CR><LF>
mpr.offset = start + pos2 + 2
var part = MultiPart(offset: start, size: pos2, headers: HttpTable.init())
for k, v in headersList.headers(mpr.buffer.toOpenArray(hstart, hfinish)):
part.headers.add(k, v)
ok(part)
else:
err("Incorrect multipart form")
else:
err("Incorrect multipart form")
template `-`(x: uint32): uint32 =
(0xFFFF_FFFF'u32 - x) + 1'u32
template LT(x, y: uint32): uint32 =
let z = x - y
(z xor ((y xor x) and (y xor z))) shr 31
proc boundaryValue(c: char): bool =
let a0 = uint32(c) - 0x27'u32
let a1 = uint32(c) - 0x2B'u32
let a2 = uint32(c) - 0x3A'u32
let a3 = uint32(c) - 0x3D'u32
let a4 = uint32(c) - 0x3F'u32
let a5 = uint32(c) - 0x41'u32
let a6 = uint32(c) - 0x5F
let a7 = uint32(c) - 0x61'u32
let r = ((a0 + 1'u32) and -LT(a0, 3)) or
((a1 + 1'u32) and -LT(a1, 15)) or
((a2 + 1'u32) and -LT(a2, 1)) or
((a3 + 1'u32) and -LT(a3, 1)) or
((a4 + 1'u32) and -LT(a4, 1)) or
((a5 + 1'u32) and -LT(a5, 26)) or
((a6 + 1'u32) and -LT(a6, 1)) or
((a7 + 1'u32) and -LT(a7, 26))
(int(r) - 1) > 0
proc boundaryValue2(c: char): bool =
c in {'a'..'z', 'A' .. 'Z', '0' .. '9',
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}
func getMultipartBoundary*(contentType: string): Result[string, string] =
let mparts = contentType.split(";")
if strip(mparts[0]).toLowerAscii() != "multipart/form-data":
return err("Content-Type is not multipart")
if len(mparts) < 2:
return err("Content-Type missing boundary value")
let stripped = strip(mparts[1])
if not(stripped.toLowerAscii().startsWith("boundary")):
return err("Incorrect Content-Type boundary format")
let bparts = stripped.split("=")
if len(bparts) < 2:
err("Missing Content-Type boundary")
else:
ok(strip(bparts[1]))
func getContentType*(contentHeader: seq[string]): Result[string, string] =
if len(contentHeader) > 1:
return err("Multiple Content-Header values found")
let mparts = contentHeader[0].split(";")
ok(strip(mparts[0]).toLowerAscii())
when isMainModule:
var buf = "--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key1\"\r\n\r\nvalue1\r\n--------------------------5e7d0dd0ed6eb849\r\nContent-Disposition: form-data; name=\"key2\"\r\n\r\nvalue2\r\n--------------------------5e7d0dd0ed6eb849--"
var reader = MultiPartReader.init(buf, "------------------------5e7d0dd0ed6eb849")
echo getPart(reader)
echo "===="
echo getPart(reader)
echo "===="
echo getPart(reader)