Address review comments and fix issues found.

Adding more tests.
This commit is contained in:
cheatfate 2021-02-18 14:08:21 +02:00 committed by zah
parent fc0d1bcb43
commit eb81018d02
9 changed files with 349 additions and 157 deletions

View File

@ -153,3 +153,47 @@ func getContentType*(ch: openarray[string]): HttpResult[string] {.
else: else:
let mparts = ch[0].split(";") let mparts = ch[0].split(";")
ok(strip(mparts[0]).toLowerAscii()) ok(strip(mparts[0]).toLowerAscii())
proc bytesToString*(src: openarray[byte], dst: var openarray[char]) =
## Convert array of bytes to array of characters.
##
## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`.
## If this equation is not correct this procedures MUST not be used.
doAssert(len(src) == len(dst))
if len(src) > 0:
copyMem(addr dst[0], unsafeAddr src[0], len(src))
proc stringToBytes*(src: openarray[char], dst: var openarray[byte]) =
## Convert array of characters to array of bytes.
##
## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`.
## If this equation is not correct this procedures MUST not be used.
doAssert(len(src) == len(dst))
if len(src) > 0:
copyMem(addr dst[0], unsafeAddr src[0], len(src))
func bytesToString*(src: openarray[byte]): string =
## Convert array of bytes to a string.
##
## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`.
## If this equation is not correct this procedures MUST not be used.
var default: string
if len(src) > 0:
var dst = newString(len(src))
bytesToString(src, dst)
dst
else:
default
func stringToBytes*(src: openarray[char]): seq[byte] =
## Convert string to sequence of bytes.
##
## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`.
## If this equation is not correct this procedures MUST not be used.
var default: seq[byte]
if len(src) > 0:
var dst = newSeq[byte](len(src))
stringToBytes(src, dst)
dst
else:
default

View File

@ -29,7 +29,7 @@ type
exc*: ref CatchableError exc*: ref CatchableError
remote*: TransportAddress remote*: TransportAddress
RequestFence*[T] = Result[T, HttpProcessError] RequestFence* = Result[HttpRequestRef, HttpProcessError]
HttpRequestFlags* {.pure.} = enum HttpRequestFlags* {.pure.} = enum
BoundBody, UnboundBody, MultipartForm, UrlencodedForm, BoundBody, UnboundBody, MultipartForm, UrlencodedForm,
@ -42,7 +42,7 @@ type
Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb
HttpProcessCallback* = HttpProcessCallback* =
proc(req: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {.gcsafe.} proc(req: RequestFence): Future[HttpResponseRef] {.gcsafe.}
HttpServer* = object of RootObj HttpServer* = object of RootObj
instance*: StreamServer instance*: StreamServer
@ -507,7 +507,7 @@ proc createConnection(server: HttpServerRef,
proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
var var
conn: HttpConnectionRef conn: HttpConnectionRef
connArg: RequestFence[HttpRequestRef] connArg: RequestFence
runLoop = false runLoop = false
try: try:
@ -520,7 +520,7 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
except HttpCriticalError as exc: except HttpCriticalError as exc:
let error = HttpProcessError.init(HTTPServerError.CriticalError, exc, let error = HttpProcessError.init(HTTPServerError.CriticalError, exc,
transp.remoteAddress(), exc.code) transp.remoteAddress(), exc.code)
connArg = RequestFence[HttpRequestRef].err(error) connArg = RequestFence.err(error)
runLoop = false runLoop = false
if not(runLoop): if not(runLoop):
@ -538,33 +538,33 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} =
var breakLoop = false var breakLoop = false
while runLoop: while runLoop:
var var
arg: RequestFence[HttpRequestRef] arg: RequestFence
resp: HttpResponseRef resp: HttpResponseRef
try: try:
let request = await conn.getRequest().wait(server.headersTimeout) let request = await conn.getRequest().wait(server.headersTimeout)
arg = RequestFence[HttpRequestRef].ok(request) arg = RequestFence.ok(request)
except CancelledError: except CancelledError:
breakLoop = true breakLoop = true
except AsyncTimeoutError as exc: except AsyncTimeoutError as exc:
let error = HttpProcessError.init(HTTPServerError.TimeoutError, exc, let error = HttpProcessError.init(HTTPServerError.TimeoutError, exc,
transp.remoteAddress(), Http408) transp.remoteAddress(), Http408)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence.err(error)
except HttpRecoverableError as exc: except HttpRecoverableError as exc:
let error = HttpProcessError.init(HTTPServerError.RecoverableError, exc, let error = HttpProcessError.init(HTTPServerError.RecoverableError, exc,
transp.remoteAddress(), exc.code) transp.remoteAddress(), exc.code)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence.err(error)
except HttpCriticalError as exc: except HttpCriticalError as exc:
let error = HttpProcessError.init(HTTPServerError.CriticalError, exc, let error = HttpProcessError.init(HTTPServerError.CriticalError, exc,
transp.remoteAddress(), exc.code) transp.remoteAddress(), exc.code)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence.err(error)
except HttpDisconnectError: except HttpDisconnectError:
# If remote peer disconnected we just exiting loop # If remote peer disconnected we just exiting loop
breakLoop = true breakLoop = true
except CatchableError as exc: except CatchableError as exc:
let error = HttpProcessError.init(HTTPServerError.CatchableError, exc, let error = HttpProcessError.init(HTTPServerError.CatchableError, exc,
transp.remoteAddress(), Http500) transp.remoteAddress(), Http500)
arg = RequestFence[HttpRequestRef].err(error) arg = RequestFence.err(error)
if breakLoop: if breakLoop:
break break

View File

@ -27,6 +27,8 @@ proc LT(x, y: uint32): uint32 {.inline.} =
(z xor ((y xor x) and (y xor z))) shr 31 (z xor ((y xor x) and (y xor z))) shr 31
proc decValue(c: byte): int = proc decValue(c: byte): int =
# Procedure returns values [0..9] for character [`0`..`9`] and -1 for all
# other characters.
let x = uint32(c) - 0x30'u32 let x = uint32(c) - 0x30'u32
let r = ((x + 1'u32) and -LT(x, 10)) let r = ((x + 1'u32) and -LT(x, 10))
int(r) - 1 int(r) - 1
@ -46,39 +48,42 @@ proc bytesToDec*[T: byte|char](src: openarray[T]): uint64 =
let nv = ((v shl 3) + (v shl 1)) + uint64(d) let nv = ((v shl 3) + (v shl 1)) + uint64(d)
if nv < v: if nv < v:
# overflow happened # overflow happened
return v return 0xFFFF_FFFF_FFFF_FFFF'u64
else: else:
v = nv v = nv
v v
proc add*(ht: var HttpTables, key: string, value: string) = proc add*(ht: var HttpTables, key: string, value: string) =
## Add string ``value`` to header with key ``key``.
var default: seq[string] var default: seq[string]
let lowkey = key.toLowerAscii() ht.table.mgetOrPut(key.toLowerAscii(), default).add(value)
var nitem = @[value]
if ht.table.hasKeyOrPut(lowkey, nitem):
var oitem = ht.table.getOrDefault(lowkey, default)
oitem.add(value)
ht.table[lowkey] = oitem
proc add*(ht: var HttpTables, key: string, value: SomeInteger) = proc add*(ht: var HttpTables, key: string, value: SomeInteger) =
## Add integer ``value`` to header with key ``key``.
ht.add(key, $value) ht.add(key, $value)
proc set*(ht: var HttpTables, key: string, value: string) = proc set*(ht: var HttpTables, key: string, value: string) =
## Set/replace value of header with key ``key`` to value ``value``.
let lowkey = key.toLowerAscii() let lowkey = key.toLowerAscii()
ht.table[lowkey] = @[value] ht.table[lowkey] = @[value]
proc contains*(ht: var HttpTables, key: string): bool {. proc contains*(ht: var HttpTables, key: string): bool =
raises: [Defect].} = ## Returns ``true`` if header with name ``key`` is present in HttpTable/Ref.
ht.table.contains(key.toLowerAscii()) ht.table.contains(key.toLowerAscii())
proc getList*(ht: HttpTables, key: string, proc getList*(ht: HttpTables, key: string,
default: openarray[string] = []): seq[string] = default: openarray[string] = []): seq[string] =
## Returns sequence of headers with key ``key``.
var defseq = @default var defseq = @default
ht.table.getOrDefault(key.toLowerAscii(), defseq) ht.table.getOrDefault(key.toLowerAscii(), defseq)
proc getString*(ht: HttpTables, key: string, proc getString*(ht: HttpTables, key: string,
default: string = ""): string = default: string = ""): string =
var defseq = newSeq[string]() ## Returns concatenated value of headers with key ``key``.
##
## If there multiple headers with the same name ``key`` the result value will
## be concatenation using `,`.
var defseq: seq[string]
let res = ht.table.getOrDefault(key.toLowerAscii(), defseq) let res = ht.table.getOrDefault(key.toLowerAscii(), defseq)
if len(res) == 0: if len(res) == 0:
return default return default
@ -86,13 +91,33 @@ proc getString*(ht: HttpTables, key: string,
res.join(",") res.join(",")
proc count*(ht: HttpTables, key: string): int = proc count*(ht: HttpTables, key: string): int =
## Returns number of headers with key ``key``.
var default: seq[string] var default: seq[string]
len(ht.table.getOrDefault(key, default)) len(ht.table.getOrDefault(key.toLowerAscii(), default))
proc getInt*(ht: HttpTables, key: string): uint64 = proc getInt*(ht: HttpTables, key: string): uint64 =
## Parse header with key ``key`` as unsigned integer.
##
## Integers are parsed in safe way, there no exceptions or errors will be
## raised.
##
## If a non-decimal character is encountered during the parsing of the string
## the current accumulated value will be returned. So if string starts with
## non-decimal character, procedure will always return `0` (for example "-1"
## will be decoded as `0`). But if non-decimal character will be encountered
## later, only decimal part will be decoded, like `1234_5678` will be decoded
## as `1234`.
## Also, if in the parsing process result exceeds `uint64` maximum allowed
## value, then `0xFFFF_FFFF_FFFF_FFFF'u64` will be returned (for example
## `18446744073709551616` will be decoded as `18446744073709551615` because it
## overflows uint64 maximum value of `18446744073709551615`).
bytesToDec(ht.getString(key)) bytesToDec(ht.getString(key))
proc getLastString*(ht: HttpTables, key: string): string = proc getLastString*(ht: HttpTables, key: string): string =
## Returns "last" value of header ``key``.
##
## If there multiple headers with the same name ``key`` the value of last
## encountered header will be returned.
var default: seq[string] var default: seq[string]
let item = ht.table.getOrDefault(key.toLowerAscii(), default) let item = ht.table.getOrDefault(key.toLowerAscii(), default)
if len(item) == 0: if len(item) == 0:
@ -101,16 +126,25 @@ proc getLastString*(ht: HttpTables, key: string): string =
item[^1] item[^1]
proc getLastInt*(ht: HttpTables, key: string): uint64 = proc getLastInt*(ht: HttpTables, key: string): uint64 =
## Returns "last" value of header ``key`` as unsigned integer.
##
## If there multiple headers with the same name ``key`` the value of last
## encountered header will be returned.
##
## Unsigned integer will be parsed using rules of getInt() procedure.
bytesToDec(ht.getLastString()) bytesToDec(ht.getLastString())
proc init*(htt: typedesc[HttpTable]): HttpTable = proc init*(htt: typedesc[HttpTable]): HttpTable =
## Create empty HttpTable.
HttpTable(table: initTable[string, seq[string]]()) HttpTable(table: initTable[string, seq[string]]())
proc new*(htt: typedesc[HttpTableRef]): HttpTableRef = proc new*(htt: typedesc[HttpTableRef]): HttpTableRef =
## Create empty HttpTableRef.
HttpTableRef(table: initTable[string, seq[string]]()) HttpTableRef(table: initTable[string, seq[string]]())
proc init*(htt: typedesc[HttpTable], proc init*(htt: typedesc[HttpTable],
data: openArray[tuple[key: string, value: string]]): HttpTable = data: openArray[tuple[key: string, value: string]]): HttpTable =
## Create HttpTable using array of tuples with header names and values.
var res = HttpTable.init() var res = HttpTable.init()
for item in data: for item in data:
res.add(item.key, item.value) res.add(item.key, item.value)
@ -118,6 +152,7 @@ proc init*(htt: typedesc[HttpTable],
proc new*(htt: typedesc[HttpTableRef], proc new*(htt: typedesc[HttpTableRef],
data: openArray[tuple[key: string, value: string]]): HttpTableRef = data: openArray[tuple[key: string, value: string]]): HttpTableRef =
## Create HttpTableRef using array of tuples with header names and values.
var res = HttpTableRef.new() var res = HttpTableRef.new()
for item in data: for item in data:
res.add(item.key, item.value) res.add(item.key, item.value)
@ -152,7 +187,7 @@ proc normalizeHeaderName*(value: string): string =
iterator stringItems*(ht: HttpTables, iterator stringItems*(ht: HttpTables,
normKey = false): tuple[key: string, value: string] = normKey = false): tuple[key: string, value: string] =
## Iterate over HttpTable values. ## Iterate over HttpTable/Ref values.
## ##
## If ``normKey`` is true, key name value will be normalized using ## If ``normKey`` is true, key name value will be normalized using
## normalizeHeaderName() procedure. ## normalizeHeaderName() procedure.
@ -163,7 +198,7 @@ iterator stringItems*(ht: HttpTables,
iterator items*(ht: HttpTables, iterator items*(ht: HttpTables,
normKey = false): tuple[key: string, value: seq[string]] = normKey = false): tuple[key: string, value: seq[string]] =
## Iterate over HttpTable values. ## Iterate over HttpTable/Ref values.
## ##
## If ``normKey`` is true, key name value will be normalized using ## If ``normKey`` is true, key name value will be normalized using
## normalizeHeaderName() procedure. ## normalizeHeaderName() procedure.
@ -172,6 +207,7 @@ iterator items*(ht: HttpTables,
yield (key, v) yield (key, v)
proc `$`*(ht: HttpTables): string = proc `$`*(ht: HttpTables): string =
## Returns string representation of HttpTable/Ref.
var res = "" var res = ""
for key, value in ht.table.pairs(): for key, value in ht.table.pairs():
for item in value: for item in value:

View File

@ -52,6 +52,8 @@ type
proc startsWith(s, prefix: openarray[byte]): bool {. proc startsWith(s, prefix: openarray[byte]): bool {.
raises: [Defect].} = raises: [Defect].} =
# This procedure is copy of strutils.startsWith() procedure, however,
# it is intended to work with arrays of bytes, but not with strings.
var i = 0 var i = 0
while true: while true:
if i >= len(prefix): return true if i >= len(prefix): return true
@ -60,6 +62,8 @@ proc startsWith(s, prefix: openarray[byte]): bool {.
proc parseUntil(s, until: openarray[byte]): int {. proc parseUntil(s, until: openarray[byte]): int {.
raises: [Defect].} = raises: [Defect].} =
# This procedure is copy of parseutils.parseUntil() procedure, however,
# it is intended to work with arrays of bytes, but not with strings.
var i = 0 var i = 0
while i < len(s): while i < len(s):
if len(until) > 0 and s[i] == until[0]: if len(until) > 0 and s[i] == until[0]:
@ -127,7 +131,15 @@ proc new*[B: BChar](mpt: typedesc[MultiPartReaderRef],
## ``stream`` is stream used to read data. ## ``stream`` is stream used to read data.
## ``boundary`` is multipart boundary, this value must not be empty. ## ``boundary`` is multipart boundary, this value must not be empty.
## ``partHeadersMaxSize`` is maximum size of multipart's headers. ## ``partHeadersMaxSize`` is maximum size of multipart's headers.
doAssert(len(boundary) > 0) # According to specification length of boundary must be bigger then `0` and
# less or equal to `70`.
doAssert(len(boundary) > 0 and len(boundary) <= 70)
# 256 bytes is minimum value because we going to use single buffer for
# reading boundaries and for reading headers.
# Minimal buffer value for boundary is 5 bytes, maximum is 74 bytes. But at
# least one header should be present "Content-Disposition", so minimum value
# of multipart headers will be near 150 bytes.
doAssert(partHeadersMaxSize >= 256)
# Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can # Our internal boundary has format `<CR><LF><-><-><boundary>`, so we can
# reuse different parts of this sequence for processing. # reuse different parts of this sequence for processing.
var fboundary = newSeq[byte](len(boundary) + 4) var fboundary = newSeq[byte](len(boundary) + 4)
@ -266,8 +278,7 @@ proc closeWait*(mpr: MultiPartReaderRef) {.async.} =
else: else:
discard discard
proc getBytes*(mp: MultiPart): seq[byte] {. proc getBytes*(mp: MultiPart): seq[byte] {.raises: [Defect].} =
raises: [Defect].} =
## Returns value for MultiPart ``mp`` as sequence of bytes. ## Returns value for MultiPart ``mp`` as sequence of bytes.
case mp.kind case mp.kind
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
@ -276,28 +287,16 @@ proc getBytes*(mp: MultiPart): seq[byte] {.
doAssert(not(mp.stream.atEof()), "Value is not obtained yet") doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
mp.buffer mp.buffer
proc getString*(mp: MultiPart): string {. proc getString*(mp: MultiPart): string {.raises: [Defect].} =
raises: [Defect].} =
## Returns value for MultiPart ``mp`` as string. ## Returns value for MultiPart ``mp`` as string.
case mp.kind case mp.kind
of MultiPartSource.Buffer: of MultiPartSource.Buffer:
if len(mp.buffer) > 0: bytesToString(mp.buffer)
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
of MultiPartSource.Stream: of MultiPartSource.Stream:
doAssert(not(mp.stream.atEof()), "Value is not obtained yet") doAssert(not(mp.stream.atEof()), "Value is not obtained yet")
if len(mp.buffer) > 0: bytesToString(mp.buffer)
var res = newString(len(mp.buffer))
copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer))
res
else:
""
proc atEoM*(mpr: var MultiPartReader): bool {. proc atEoM*(mpr: var MultiPartReader): bool {.raises: [Defect].} =
raises: [Defect].} =
## Procedure returns ``true`` if MultiPartReader has reached the end of ## Procedure returns ``true`` if MultiPartReader has reached the end of
## multipart message. ## multipart message.
case mpr.kind case mpr.kind
@ -306,8 +305,7 @@ proc atEoM*(mpr: var MultiPartReader): bool {.
of MultiPartSource.Stream: of MultiPartSource.Stream:
mpr.stream.atEof() mpr.stream.atEof()
proc atEoM*(mpr: MultiPartReaderRef): bool {. proc atEoM*(mpr: MultiPartReaderRef): bool {.raises: [Defect].} =
raises: [Defect].} =
## Procedure returns ``true`` if MultiPartReader has reached the end of ## Procedure returns ``true`` if MultiPartReader has reached the end of
## multipart message. ## multipart message.
case mpr.kind case mpr.kind
@ -422,7 +420,7 @@ func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {.
## 2) `Content-Type` must be ``multipart/form-data``. ## 2) `Content-Type` must be ``multipart/form-data``.
## 3) `boundary` value must be present ## 3) `boundary` value must be present
## 4) `boundary` value must be less then 70 characters length and ## 4) `boundary` value must be less then 70 characters length and
## all characters should be part of alphabet. ## all characters should be part of specific alphabet.
if len(ch) > 1: if len(ch) > 1:
err("Multiple Content-Type headers found") err("Multiple Content-Type headers found")
else: else:
@ -455,17 +453,14 @@ func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {.
if len(bparts) < 2: if len(bparts) < 2:
err("Missing Content-Type boundary") err("Missing Content-Type boundary")
else: else:
if bparts[0].toLowerAscii() != "boundary": let candidate = strip(bparts[1])
err("Missing boundary key") if len(candidate) == 0:
err("Content-Type boundary must be at least 1 character size")
elif len(candidate) > 70:
err("Content-Type boundary must be less then 70 characters")
else: else:
let candidate = strip(bparts[1]) for ch in candidate:
if len(candidate) == 0: if ch notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9',
err("Content-Type boundary must be at least 1 character size") '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}:
elif len(candidate) > 70: return err("Content-Type boundary alphabet incorrect")
err("Content-Type boundary must be less then 70 characters") ok(candidate)
else:
for ch in candidate:
if ch notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9',
'\'' .. ')', '+' .. '/', ':', '=', '?', '_'}:
return err("Content-Type boundary alphabet incorrect")
ok(candidate)

View File

@ -42,11 +42,11 @@ type
WriteItem* = object WriteItem* = object
case kind*: WriteType case kind*: WriteType
of Pointer: of Pointer:
data1*: pointer dataPtr*: pointer
of Sequence: of Sequence:
data2*: seq[byte] dataSeq*: seq[byte]
of String: of String:
data3*: string dataStr*: string
size*: int size*: int
offset*: int offset*: int
future*: Future[void] future*: Future[void]
@ -96,12 +96,11 @@ type
AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter
proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer =
var res = AsyncBuffer( AsyncBuffer(
buffer: newSeq[byte](size), buffer: newSeq[byte](size),
events: [newAsyncEvent(), newAsyncEvent()], events: [newAsyncEvent(), newAsyncEvent()],
offset: 0 offset: 0
) )
res
proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} = proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} =
unsafeAddr sb.buffer[sb.offset] unsafeAddr sb.buffer[sb.offset]
@ -171,12 +170,12 @@ template toBufferOpenArray*(sb: AsyncBuffer): auto =
template copyOut*(dest: pointer, item: WriteItem, length: int) = template copyOut*(dest: pointer, item: WriteItem, length: int) =
if item.kind == Pointer: if item.kind == Pointer:
let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) let p = cast[pointer](cast[uint](item.dataPtr) + uint(item.offset))
copyMem(dest, p, length) copyMem(dest, p, length)
elif item.kind == Sequence: elif item.kind == Sequence:
copyMem(dest, unsafeAddr item.data2[item.offset], length) copyMem(dest, unsafeAddr item.dataSeq[item.offset], length)
elif item.kind == String: elif item.kind == String:
copyMem(dest, unsafeAddr item.data3[item.offset], length) copyMem(dest, unsafeAddr item.dataStr[item.offset], length)
proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {.
noinline.} = noinline.} =
@ -226,7 +225,7 @@ template checkStreamClosed*(t: untyped) =
proc atEof*(rstream: AsyncStreamReader): bool = proc atEof*(rstream: AsyncStreamReader): bool =
## Returns ``true`` is reading stream is closed or finished and internal ## Returns ``true`` is reading stream is closed or finished and internal
## buffer do not have any bytes left. ## buffer do not have any bytes left.
rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and rstream.state in {AsyncStreamState.Stopped, Finished, Closed, Error} and
(rstream.buffer.dataLen() == 0) (rstream.buffer.dataLen() == 0)
proc atEof*(wstream: AsyncStreamWriter): bool = proc atEof*(wstream: AsyncStreamWriter): bool =
@ -327,13 +326,13 @@ template readLoop(body: untyped): untyped =
raise rstream.error raise rstream.error
let (consumed, done) = body let (consumed, done) = body
rstream.buffer.shift(consumed) rstream.buffer.shift(consumed)
rstream.bytesCount = rstream.bytesCount + uint64(consumed) rstream.bytesCount = rstream.bytesCount + uint64(consumed)
if done: if done:
break break
else: else:
await rstream.buffer.wait() if not(rstream.atEof()):
await rstream.buffer.wait()
proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer,
nbytes: int) {.async.} = nbytes: int) {.async.} =
@ -711,7 +710,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer,
wstream.bytesCount = wstream.bytesCount + uint64(nbytes) wstream.bytesCount = wstream.bytesCount + uint64(nbytes)
else: else:
var item = WriteItem(kind: Pointer) var item = WriteItem(kind: Pointer)
item.data1 = pbytes item.dataPtr = pbytes
item.size = nbytes item.size = nbytes
item.future = newFuture[void]("async.stream.write(pointer)") item.future = newFuture[void]("async.stream.write(pointer)")
try: try:
@ -758,9 +757,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte],
else: else:
var item = WriteItem(kind: Sequence) var item = WriteItem(kind: Sequence)
if not isLiteral(sbytes): if not isLiteral(sbytes):
shallowCopy(item.data2, sbytes) shallowCopy(item.dataSeq, sbytes)
else: else:
item.data2 = sbytes item.dataSeq = sbytes
item.size = length item.size = length
item.future = newFuture[void]("async.stream.write(seq)") item.future = newFuture[void]("async.stream.write(seq)")
try: try:
@ -806,9 +805,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string,
else: else:
var item = WriteItem(kind: String) var item = WriteItem(kind: String)
if not isLiteral(sbytes): if not isLiteral(sbytes):
shallowCopy(item.data3, sbytes) shallowCopy(item.dataStr, sbytes)
else: else:
item.data3 = sbytes item.dataStr = sbytes
item.size = length item.size = length
item.future = newFuture[void]("async.stream.write(string)") item.future = newFuture[void]("async.stream.write(string)")
try: try:

View File

@ -117,9 +117,23 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
await upload(addr rstream.buffer, addr buffer[0], length) await upload(addr rstream.buffer, addr buffer[0], length)
rstream.state = AsyncStreamState.Finished rstream.state = AsyncStreamState.Finished
else: else:
if (res < toRead) and rstream.rsource.atEof():
case rstream.cmpop
of BoundCmp.Equal:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
of BoundCmp.LessOrEqual:
rstream.state = AsyncStreamState.Finished
rstream.offset = rstream.offset + res rstream.offset = rstream.offset + res
await upload(addr rstream.buffer, addr buffer[0], res) await upload(addr rstream.buffer, addr buffer[0], res)
else: else:
if (res < toRead) and rstream.rsource.atEof():
case rstream.cmpop
of BoundCmp.Equal:
rstream.state = AsyncStreamState.Error
rstream.error = newBoundedStreamIncompleteError()
of BoundCmp.LessOrEqual:
rstream.state = AsyncStreamState.Finished
rstream.offset = rstream.offset + res rstream.offset = rstream.offset + res
await upload(addr rstream.buffer, addr buffer[0], res) await upload(addr rstream.buffer, addr buffer[0], res)
else: else:
@ -146,10 +160,6 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.state = AsyncStreamState.Finished rstream.state = AsyncStreamState.Finished
break break
# Without this additional wait, procedures such as `read()` could got stuck
# in `await.buffer.wait()` because procedures are unable to detect EOF while
# inside readLoop body.
await stepsAsync(1)
# We need to notify consumer about error/close, but we do not care about # We need to notify consumer about error/close, but we do not care about
# incoming data anymore. # incoming data anymore.
rstream.buffer.forget() rstream.buffer.forget()
@ -170,11 +180,11 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} =
# Writing chunk data. # Writing chunk data.
case item.kind case item.kind
of WriteType.Pointer: of WriteType.Pointer:
await wstream.wsource.write(item.data1, item.size) await wstream.wsource.write(item.dataPtr, item.size)
of WriteType.Sequence: of WriteType.Sequence:
await wstream.wsource.write(addr item.data2[0], item.size) await wstream.wsource.write(addr item.dataSeq[0], item.size)
of WriteType.String: of WriteType.String:
await wstream.wsource.write(addr item.data3[0], item.size) await wstream.wsource.write(addr item.dataStr[0], item.size)
wstream.offset = wstream.offset + item.size wstream.offset = wstream.offset + item.size
item.future.complete() item.future.complete()
else: else:

View File

@ -49,11 +49,18 @@ proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] =
# We using `uint64` representation, but allow only 2^32 chunk size, # We using `uint64` representation, but allow only 2^32 chunk size,
# ChunkHeaderSize. # ChunkHeaderSize.
var res = 0'u64 var res = 0'u64
for i in 0 ..< min(len(buffer), ChunkHeaderSize): for i in 0 ..< min(len(buffer), ChunkHeaderSize + 1):
let value = hexValue(buffer[i]) let value = hexValue(buffer[i])
if value < 0: if value < 0:
return err("Incorrect chunk size encoding") if buffer[i] == byte(';'):
res = (res shl 4) or uint64(value) # chunk-extension is present, so chunk size is already decoded in res.
return ok(res)
else:
return err("Incorrect chunk size encoding")
else:
if i >= ChunkHeaderSize:
return err("The chunk size exceeds the limit")
res = (res shl 4) or uint64(value)
ok(res) ok(res)
proc setChunkSize(buffer: var openarray[byte], length: int64): int = proc setChunkSize(buffer: var openarray[byte], length: int64): int =
@ -135,7 +142,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} =
rstream.state = AsyncStreamState.Error rstream.state = AsyncStreamState.Error
rstream.error = exc rstream.error = exc
if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: if rstream.state != AsyncStreamState.Running:
# We need to notify consumer about error/close, but we do not care about # We need to notify consumer about error/close, but we do not care about
# incoming data anymore. # incoming data anymore.
rstream.buffer.forget() rstream.buffer.forget()
@ -161,11 +168,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} =
# Writing chunk data. # Writing chunk data.
case item.kind case item.kind
of WriteType.Pointer: of WriteType.Pointer:
await wstream.wsource.write(item.data1, item.size) await wstream.wsource.write(item.dataPtr, item.size)
of WriteType.Sequence: of WriteType.Sequence:
await wstream.wsource.write(addr item.data2[0], item.size) await wstream.wsource.write(addr item.dataSeq[0], item.size)
of WriteType.String: of WriteType.String:
await wstream.wsource.write(addr item.data3[0], item.size) await wstream.wsource.write(addr item.dataStr[0], item.size)
# Writing chunk footer CRLF. # Writing chunk footer CRLF.
await wstream.wsource.write(CRLF) await wstream.wsource.write(CRLF)
# Everything is fine, completing queue item's future. # Everything is fine, completing queue item's future.

View File

@ -506,7 +506,10 @@ suite "ChunkedStream test suite":
"--f98f0\r\nContent-Disposition: form-data; name=\"key3\"" & "--f98f0\r\nContent-Disposition: form-data; name=\"key3\"" &
"\r\n\r\nC\r\n" & "\r\n\r\nC\r\n" &
"--f98f0--\r\n" "--f98f0--\r\n"
] ],
["4;position=1\r\nWiki\r\n5;position=2\r\npedia\r\nE;position=3\r\n" &
" in\r\n\r\nchunks.\r\n0;position=4\r\n\r\n",
"Wikipedia in\r\n\r\nchunks."],
] ]
proc checkVector(address: TransportAddress, proc checkVector(address: TransportAddress,
inputstr: string): Future[string] {.async.} = inputstr: string): Future[string] {.async.} =
@ -545,7 +548,16 @@ suite "ChunkedStream test suite":
check waitFor(testVectors(initTAddress("127.0.0.1:46001"))) == true check waitFor(testVectors(initTAddress("127.0.0.1:46001"))) == true
test "ChunkedStream incorrect chunk test": test "ChunkedStream incorrect chunk test":
const BadVectors = [ const BadVectors = [
["100000000 \r\n1"], ["10000000;\r\n1"],
["10000000\r\n1"],
["FFFFFFFF;extension1=value1;extension2=value2\r\n1"],
["FFFFFFFF\r\n1"],
["100000000\r\n1"],
["10000000 \r\n1"],
["100000000 ;\r\n"],
["FFFFFFFF0\r\n1"],
["FFFFFFFF \r\n1"],
["FFFFFFFF ;\r\n1"],
["z\r\n1"] ["z\r\n1"]
] ]
proc checkVector(address: TransportAddress, proc checkVector(address: TransportAddress,
@ -571,11 +583,36 @@ suite "ChunkedStream test suite":
var r = await rstream2.read() var r = await rstream2.read()
doAssert(len(r) > 0) doAssert(len(r) > 0)
except ChunkedStreamIncompleteError: except ChunkedStreamIncompleteError:
if inputstr == "100000000 \r\n1": case inputstr
of "10000000;\r\n1":
res = true res = true
of "10000000\r\n1":
res = true
of "FFFFFFFF;extension1=value1;extension2=value2\r\n1":
res = true
of "FFFFFFFF\r\n1":
res = true
else:
res = false
except ChunkedStreamProtocolError: except ChunkedStreamProtocolError:
if inputstr == "z\r\n1": case inputstr
of "100000000\r\n1":
res = true res = true
of "10000000 \r\n1":
res = true
of "100000000 ;\r\n":
res = true
of "z\r\n1":
res = true
of "FFFFFFFF0\r\n1":
res = true
of "FFFFFFFF \r\n1":
res = true
of "FFFFFFFF ;\r\n1":
res = true
else:
res = false
await rstream2.closeWait() await rstream2.closeWait()
await rstream.closeWait() await rstream.closeWait()
await transp.closeWait() await transp.closeWait()
@ -687,6 +724,13 @@ suite "TLSStream test suite":
suite "BoundedStream test suite": suite "BoundedStream test suite":
type
BoundarySizeTest = enum
SizeReadWrite, SizeOverflow, SizeIncomplete, SizeEmpty
BoundaryBytesTest = enum
BoundaryRead, BoundaryDouble, BoundarySize, BoundaryIncomplete,
BoundaryEmpty
proc createBigMessage(size: int): seq[byte] = proc createBigMessage(size: int): seq[byte] =
var message = "ABCDEFGHIJKLMNOP" var message = "ABCDEFGHIJKLMNOP"
var res = newSeq[byte](size) var res = newSeq[byte](size)
@ -697,8 +741,8 @@ suite "BoundedStream test suite":
for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]: for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]:
for itemSize in [100, 60000]: for itemSize in [100, 60000]:
proc boundaryTest(address: TransportAddress, test: int, size: int, proc boundaryTest(address: TransportAddress, btest: BoundaryBytesTest,
boundary: seq[byte], size: int, boundary: seq[byte],
cmp: BoundCmp): Future[bool] {.async.} = cmp: BoundCmp): Future[bool] {.async.} =
var message = createBigMessage(size) var message = createBigMessage(size)
var clientRes = false var clientRes = false
@ -706,20 +750,21 @@ suite "BoundedStream test suite":
proc processClient(server: StreamServer, proc processClient(server: StreamServer,
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp) var wstream = newAsyncStreamWriter(transp)
if test == 0: case btest
of BoundaryRead:
await wstream.write(message) await wstream.write(message)
await wstream.write(boundary) await wstream.write(boundary)
await wstream.finish() await wstream.finish()
await wstream.closeWait() await wstream.closeWait()
clientRes = true clientRes = true
elif test == 1: of BoundaryDouble:
await wstream.write(message) await wstream.write(message)
await wstream.write(boundary) await wstream.write(boundary)
await wstream.write(message) await wstream.write(message)
await wstream.finish() await wstream.finish()
await wstream.closeWait() await wstream.closeWait()
clientRes = true clientRes = true
elif test == 2: of BoundarySize:
var ncmessage = message var ncmessage = message
ncmessage.setLen(len(message) - 2) ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage) await wstream.write(ncmessage)
@ -727,14 +772,14 @@ suite "BoundedStream test suite":
await wstream.finish() await wstream.finish()
await wstream.closeWait() await wstream.closeWait()
clientRes = true clientRes = true
elif test == 3: of BoundaryIncomplete:
var ncmessage = message var ncmessage = message
ncmessage.setLen(len(message) - 2) ncmessage.setLen(len(message) - 2)
await wstream.write(ncmessage) await wstream.write(ncmessage)
await wstream.finish() await wstream.finish()
await wstream.closeWait() await wstream.closeWait()
clientRes = true clientRes = true
elif test == 4: of BoundaryEmpty:
await wstream.write(boundary) await wstream.write(boundary)
await wstream.finish() await wstream.finish()
await wstream.closeWait() await wstream.closeWait()
@ -750,20 +795,21 @@ suite "BoundedStream test suite":
server.start() server.start()
var conn = await connect(address) var conn = await connect(address)
var rstream = newAsyncStreamReader(conn) var rstream = newAsyncStreamReader(conn)
if test == 0: case btest
of BoundaryRead:
var rbstream = newBoundedStreamReader(rstream, -1, boundary) var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response = await rbstream.read() let response = await rbstream.read()
if response == message: if response == message:
res = true res = true
await rbstream.closeWait() await rbstream.closeWait()
elif test == 1: of BoundaryDouble:
var rbstream = newBoundedStreamReader(rstream, -1, boundary) var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response1 = await rbstream.read() let response1 = await rbstream.read()
await rbstream.closeWait() await rbstream.closeWait()
let response2 = await rstream.read() let response2 = await rstream.read()
if (response1 == message) and (response2 == message): if (response1 == message) and (response2 == message):
res = true res = true
elif test == 2: of BoundarySize:
var expectMessage = message var expectMessage = message
expectMessage[^2] = 0x2D'u8 expectMessage[^2] = 0x2D'u8
expectMessage[^1] = 0x2D'u8 expectMessage[^1] = 0x2D'u8
@ -772,14 +818,14 @@ suite "BoundedStream test suite":
await rbstream.closeWait() await rbstream.closeWait()
if (len(response) == size) and response == expectMessage: if (len(response) == size) and response == expectMessage:
res = true res = true
elif test == 3: of BoundaryIncomplete:
var rbstream = newBoundedStreamReader(rstream, -1, boundary) var rbstream = newBoundedStreamReader(rstream, -1, boundary)
try: try:
let response {.used.} = await rbstream.read() let response {.used.} = await rbstream.read()
except BoundedStreamIncompleteError: except BoundedStreamIncompleteError:
res = true res = true
await rbstream.closeWait() await rbstream.closeWait()
elif test == 4: of BoundaryEmpty:
var rbstream = newBoundedStreamReader(rstream, -1, boundary) var rbstream = newBoundedStreamReader(rstream, -1, boundary)
let response = await rbstream.read() let response = await rbstream.read()
await rbstream.closeWait() await rbstream.closeWait()
@ -791,7 +837,7 @@ suite "BoundedStream test suite":
await server.join() await server.join()
return (res and clientRes) return (res and clientRes)
proc boundedTest(address: TransportAddress, test: int, proc boundedTest(address: TransportAddress, stest: BoundarySizeTest,
size: int, cmp: BoundCmp): Future[bool] {.async.} = size: int, cmp: BoundCmp): Future[bool] {.async.} =
var clientRes = false var clientRes = false
var res = false var res = false
@ -805,13 +851,14 @@ suite "BoundedStream test suite":
transp: StreamTransport) {.async.} = transp: StreamTransport) {.async.} =
var wstream = newAsyncStreamWriter(transp) var wstream = newAsyncStreamWriter(transp)
var wbstream = newBoundedStreamWriter(wstream, size, comparison = cmp) var wbstream = newBoundedStreamWriter(wstream, size, comparison = cmp)
if test == 0: case stest
of SizeReadWrite:
for i in 0 ..< 10: for i in 0 ..< 10:
await wbstream.write(messagePart) await wbstream.write(messagePart)
await wbstream.finish() await wbstream.finish()
await wbstream.closeWait() await wbstream.closeWait()
clientRes = true clientRes = true
elif test == 1: of SizeOverflow:
for i in 0 ..< 10: for i in 0 ..< 10:
await wbstream.write(messagePart) await wbstream.write(messagePart)
try: try:
@ -819,7 +866,7 @@ suite "BoundedStream test suite":
except BoundedStreamOverflowError: except BoundedStreamOverflowError:
clientRes = true clientRes = true
await wbstream.closeWait() await wbstream.closeWait()
elif test == 2: of SizeIncomplete:
for i in 0 ..< 9: for i in 0 ..< 9:
await wbstream.write(messagePart) await wbstream.write(messagePart)
case cmp case cmp
@ -835,7 +882,7 @@ suite "BoundedStream test suite":
except BoundedStreamIncompleteError: except BoundedStreamIncompleteError:
discard discard
await wbstream.closeWait() await wbstream.closeWait()
elif test == 3: of SizeEmpty:
case cmp case cmp
of BoundCmp.Equal: of BoundCmp.Equal:
try: try:
@ -861,17 +908,18 @@ suite "BoundedStream test suite":
var conn = await connect(address) var conn = await connect(address)
var rstream = newAsyncStreamReader(conn) var rstream = newAsyncStreamReader(conn)
var rbstream = newBoundedStreamReader(rstream, size, comparison = cmp) var rbstream = newBoundedStreamReader(rstream, size, comparison = cmp)
if test == 0: case stest
of SizeReadWrite:
let response = await rbstream.read() let response = await rbstream.read()
await rbstream.closeWait() await rbstream.closeWait()
if response == message: if response == message:
res = true res = true
elif test == 1: of SizeOverflow:
let response = await rbstream.read() let response = await rbstream.read()
await rbstream.closeWait() await rbstream.closeWait()
if response == message: if response == message:
res = true res = true
elif test == 2: of SizeIncomplete:
case cmp case cmp
of BoundCmp.Equal: of BoundCmp.Equal:
try: try:
@ -886,7 +934,7 @@ suite "BoundedStream test suite":
except BoundedStreamIncompleteError: except BoundedStreamIncompleteError:
res = false res = false
await rbstream.closeWait() await rbstream.closeWait()
elif test == 3: of SizeEmpty:
case cmp case cmp
of BoundCmp.Equal: of BoundCmp.Equal:
try: try:
@ -916,30 +964,34 @@ suite "BoundedStream test suite":
"<= " & $itemSize "<= " & $itemSize
test "BoundedStream(size) reading/writing test [" & suffix & "]": test "BoundedStream(size) reading/writing test [" & suffix & "]":
check waitFor(boundedTest(address, 0, itemSize, itemComp)) == true check waitFor(boundedTest(address, SizeReadWrite, itemSize,
itemComp)) == true
test "BoundedStream(size) overflow test [" & suffix & "]": test "BoundedStream(size) overflow test [" & suffix & "]":
check waitFor(boundedTest(address, 1, itemSize, itemComp)) == true check waitFor(boundedTest(address, SizeOverflow, itemSize,
itemComp)) == true
test "BoundedStream(size) incomplete test [" & suffix & "]": test "BoundedStream(size) incomplete test [" & suffix & "]":
check waitFor(boundedTest(address, 2, itemSize, itemComp)) == true check waitFor(boundedTest(address, SizeIncomplete, itemSize,
itemComp)) == true
test "BoundedStream(size) empty message test [" & suffix & "]": test "BoundedStream(size) empty message test [" & suffix & "]":
check waitFor(boundedTest(address, 3, itemSize, itemComp)) == true check waitFor(boundedTest(address, SizeEmpty, itemSize,
itemComp)) == true
test "BoundedStream(boundary) reading test [" & suffix & "]": test "BoundedStream(boundary) reading test [" & suffix & "]":
check waitFor(boundaryTest(address, 0, itemSize, check waitFor(boundaryTest(address, BoundaryRead, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) double message test [" & suffix & "]": test "BoundedStream(boundary) double message test [" & suffix & "]":
check waitFor(boundaryTest(address, 1, itemSize, check waitFor(boundaryTest(address, BoundaryDouble, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(size+boundary) reading size-bound test [" & test "BoundedStream(size+boundary) reading size-bound test [" &
suffix & "]": suffix & "]":
check waitFor(boundaryTest(address, 2, itemSize, check waitFor(boundaryTest(address, BoundarySize, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) reading incomplete test [" & test "BoundedStream(boundary) reading incomplete test [" &
suffix & "]": suffix & "]":
check waitFor(boundaryTest(address, 3, itemSize, check waitFor(boundaryTest(address, BoundaryIncomplete, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream(boundary) empty message test [" & test "BoundedStream(boundary) empty message test [" &
suffix & "]": suffix & "]":
check waitFor(boundaryTest(address, 4, itemSize, check waitFor(boundaryTest(address, BoundaryEmpty, itemSize,
@[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp))
test "BoundedStream leaks test": test "BoundedStream leaks test":

View File

@ -69,6 +69,10 @@ N8r5CwGcIX/XPC3lKazzbZ8baA==
""" """
suite "HTTP server testing suite": suite "HTTP server testing suite":
type
TooBigTest = enum
GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest
proc httpClient(address: TransportAddress, proc httpClient(address: TransportAddress,
data: string): Future[string] {.async.} = data: string): Future[string] {.async.} =
var transp: StreamTransport var transp: StreamTransport
@ -77,10 +81,7 @@ suite "HTTP server testing suite":
if len(data) > 0: if len(data) > 0:
let wres {.used.} = await transp.write(data) let wres {.used.} = await transp.write(data)
var rres = await transp.read() var rres = await transp.read()
var sres = newString(len(rres)) return bytesToString(rres)
if len(rres) > 0:
copyMem(addr sres[0], addr rres[0], len(rres))
return sres
except CatchableError: except CatchableError:
return "EXCEPTION" return "EXCEPTION"
finally: finally:
@ -104,10 +105,7 @@ suite "HTTP server testing suite":
if len(data) > 0: if len(data) > 0:
await tlsstream.writer.write(data) await tlsstream.writer.write(data)
var rres = await tlsstream.reader.read() var rres = await tlsstream.reader.read()
var sres = newString(len(rres)) return bytesToString(rres)
if len(rres) > 0:
copyMem(addr sres[0], addr rres[0], len(rres))
return sres
except CatchableError: except CatchableError:
return "EXCEPTION" return "EXCEPTION"
finally: finally:
@ -119,20 +117,21 @@ suite "HTTP server testing suite":
transp.closeWait()) transp.closeWait())
proc testTooBigBodyChunked(address: TransportAddress, proc testTooBigBodyChunked(address: TransportAddress,
operation: int): Future[bool] {.async.} = operation: TooBigTest): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
try: try:
if operation == 0: case operation
of GetBodyTest:
let body {.used.} = await request.getBody() let body {.used.} = await request.getBody()
elif operation == 1: of ConsumeBodyTest:
await request.consumeBody() await request.consumeBody()
elif operation == 2: of PostUrlTest:
let ptable {.used.} = await request.post() let ptable {.used.} = await request.post()
elif operation == 3: of PostMultipartTest:
let ptable {.used.} = await request.post() let ptable {.used.} = await request.post()
except HttpCriticalError as exc: except HttpCriticalError as exc:
if exc.code == Http413: if exc.code == Http413:
@ -153,14 +152,15 @@ suite "HTTP server testing suite":
server.start() server.start()
let request = let request =
if operation in [0, 1, 2]: case operation
of GetBodyTest, ConsumeBodyTest, PostUrlTest:
"POST / HTTP/1.0\r\n" & "POST / HTTP/1.0\r\n" &
"Content-Type: application/x-www-form-urlencoded\r\n" & "Content-Type: application/x-www-form-urlencoded\r\n" &
"Transfer-Encoding: chunked\r\n" & "Transfer-Encoding: chunked\r\n" &
"Cookie: 2\r\n\r\n" & "Cookie: 2\r\n\r\n" &
"5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" & "5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" &
"2\r\n9F\r\n0\r\n\r\n" "2\r\n9F\r\n0\r\n\r\n"
elif operation in [3]: of PostMultipartTest:
"POST / HTTP/1.0\r\n" & "POST / HTTP/1.0\r\n" &
"Host: 127.0.0.1:30080\r\n" & "Host: 127.0.0.1:30080\r\n" &
"Transfer-Encoding: chunked\r\n" & "Transfer-Encoding: chunked\r\n" &
@ -173,8 +173,6 @@ suite "HTTP server testing suite":
"\r\n\r\nC\r\n\r\n" & "\r\n\r\nC\r\n\r\n" &
"b\r\n--f98f0--\r\n\r\n" & "b\r\n--f98f0--\r\n\r\n" &
"0\r\n\r\n" "0\r\n\r\n"
else:
""
let data = await httpClient(address, request) let data = await httpClient(address, request)
await server.stop() await server.stop()
@ -184,7 +182,7 @@ suite "HTTP server testing suite":
test "Request headers timeout test": test "Request headers timeout test":
proc testTimeout(address: TransportAddress): Future[bool] {.async.} = proc testTimeout(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -213,7 +211,7 @@ suite "HTTP server testing suite":
test "Empty headers test": test "Empty headers test":
proc testEmpty(address: TransportAddress): Future[bool] {.async.} = proc testEmpty(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -241,7 +239,7 @@ suite "HTTP server testing suite":
test "Too big headers test": test "Too big headers test":
proc testTooBig(address: TransportAddress): Future[bool] {.async.} = proc testTooBig(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -271,7 +269,7 @@ suite "HTTP server testing suite":
test "Too big request body test (content-length)": test "Too big request body test (content-length)":
proc testTooBigBody(address: TransportAddress): Future[bool] {.async.} = proc testTooBigBody(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
discard discard
@ -300,24 +298,28 @@ suite "HTTP server testing suite":
test "Too big request body test (getBody()/chunked encoding)": test "Too big request body test (getBody()/chunked encoding)":
check: check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 0)) == true waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
GetBodyTest)) == true
test "Too big request body test (consumeBody()/chunked encoding)": test "Too big request body test (consumeBody()/chunked encoding)":
check: check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 1)) == true waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
ConsumeBodyTest)) == true
test "Too big request body test (post()/urlencoded/chunked encoding)": test "Too big request body test (post()/urlencoded/chunked encoding)":
check: check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 2)) == true waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
PostUrlTest)) == true
test "Too big request body test (post()/multipart/chunked encoding)": test "Too big request body test (post()/multipart/chunked encoding)":
check: check:
waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 3)) == true waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"),
PostMultipartTest)) == true
test "Query arguments test": test "Query arguments test":
proc testQuery(address: TransportAddress): Future[bool] {.async.} = proc testQuery(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -357,7 +359,7 @@ suite "HTTP server testing suite":
test "Headers test": test "Headers test":
proc testHeaders(address: TransportAddress): Future[bool] {.async.} = proc testHeaders(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -400,7 +402,7 @@ suite "HTTP server testing suite":
test "POST arguments (urlencoded/content-length) test": test "POST arguments (urlencoded/content-length) test":
proc testPostUrl(address: TransportAddress): Future[bool] {.async.} = proc testPostUrl(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
@ -443,7 +445,7 @@ suite "HTTP server testing suite":
test "POST arguments (urlencoded/chunked encoding) test": test "POST arguments (urlencoded/chunked encoding) test":
proc testPostUrl2(address: TransportAddress): Future[bool] {.async.} = proc testPostUrl2(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
@ -487,7 +489,7 @@ suite "HTTP server testing suite":
test "POST arguments (multipart/content-length) test": test "POST arguments (multipart/content-length) test":
proc testPostMultipart(address: TransportAddress): Future[bool] {.async.} = proc testPostMultipart(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
@ -542,7 +544,7 @@ suite "HTTP server testing suite":
test "POST arguments (multipart/chunked encoding) test": test "POST arguments (multipart/chunked encoding) test":
proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} = proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
var kres = newSeq[string]() var kres = newSeq[string]()
@ -606,7 +608,7 @@ suite "HTTP server testing suite":
test "HTTPS server (successful handshake) test": test "HTTPS server (successful handshake) test":
proc testHTTPS(address: TransportAddress): Future[bool] {.async.} = proc testHTTPS(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -644,7 +646,7 @@ suite "HTTP server testing suite":
proc testHTTPS2(address: TransportAddress): Future[bool] {.async.} = proc testHTTPS2(address: TransportAddress): Future[bool] {.async.} =
var serverRes = false var serverRes = false
var testFut = newFuture[void]() var testFut = newFuture[void]()
proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. proc process(r: RequestFence): Future[HttpResponseRef] {.
async.} = async.} =
if r.isOk(): if r.isOk():
let request = r.get() let request = r.get()
@ -758,7 +760,9 @@ suite "HTTP server testing suite":
("", 0'u64), ("0", 0'u64), ("-0", 0'u64), ("0-", 0'u64), ("", 0'u64), ("0", 0'u64), ("-0", 0'u64), ("0-", 0'u64),
("01", 1'u64), ("001", 1'u64), ("0000000000001", 1'u64), ("01", 1'u64), ("001", 1'u64), ("0000000000001", 1'u64),
("18446744073709551615", 0xFFFF_FFFF_FFFF_FFFF'u64), ("18446744073709551615", 0xFFFF_FFFF_FFFF_FFFF'u64),
("18446744073709551616", 1844674407370955161'u64), ("18446744073709551616", 0xFFFF_FFFF_FFFF_FFFF'u64),
("99999999999999999999", 0xFFFF_FFFF_FFFF_FFFF'u64),
("999999999999999999999999999999999999", 0xFFFF_FFFF_FFFF_FFFF'u64),
("FFFFFFFFFFFFFFFF", 0'u64), ("FFFFFFFFFFFFFFFF", 0'u64),
("0123456789ABCDEF", 123456789'u64) ("0123456789ABCDEF", 123456789'u64)
] ]
@ -770,6 +774,51 @@ suite "HTTP server testing suite":
for item in TestVectors: for item in TestVectors:
check bytesToDec(item[0]) == item[1] check bytesToDec(item[0]) == item[1]
test "HttpTable behavior test":
var table1 = HttpTable.init()
var table2 = HttpTable.init([("Header1", "value1"), ("Header2", "value2")])
check:
table1.isEmpty() == true
table2.isEmpty() == false
table1.add("Header1", "value1")
table1.add("Header2", "value2")
table1.add("HEADER2", "VALUE3")
check:
table1.getList("HeAdEr2") == @["value2", "VALUE3"]
table1.getString("HeAdEr2") == "value2,VALUE3"
table2.getString("HEADER1") == "value1"
table1.count("HEADER2") == 2
table1.count("HEADER1") == 1
table1.getLastString("HEADER1") == "value1"
table1.getLastString("HEADER2") == "VALUE3"
"header1" in table1 == true
"HEADER1" in table1 == true
"header2" in table1 == true
"HEADER2" in table1 == true
"HEADER3" in table1 == false
var
data1: seq[tuple[key: string, value: string]]
data2: seq[tuple[key: string, value: seq[string]]]
for key, value in table1.stringItems(true):
data1.add((key, value))
for key, value in table1.items(true):
data2.add((key, value))
check:
data1 == @[("Header2", "value2"), ("Header2", "VALUE3"),
("Header1", "value1")]
data2 == @[("Header2", @["value2", "VALUE3"]),
("Header1", @["value1"])]
table1.set("header2", "value4")
check:
table1.getList("header2") == @["value4"]
table1.getString("header2") == "value4"
table1.count("header2") == 1
table1.getLastString("header2") == "value4"
test "getTransferEncoding() test": test "getTransferEncoding() test":
var encodings = [ var encodings = [
"chunked", "compress", "deflate", "gzip", "identity", "x-gzip" "chunked", "compress", "deflate", "gzip", "identity", "x-gzip"