From eb81018d02e87cf0fcfab15ada8fe5c9931958b1 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 18 Feb 2021 14:08:21 +0200 Subject: [PATCH] Address review comments and fix issues found. Adding more tests. --- chronos/apps/http/httpcommon.nim | 44 ++++++++++++ chronos/apps/http/httpserver.nim | 20 +++--- chronos/apps/http/httptable.nim | 62 ++++++++++++---- chronos/apps/http/multipart.nim | 65 ++++++++--------- chronos/streams/asyncstream.nim | 31 ++++---- chronos/streams/boundstream.nim | 24 +++++-- chronos/streams/chunkstream.nim | 21 ++++-- tests/testasyncstream.nim | 120 ++++++++++++++++++++++--------- tests/testhttpserver.nim | 119 +++++++++++++++++++++--------- 9 files changed, 349 insertions(+), 157 deletions(-) diff --git a/chronos/apps/http/httpcommon.nim b/chronos/apps/http/httpcommon.nim index 52e440e..45fafec 100644 --- a/chronos/apps/http/httpcommon.nim +++ b/chronos/apps/http/httpcommon.nim @@ -153,3 +153,47 @@ func getContentType*(ch: openarray[string]): HttpResult[string] {. else: let mparts = ch[0].split(";") ok(strip(mparts[0]).toLowerAscii()) + +proc bytesToString*(src: openarray[byte], dst: var openarray[char]) = + ## Convert array of bytes to array of characters. + ## + ## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`. + ## If this equation is not correct this procedures MUST not be used. + doAssert(len(src) == len(dst)) + if len(src) > 0: + copyMem(addr dst[0], unsafeAddr src[0], len(src)) + +proc stringToBytes*(src: openarray[char], dst: var openarray[byte]) = + ## Convert array of characters to array of bytes. + ## + ## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`. + ## If this equation is not correct this procedures MUST not be used. + doAssert(len(src) == len(dst)) + if len(src) > 0: + copyMem(addr dst[0], unsafeAddr src[0], len(src)) + +func bytesToString*(src: openarray[byte]): string = + ## Convert array of bytes to a string. + ## + ## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`. + ## If this equation is not correct this procedures MUST not be used. + var default: string + if len(src) > 0: + var dst = newString(len(src)) + bytesToString(src, dst) + dst + else: + default + +func stringToBytes*(src: openarray[char]): seq[byte] = + ## Convert string to sequence of bytes. + ## + ## Note, that this procedure assume that `sizeof(byte) == sizeof(char) == 1`. + ## If this equation is not correct this procedures MUST not be used. + var default: seq[byte] + if len(src) > 0: + var dst = newSeq[byte](len(src)) + stringToBytes(src, dst) + dst + else: + default diff --git a/chronos/apps/http/httpserver.nim b/chronos/apps/http/httpserver.nim index f538d2f..a53e699 100644 --- a/chronos/apps/http/httpserver.nim +++ b/chronos/apps/http/httpserver.nim @@ -29,7 +29,7 @@ type exc*: ref CatchableError remote*: TransportAddress - RequestFence*[T] = Result[T, HttpProcessError] + RequestFence* = Result[HttpRequestRef, HttpProcessError] HttpRequestFlags* {.pure.} = enum BoundBody, UnboundBody, MultipartForm, UrlencodedForm, @@ -42,7 +42,7 @@ type Empty, Prepared, Sending, Finished, Failed, Cancelled, Dumb HttpProcessCallback* = - proc(req: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {.gcsafe.} + proc(req: RequestFence): Future[HttpResponseRef] {.gcsafe.} HttpServer* = object of RootObj instance*: StreamServer @@ -507,7 +507,7 @@ proc createConnection(server: HttpServerRef, proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = var conn: HttpConnectionRef - connArg: RequestFence[HttpRequestRef] + connArg: RequestFence runLoop = false try: @@ -520,7 +520,7 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = except HttpCriticalError as exc: let error = HttpProcessError.init(HTTPServerError.CriticalError, exc, transp.remoteAddress(), exc.code) - connArg = RequestFence[HttpRequestRef].err(error) + connArg = RequestFence.err(error) runLoop = false if not(runLoop): @@ -538,33 +538,33 @@ proc processLoop(server: HttpServerRef, transp: StreamTransport) {.async.} = var breakLoop = false while runLoop: var - arg: RequestFence[HttpRequestRef] + arg: RequestFence resp: HttpResponseRef try: let request = await conn.getRequest().wait(server.headersTimeout) - arg = RequestFence[HttpRequestRef].ok(request) + arg = RequestFence.ok(request) except CancelledError: breakLoop = true except AsyncTimeoutError as exc: let error = HttpProcessError.init(HTTPServerError.TimeoutError, exc, transp.remoteAddress(), Http408) - arg = RequestFence[HttpRequestRef].err(error) + arg = RequestFence.err(error) except HttpRecoverableError as exc: let error = HttpProcessError.init(HTTPServerError.RecoverableError, exc, transp.remoteAddress(), exc.code) - arg = RequestFence[HttpRequestRef].err(error) + arg = RequestFence.err(error) except HttpCriticalError as exc: let error = HttpProcessError.init(HTTPServerError.CriticalError, exc, transp.remoteAddress(), exc.code) - arg = RequestFence[HttpRequestRef].err(error) + arg = RequestFence.err(error) except HttpDisconnectError: # If remote peer disconnected we just exiting loop breakLoop = true except CatchableError as exc: let error = HttpProcessError.init(HTTPServerError.CatchableError, exc, transp.remoteAddress(), Http500) - arg = RequestFence[HttpRequestRef].err(error) + arg = RequestFence.err(error) if breakLoop: break diff --git a/chronos/apps/http/httptable.nim b/chronos/apps/http/httptable.nim index 3dede5a..4b9a210 100644 --- a/chronos/apps/http/httptable.nim +++ b/chronos/apps/http/httptable.nim @@ -27,6 +27,8 @@ proc LT(x, y: uint32): uint32 {.inline.} = (z xor ((y xor x) and (y xor z))) shr 31 proc decValue(c: byte): int = + # Procedure returns values [0..9] for character [`0`..`9`] and -1 for all + # other characters. let x = uint32(c) - 0x30'u32 let r = ((x + 1'u32) and -LT(x, 10)) int(r) - 1 @@ -46,39 +48,42 @@ proc bytesToDec*[T: byte|char](src: openarray[T]): uint64 = let nv = ((v shl 3) + (v shl 1)) + uint64(d) if nv < v: # overflow happened - return v + return 0xFFFF_FFFF_FFFF_FFFF'u64 else: v = nv v proc add*(ht: var HttpTables, key: string, value: string) = + ## Add string ``value`` to header with key ``key``. var default: seq[string] - let lowkey = key.toLowerAscii() - var nitem = @[value] - if ht.table.hasKeyOrPut(lowkey, nitem): - var oitem = ht.table.getOrDefault(lowkey, default) - oitem.add(value) - ht.table[lowkey] = oitem + ht.table.mgetOrPut(key.toLowerAscii(), default).add(value) proc add*(ht: var HttpTables, key: string, value: SomeInteger) = + ## Add integer ``value`` to header with key ``key``. ht.add(key, $value) proc set*(ht: var HttpTables, key: string, value: string) = + ## Set/replace value of header with key ``key`` to value ``value``. let lowkey = key.toLowerAscii() ht.table[lowkey] = @[value] -proc contains*(ht: var HttpTables, key: string): bool {. - raises: [Defect].} = +proc contains*(ht: var HttpTables, key: string): bool = + ## Returns ``true`` if header with name ``key`` is present in HttpTable/Ref. ht.table.contains(key.toLowerAscii()) proc getList*(ht: HttpTables, key: string, default: openarray[string] = []): seq[string] = + ## Returns sequence of headers with key ``key``. var defseq = @default ht.table.getOrDefault(key.toLowerAscii(), defseq) proc getString*(ht: HttpTables, key: string, default: string = ""): string = - var defseq = newSeq[string]() + ## Returns concatenated value of headers with key ``key``. + ## + ## If there multiple headers with the same name ``key`` the result value will + ## be concatenation using `,`. + var defseq: seq[string] let res = ht.table.getOrDefault(key.toLowerAscii(), defseq) if len(res) == 0: return default @@ -86,13 +91,33 @@ proc getString*(ht: HttpTables, key: string, res.join(",") proc count*(ht: HttpTables, key: string): int = + ## Returns number of headers with key ``key``. var default: seq[string] - len(ht.table.getOrDefault(key, default)) + len(ht.table.getOrDefault(key.toLowerAscii(), default)) proc getInt*(ht: HttpTables, key: string): uint64 = + ## Parse header with key ``key`` as unsigned integer. + ## + ## Integers are parsed in safe way, there no exceptions or errors will be + ## raised. + ## + ## If a non-decimal character is encountered during the parsing of the string + ## the current accumulated value will be returned. So if string starts with + ## non-decimal character, procedure will always return `0` (for example "-1" + ## will be decoded as `0`). But if non-decimal character will be encountered + ## later, only decimal part will be decoded, like `1234_5678` will be decoded + ## as `1234`. + ## Also, if in the parsing process result exceeds `uint64` maximum allowed + ## value, then `0xFFFF_FFFF_FFFF_FFFF'u64` will be returned (for example + ## `18446744073709551616` will be decoded as `18446744073709551615` because it + ## overflows uint64 maximum value of `18446744073709551615`). bytesToDec(ht.getString(key)) proc getLastString*(ht: HttpTables, key: string): string = + ## Returns "last" value of header ``key``. + ## + ## If there multiple headers with the same name ``key`` the value of last + ## encountered header will be returned. var default: seq[string] let item = ht.table.getOrDefault(key.toLowerAscii(), default) if len(item) == 0: @@ -101,16 +126,25 @@ proc getLastString*(ht: HttpTables, key: string): string = item[^1] proc getLastInt*(ht: HttpTables, key: string): uint64 = + ## Returns "last" value of header ``key`` as unsigned integer. + ## + ## If there multiple headers with the same name ``key`` the value of last + ## encountered header will be returned. + ## + ## Unsigned integer will be parsed using rules of getInt() procedure. bytesToDec(ht.getLastString()) proc init*(htt: typedesc[HttpTable]): HttpTable = + ## Create empty HttpTable. HttpTable(table: initTable[string, seq[string]]()) proc new*(htt: typedesc[HttpTableRef]): HttpTableRef = + ## Create empty HttpTableRef. HttpTableRef(table: initTable[string, seq[string]]()) proc init*(htt: typedesc[HttpTable], data: openArray[tuple[key: string, value: string]]): HttpTable = + ## Create HttpTable using array of tuples with header names and values. var res = HttpTable.init() for item in data: res.add(item.key, item.value) @@ -118,6 +152,7 @@ proc init*(htt: typedesc[HttpTable], proc new*(htt: typedesc[HttpTableRef], data: openArray[tuple[key: string, value: string]]): HttpTableRef = + ## Create HttpTableRef using array of tuples with header names and values. var res = HttpTableRef.new() for item in data: res.add(item.key, item.value) @@ -152,7 +187,7 @@ proc normalizeHeaderName*(value: string): string = iterator stringItems*(ht: HttpTables, normKey = false): tuple[key: string, value: string] = - ## Iterate over HttpTable values. + ## Iterate over HttpTable/Ref values. ## ## If ``normKey`` is true, key name value will be normalized using ## normalizeHeaderName() procedure. @@ -163,7 +198,7 @@ iterator stringItems*(ht: HttpTables, iterator items*(ht: HttpTables, normKey = false): tuple[key: string, value: seq[string]] = - ## Iterate over HttpTable values. + ## Iterate over HttpTable/Ref values. ## ## If ``normKey`` is true, key name value will be normalized using ## normalizeHeaderName() procedure. @@ -172,6 +207,7 @@ iterator items*(ht: HttpTables, yield (key, v) proc `$`*(ht: HttpTables): string = + ## Returns string representation of HttpTable/Ref. var res = "" for key, value in ht.table.pairs(): for item in value: diff --git a/chronos/apps/http/multipart.nim b/chronos/apps/http/multipart.nim index 88f4dd3..275df74 100644 --- a/chronos/apps/http/multipart.nim +++ b/chronos/apps/http/multipart.nim @@ -52,6 +52,8 @@ type proc startsWith(s, prefix: openarray[byte]): bool {. raises: [Defect].} = + # This procedure is copy of strutils.startsWith() procedure, however, + # it is intended to work with arrays of bytes, but not with strings. var i = 0 while true: if i >= len(prefix): return true @@ -60,6 +62,8 @@ proc startsWith(s, prefix: openarray[byte]): bool {. proc parseUntil(s, until: openarray[byte]): int {. raises: [Defect].} = + # This procedure is copy of parseutils.parseUntil() procedure, however, + # it is intended to work with arrays of bytes, but not with strings. var i = 0 while i < len(s): if len(until) > 0 and s[i] == until[0]: @@ -127,7 +131,15 @@ proc new*[B: BChar](mpt: typedesc[MultiPartReaderRef], ## ``stream`` is stream used to read data. ## ``boundary`` is multipart boundary, this value must not be empty. ## ``partHeadersMaxSize`` is maximum size of multipart's headers. - doAssert(len(boundary) > 0) + # According to specification length of boundary must be bigger then `0` and + # less or equal to `70`. + doAssert(len(boundary) > 0 and len(boundary) <= 70) + # 256 bytes is minimum value because we going to use single buffer for + # reading boundaries and for reading headers. + # Minimal buffer value for boundary is 5 bytes, maximum is 74 bytes. But at + # least one header should be present "Content-Disposition", so minimum value + # of multipart headers will be near 150 bytes. + doAssert(partHeadersMaxSize >= 256) # Our internal boundary has format `<-><->`, so we can # reuse different parts of this sequence for processing. var fboundary = newSeq[byte](len(boundary) + 4) @@ -266,8 +278,7 @@ proc closeWait*(mpr: MultiPartReaderRef) {.async.} = else: discard -proc getBytes*(mp: MultiPart): seq[byte] {. - raises: [Defect].} = +proc getBytes*(mp: MultiPart): seq[byte] {.raises: [Defect].} = ## Returns value for MultiPart ``mp`` as sequence of bytes. case mp.kind of MultiPartSource.Buffer: @@ -276,28 +287,16 @@ proc getBytes*(mp: MultiPart): seq[byte] {. doAssert(not(mp.stream.atEof()), "Value is not obtained yet") mp.buffer -proc getString*(mp: MultiPart): string {. - raises: [Defect].} = +proc getString*(mp: MultiPart): string {.raises: [Defect].} = ## Returns value for MultiPart ``mp`` as string. case mp.kind of MultiPartSource.Buffer: - if len(mp.buffer) > 0: - var res = newString(len(mp.buffer)) - copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer)) - res - else: - "" + bytesToString(mp.buffer) of MultiPartSource.Stream: doAssert(not(mp.stream.atEof()), "Value is not obtained yet") - if len(mp.buffer) > 0: - var res = newString(len(mp.buffer)) - copyMem(addr res[0], unsafeAddr mp.buffer[0], len(mp.buffer)) - res - else: - "" + bytesToString(mp.buffer) -proc atEoM*(mpr: var MultiPartReader): bool {. - raises: [Defect].} = +proc atEoM*(mpr: var MultiPartReader): bool {.raises: [Defect].} = ## Procedure returns ``true`` if MultiPartReader has reached the end of ## multipart message. case mpr.kind @@ -306,8 +305,7 @@ proc atEoM*(mpr: var MultiPartReader): bool {. of MultiPartSource.Stream: mpr.stream.atEof() -proc atEoM*(mpr: MultiPartReaderRef): bool {. - raises: [Defect].} = +proc atEoM*(mpr: MultiPartReaderRef): bool {.raises: [Defect].} = ## Procedure returns ``true`` if MultiPartReader has reached the end of ## multipart message. case mpr.kind @@ -422,7 +420,7 @@ func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {. ## 2) `Content-Type` must be ``multipart/form-data``. ## 3) `boundary` value must be present ## 4) `boundary` value must be less then 70 characters length and - ## all characters should be part of alphabet. + ## all characters should be part of specific alphabet. if len(ch) > 1: err("Multiple Content-Type headers found") else: @@ -455,17 +453,14 @@ func getMultipartBoundary*(ch: openarray[string]): HttpResult[string] {. if len(bparts) < 2: err("Missing Content-Type boundary") else: - if bparts[0].toLowerAscii() != "boundary": - err("Missing boundary key") + let candidate = strip(bparts[1]) + if len(candidate) == 0: + err("Content-Type boundary must be at least 1 character size") + elif len(candidate) > 70: + err("Content-Type boundary must be less then 70 characters") else: - let candidate = strip(bparts[1]) - if len(candidate) == 0: - err("Content-Type boundary must be at least 1 character size") - elif len(candidate) > 70: - err("Content-Type boundary must be less then 70 characters") - else: - for ch in candidate: - if ch notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9', - '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}: - return err("Content-Type boundary alphabet incorrect") - ok(candidate) + for ch in candidate: + if ch notin {'a' .. 'z', 'A' .. 'Z', '0' .. '9', + '\'' .. ')', '+' .. '/', ':', '=', '?', '_'}: + return err("Content-Type boundary alphabet incorrect") + ok(candidate) diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index e25d108..ce7b1f7 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -42,11 +42,11 @@ type WriteItem* = object case kind*: WriteType of Pointer: - data1*: pointer + dataPtr*: pointer of Sequence: - data2*: seq[byte] + dataSeq*: seq[byte] of String: - data3*: string + dataStr*: string size*: int offset*: int future*: Future[void] @@ -96,12 +96,11 @@ type AsyncStreamRW* = AsyncStreamReader | AsyncStreamWriter proc init*(t: typedesc[AsyncBuffer], size: int): AsyncBuffer = - var res = AsyncBuffer( + AsyncBuffer( buffer: newSeq[byte](size), events: [newAsyncEvent(), newAsyncEvent()], offset: 0 ) - res proc getBuffer*(sb: AsyncBuffer): pointer {.inline.} = unsafeAddr sb.buffer[sb.offset] @@ -171,12 +170,12 @@ template toBufferOpenArray*(sb: AsyncBuffer): auto = template copyOut*(dest: pointer, item: WriteItem, length: int) = if item.kind == Pointer: - let p = cast[pointer](cast[uint](item.data1) + uint(item.offset)) + let p = cast[pointer](cast[uint](item.dataPtr) + uint(item.offset)) copyMem(dest, p, length) elif item.kind == Sequence: - copyMem(dest, unsafeAddr item.data2[item.offset], length) + copyMem(dest, unsafeAddr item.dataSeq[item.offset], length) elif item.kind == String: - copyMem(dest, unsafeAddr item.data3[item.offset], length) + copyMem(dest, unsafeAddr item.dataStr[item.offset], length) proc newAsyncStreamReadError(p: ref CatchableError): ref AsyncStreamReadError {. noinline.} = @@ -226,7 +225,7 @@ template checkStreamClosed*(t: untyped) = proc atEof*(rstream: AsyncStreamReader): bool = ## Returns ``true`` is reading stream is closed or finished and internal ## buffer do not have any bytes left. - rstream.state in {AsyncStreamState.Stopped, Finished, Closed} and + rstream.state in {AsyncStreamState.Stopped, Finished, Closed, Error} and (rstream.buffer.dataLen() == 0) proc atEof*(wstream: AsyncStreamWriter): bool = @@ -327,13 +326,13 @@ template readLoop(body: untyped): untyped = raise rstream.error let (consumed, done) = body - rstream.buffer.shift(consumed) rstream.bytesCount = rstream.bytesCount + uint64(consumed) if done: break else: - await rstream.buffer.wait() + if not(rstream.atEof()): + await rstream.buffer.wait() proc readExactly*(rstream: AsyncStreamReader, pbytes: pointer, nbytes: int) {.async.} = @@ -711,7 +710,7 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, wstream.bytesCount = wstream.bytesCount + uint64(nbytes) else: var item = WriteItem(kind: Pointer) - item.data1 = pbytes + item.dataPtr = pbytes item.size = nbytes item.future = newFuture[void]("async.stream.write(pointer)") try: @@ -758,9 +757,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], else: var item = WriteItem(kind: Sequence) if not isLiteral(sbytes): - shallowCopy(item.data2, sbytes) + shallowCopy(item.dataSeq, sbytes) else: - item.data2 = sbytes + item.dataSeq = sbytes item.size = length item.future = newFuture[void]("async.stream.write(seq)") try: @@ -806,9 +805,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, else: var item = WriteItem(kind: String) if not isLiteral(sbytes): - shallowCopy(item.data3, sbytes) + shallowCopy(item.dataStr, sbytes) else: - item.data3 = sbytes + item.dataStr = sbytes item.size = length item.future = newFuture[void]("async.stream.write(string)") try: diff --git a/chronos/streams/boundstream.nim b/chronos/streams/boundstream.nim index e927768..afe9166 100644 --- a/chronos/streams/boundstream.nim +++ b/chronos/streams/boundstream.nim @@ -117,9 +117,23 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = await upload(addr rstream.buffer, addr buffer[0], length) rstream.state = AsyncStreamState.Finished else: + if (res < toRead) and rstream.rsource.atEof(): + case rstream.cmpop + of BoundCmp.Equal: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() + of BoundCmp.LessOrEqual: + rstream.state = AsyncStreamState.Finished rstream.offset = rstream.offset + res await upload(addr rstream.buffer, addr buffer[0], res) else: + if (res < toRead) and rstream.rsource.atEof(): + case rstream.cmpop + of BoundCmp.Equal: + rstream.state = AsyncStreamState.Error + rstream.error = newBoundedStreamIncompleteError() + of BoundCmp.LessOrEqual: + rstream.state = AsyncStreamState.Finished rstream.offset = rstream.offset + res await upload(addr rstream.buffer, addr buffer[0], res) else: @@ -146,10 +160,6 @@ proc boundedReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Finished break - # Without this additional wait, procedures such as `read()` could got stuck - # in `await.buffer.wait()` because procedures are unable to detect EOF while - # inside readLoop body. - await stepsAsync(1) # We need to notify consumer about error/close, but we do not care about # incoming data anymore. rstream.buffer.forget() @@ -170,11 +180,11 @@ proc boundedWriteLoop(stream: AsyncStreamWriter) {.async.} = # Writing chunk data. case item.kind of WriteType.Pointer: - await wstream.wsource.write(item.data1, item.size) + await wstream.wsource.write(item.dataPtr, item.size) of WriteType.Sequence: - await wstream.wsource.write(addr item.data2[0], item.size) + await wstream.wsource.write(addr item.dataSeq[0], item.size) of WriteType.String: - await wstream.wsource.write(addr item.data3[0], item.size) + await wstream.wsource.write(addr item.dataStr[0], item.size) wstream.offset = wstream.offset + item.size item.future.complete() else: diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index 0cf61fd..59e92c9 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -49,11 +49,18 @@ proc getChunkSize(buffer: openarray[byte]): Result[uint64, cstring] = # We using `uint64` representation, but allow only 2^32 chunk size, # ChunkHeaderSize. var res = 0'u64 - for i in 0 ..< min(len(buffer), ChunkHeaderSize): + for i in 0 ..< min(len(buffer), ChunkHeaderSize + 1): let value = hexValue(buffer[i]) if value < 0: - return err("Incorrect chunk size encoding") - res = (res shl 4) or uint64(value) + if buffer[i] == byte(';'): + # chunk-extension is present, so chunk size is already decoded in res. + return ok(res) + else: + return err("Incorrect chunk size encoding") + else: + if i >= ChunkHeaderSize: + return err("The chunk size exceeds the limit") + res = (res shl 4) or uint64(value) ok(res) proc setChunkSize(buffer: var openarray[byte], length: int64): int = @@ -135,7 +142,7 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = rstream.state = AsyncStreamState.Error rstream.error = exc - if rstream.state in {AsyncStreamState.Stopped, AsyncStreamState.Error}: + if rstream.state != AsyncStreamState.Running: # We need to notify consumer about error/close, but we do not care about # incoming data anymore. rstream.buffer.forget() @@ -161,11 +168,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = # Writing chunk data. case item.kind of WriteType.Pointer: - await wstream.wsource.write(item.data1, item.size) + await wstream.wsource.write(item.dataPtr, item.size) of WriteType.Sequence: - await wstream.wsource.write(addr item.data2[0], item.size) + await wstream.wsource.write(addr item.dataSeq[0], item.size) of WriteType.String: - await wstream.wsource.write(addr item.data3[0], item.size) + await wstream.wsource.write(addr item.dataStr[0], item.size) # Writing chunk footer CRLF. await wstream.wsource.write(CRLF) # Everything is fine, completing queue item's future. diff --git a/tests/testasyncstream.nim b/tests/testasyncstream.nim index 44462af..9bce6bf 100644 --- a/tests/testasyncstream.nim +++ b/tests/testasyncstream.nim @@ -506,7 +506,10 @@ suite "ChunkedStream test suite": "--f98f0\r\nContent-Disposition: form-data; name=\"key3\"" & "\r\n\r\nC\r\n" & "--f98f0--\r\n" - ] + ], + ["4;position=1\r\nWiki\r\n5;position=2\r\npedia\r\nE;position=3\r\n" & + " in\r\n\r\nchunks.\r\n0;position=4\r\n\r\n", + "Wikipedia in\r\n\r\nchunks."], ] proc checkVector(address: TransportAddress, inputstr: string): Future[string] {.async.} = @@ -545,7 +548,16 @@ suite "ChunkedStream test suite": check waitFor(testVectors(initTAddress("127.0.0.1:46001"))) == true test "ChunkedStream incorrect chunk test": const BadVectors = [ - ["100000000 \r\n1"], + ["10000000;\r\n1"], + ["10000000\r\n1"], + ["FFFFFFFF;extension1=value1;extension2=value2\r\n1"], + ["FFFFFFFF\r\n1"], + ["100000000\r\n1"], + ["10000000 \r\n1"], + ["100000000 ;\r\n"], + ["FFFFFFFF0\r\n1"], + ["FFFFFFFF \r\n1"], + ["FFFFFFFF ;\r\n1"], ["z\r\n1"] ] proc checkVector(address: TransportAddress, @@ -571,11 +583,36 @@ suite "ChunkedStream test suite": var r = await rstream2.read() doAssert(len(r) > 0) except ChunkedStreamIncompleteError: - if inputstr == "100000000 \r\n1": + case inputstr + of "10000000;\r\n1": res = true + of "10000000\r\n1": + res = true + of "FFFFFFFF;extension1=value1;extension2=value2\r\n1": + res = true + of "FFFFFFFF\r\n1": + res = true + else: + res = false except ChunkedStreamProtocolError: - if inputstr == "z\r\n1": + case inputstr + of "100000000\r\n1": res = true + of "10000000 \r\n1": + res = true + of "100000000 ;\r\n": + res = true + of "z\r\n1": + res = true + of "FFFFFFFF0\r\n1": + res = true + of "FFFFFFFF \r\n1": + res = true + of "FFFFFFFF ;\r\n1": + res = true + else: + res = false + await rstream2.closeWait() await rstream.closeWait() await transp.closeWait() @@ -687,6 +724,13 @@ suite "TLSStream test suite": suite "BoundedStream test suite": + type + BoundarySizeTest = enum + SizeReadWrite, SizeOverflow, SizeIncomplete, SizeEmpty + BoundaryBytesTest = enum + BoundaryRead, BoundaryDouble, BoundarySize, BoundaryIncomplete, + BoundaryEmpty + proc createBigMessage(size: int): seq[byte] = var message = "ABCDEFGHIJKLMNOP" var res = newSeq[byte](size) @@ -697,8 +741,8 @@ suite "BoundedStream test suite": for itemComp in [BoundCmp.Equal, BoundCmp.LessOrEqual]: for itemSize in [100, 60000]: - proc boundaryTest(address: TransportAddress, test: int, size: int, - boundary: seq[byte], + proc boundaryTest(address: TransportAddress, btest: BoundaryBytesTest, + size: int, boundary: seq[byte], cmp: BoundCmp): Future[bool] {.async.} = var message = createBigMessage(size) var clientRes = false @@ -706,20 +750,21 @@ suite "BoundedStream test suite": proc processClient(server: StreamServer, transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) - if test == 0: + case btest + of BoundaryRead: await wstream.write(message) await wstream.write(boundary) await wstream.finish() await wstream.closeWait() clientRes = true - elif test == 1: + of BoundaryDouble: await wstream.write(message) await wstream.write(boundary) await wstream.write(message) await wstream.finish() await wstream.closeWait() clientRes = true - elif test == 2: + of BoundarySize: var ncmessage = message ncmessage.setLen(len(message) - 2) await wstream.write(ncmessage) @@ -727,14 +772,14 @@ suite "BoundedStream test suite": await wstream.finish() await wstream.closeWait() clientRes = true - elif test == 3: + of BoundaryIncomplete: var ncmessage = message ncmessage.setLen(len(message) - 2) await wstream.write(ncmessage) await wstream.finish() await wstream.closeWait() clientRes = true - elif test == 4: + of BoundaryEmpty: await wstream.write(boundary) await wstream.finish() await wstream.closeWait() @@ -750,20 +795,21 @@ suite "BoundedStream test suite": server.start() var conn = await connect(address) var rstream = newAsyncStreamReader(conn) - if test == 0: + case btest + of BoundaryRead: var rbstream = newBoundedStreamReader(rstream, -1, boundary) let response = await rbstream.read() if response == message: res = true await rbstream.closeWait() - elif test == 1: + of BoundaryDouble: var rbstream = newBoundedStreamReader(rstream, -1, boundary) let response1 = await rbstream.read() await rbstream.closeWait() let response2 = await rstream.read() if (response1 == message) and (response2 == message): res = true - elif test == 2: + of BoundarySize: var expectMessage = message expectMessage[^2] = 0x2D'u8 expectMessage[^1] = 0x2D'u8 @@ -772,14 +818,14 @@ suite "BoundedStream test suite": await rbstream.closeWait() if (len(response) == size) and response == expectMessage: res = true - elif test == 3: + of BoundaryIncomplete: var rbstream = newBoundedStreamReader(rstream, -1, boundary) try: let response {.used.} = await rbstream.read() except BoundedStreamIncompleteError: res = true await rbstream.closeWait() - elif test == 4: + of BoundaryEmpty: var rbstream = newBoundedStreamReader(rstream, -1, boundary) let response = await rbstream.read() await rbstream.closeWait() @@ -791,7 +837,7 @@ suite "BoundedStream test suite": await server.join() return (res and clientRes) - proc boundedTest(address: TransportAddress, test: int, + proc boundedTest(address: TransportAddress, stest: BoundarySizeTest, size: int, cmp: BoundCmp): Future[bool] {.async.} = var clientRes = false var res = false @@ -805,13 +851,14 @@ suite "BoundedStream test suite": transp: StreamTransport) {.async.} = var wstream = newAsyncStreamWriter(transp) var wbstream = newBoundedStreamWriter(wstream, size, comparison = cmp) - if test == 0: + case stest + of SizeReadWrite: for i in 0 ..< 10: await wbstream.write(messagePart) await wbstream.finish() await wbstream.closeWait() clientRes = true - elif test == 1: + of SizeOverflow: for i in 0 ..< 10: await wbstream.write(messagePart) try: @@ -819,7 +866,7 @@ suite "BoundedStream test suite": except BoundedStreamOverflowError: clientRes = true await wbstream.closeWait() - elif test == 2: + of SizeIncomplete: for i in 0 ..< 9: await wbstream.write(messagePart) case cmp @@ -835,7 +882,7 @@ suite "BoundedStream test suite": except BoundedStreamIncompleteError: discard await wbstream.closeWait() - elif test == 3: + of SizeEmpty: case cmp of BoundCmp.Equal: try: @@ -861,17 +908,18 @@ suite "BoundedStream test suite": var conn = await connect(address) var rstream = newAsyncStreamReader(conn) var rbstream = newBoundedStreamReader(rstream, size, comparison = cmp) - if test == 0: + case stest + of SizeReadWrite: let response = await rbstream.read() await rbstream.closeWait() if response == message: res = true - elif test == 1: + of SizeOverflow: let response = await rbstream.read() await rbstream.closeWait() if response == message: res = true - elif test == 2: + of SizeIncomplete: case cmp of BoundCmp.Equal: try: @@ -886,7 +934,7 @@ suite "BoundedStream test suite": except BoundedStreamIncompleteError: res = false await rbstream.closeWait() - elif test == 3: + of SizeEmpty: case cmp of BoundCmp.Equal: try: @@ -916,30 +964,34 @@ suite "BoundedStream test suite": "<= " & $itemSize test "BoundedStream(size) reading/writing test [" & suffix & "]": - check waitFor(boundedTest(address, 0, itemSize, itemComp)) == true + check waitFor(boundedTest(address, SizeReadWrite, itemSize, + itemComp)) == true test "BoundedStream(size) overflow test [" & suffix & "]": - check waitFor(boundedTest(address, 1, itemSize, itemComp)) == true + check waitFor(boundedTest(address, SizeOverflow, itemSize, + itemComp)) == true test "BoundedStream(size) incomplete test [" & suffix & "]": - check waitFor(boundedTest(address, 2, itemSize, itemComp)) == true + check waitFor(boundedTest(address, SizeIncomplete, itemSize, + itemComp)) == true test "BoundedStream(size) empty message test [" & suffix & "]": - check waitFor(boundedTest(address, 3, itemSize, itemComp)) == true + check waitFor(boundedTest(address, SizeEmpty, itemSize, + itemComp)) == true test "BoundedStream(boundary) reading test [" & suffix & "]": - check waitFor(boundaryTest(address, 0, itemSize, + check waitFor(boundaryTest(address, BoundaryRead, itemSize, @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream(boundary) double message test [" & suffix & "]": - check waitFor(boundaryTest(address, 1, itemSize, + check waitFor(boundaryTest(address, BoundaryDouble, itemSize, @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream(size+boundary) reading size-bound test [" & suffix & "]": - check waitFor(boundaryTest(address, 2, itemSize, + check waitFor(boundaryTest(address, BoundarySize, itemSize, @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream(boundary) reading incomplete test [" & suffix & "]": - check waitFor(boundaryTest(address, 3, itemSize, + check waitFor(boundaryTest(address, BoundaryIncomplete, itemSize, @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream(boundary) empty message test [" & suffix & "]": - check waitFor(boundaryTest(address, 4, itemSize, + check waitFor(boundaryTest(address, BoundaryEmpty, itemSize, @[0x2D'u8, 0x2D'u8, 0x2D'u8], itemComp)) test "BoundedStream leaks test": diff --git a/tests/testhttpserver.nim b/tests/testhttpserver.nim index 3e6e809..549471b 100644 --- a/tests/testhttpserver.nim +++ b/tests/testhttpserver.nim @@ -69,6 +69,10 @@ N8r5CwGcIX/XPC3lKazzbZ8baA== """ suite "HTTP server testing suite": + type + TooBigTest = enum + GetBodyTest, ConsumeBodyTest, PostUrlTest, PostMultipartTest + proc httpClient(address: TransportAddress, data: string): Future[string] {.async.} = var transp: StreamTransport @@ -77,10 +81,7 @@ suite "HTTP server testing suite": if len(data) > 0: let wres {.used.} = await transp.write(data) var rres = await transp.read() - var sres = newString(len(rres)) - if len(rres) > 0: - copyMem(addr sres[0], addr rres[0], len(rres)) - return sres + return bytesToString(rres) except CatchableError: return "EXCEPTION" finally: @@ -104,10 +105,7 @@ suite "HTTP server testing suite": if len(data) > 0: await tlsstream.writer.write(data) var rres = await tlsstream.reader.read() - var sres = newString(len(rres)) - if len(rres) > 0: - copyMem(addr sres[0], addr rres[0], len(rres)) - return sres + return bytesToString(rres) except CatchableError: return "EXCEPTION" finally: @@ -119,20 +117,21 @@ suite "HTTP server testing suite": transp.closeWait()) proc testTooBigBodyChunked(address: TransportAddress, - operation: int): Future[bool] {.async.} = + operation: TooBigTest): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() try: - if operation == 0: + case operation + of GetBodyTest: let body {.used.} = await request.getBody() - elif operation == 1: + of ConsumeBodyTest: await request.consumeBody() - elif operation == 2: + of PostUrlTest: let ptable {.used.} = await request.post() - elif operation == 3: + of PostMultipartTest: let ptable {.used.} = await request.post() except HttpCriticalError as exc: if exc.code == Http413: @@ -153,14 +152,15 @@ suite "HTTP server testing suite": server.start() let request = - if operation in [0, 1, 2]: + case operation + of GetBodyTest, ConsumeBodyTest, PostUrlTest: "POST / HTTP/1.0\r\n" & "Content-Type: application/x-www-form-urlencoded\r\n" & "Transfer-Encoding: chunked\r\n" & "Cookie: 2\r\n\r\n" & "5\r\na=a&b\r\n5\r\n=b&c=\r\n4\r\nc&d=\r\n4\r\n%D0%\r\n" & "2\r\n9F\r\n0\r\n\r\n" - elif operation in [3]: + of PostMultipartTest: "POST / HTTP/1.0\r\n" & "Host: 127.0.0.1:30080\r\n" & "Transfer-Encoding: chunked\r\n" & @@ -173,8 +173,6 @@ suite "HTTP server testing suite": "\r\n\r\nC\r\n\r\n" & "b\r\n--f98f0--\r\n\r\n" & "0\r\n\r\n" - else: - "" let data = await httpClient(address, request) await server.stop() @@ -184,7 +182,7 @@ suite "HTTP server testing suite": test "Request headers timeout test": proc testTimeout(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -213,7 +211,7 @@ suite "HTTP server testing suite": test "Empty headers test": proc testEmpty(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -241,7 +239,7 @@ suite "HTTP server testing suite": test "Too big headers test": proc testTooBig(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -271,7 +269,7 @@ suite "HTTP server testing suite": test "Too big request body test (content-length)": proc testTooBigBody(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): discard @@ -300,24 +298,28 @@ suite "HTTP server testing suite": test "Too big request body test (getBody()/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 0)) == true + waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), + GetBodyTest)) == true test "Too big request body test (consumeBody()/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 1)) == true + waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), + ConsumeBodyTest)) == true test "Too big request body test (post()/urlencoded/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 2)) == true + waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), + PostUrlTest)) == true test "Too big request body test (post()/multipart/chunked encoding)": check: - waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), 3)) == true + waitFor(testTooBigBodyChunked(initTAddress("127.0.0.1:30080"), + PostMultipartTest)) == true test "Query arguments test": proc testQuery(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -357,7 +359,7 @@ suite "HTTP server testing suite": test "Headers test": proc testHeaders(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -400,7 +402,7 @@ suite "HTTP server testing suite": test "POST arguments (urlencoded/content-length) test": proc testPostUrl(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): var kres = newSeq[string]() @@ -443,7 +445,7 @@ suite "HTTP server testing suite": test "POST arguments (urlencoded/chunked encoding) test": proc testPostUrl2(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): var kres = newSeq[string]() @@ -487,7 +489,7 @@ suite "HTTP server testing suite": test "POST arguments (multipart/content-length) test": proc testPostMultipart(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): var kres = newSeq[string]() @@ -542,7 +544,7 @@ suite "HTTP server testing suite": test "POST arguments (multipart/chunked encoding) test": proc testPostMultipart2(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): var kres = newSeq[string]() @@ -606,7 +608,7 @@ suite "HTTP server testing suite": test "HTTPS server (successful handshake) test": proc testHTTPS(address: TransportAddress): Future[bool] {.async.} = var serverRes = false - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -644,7 +646,7 @@ suite "HTTP server testing suite": proc testHTTPS2(address: TransportAddress): Future[bool] {.async.} = var serverRes = false var testFut = newFuture[void]() - proc process(r: RequestFence[HttpRequestRef]): Future[HttpResponseRef] {. + proc process(r: RequestFence): Future[HttpResponseRef] {. async.} = if r.isOk(): let request = r.get() @@ -758,7 +760,9 @@ suite "HTTP server testing suite": ("", 0'u64), ("0", 0'u64), ("-0", 0'u64), ("0-", 0'u64), ("01", 1'u64), ("001", 1'u64), ("0000000000001", 1'u64), ("18446744073709551615", 0xFFFF_FFFF_FFFF_FFFF'u64), - ("18446744073709551616", 1844674407370955161'u64), + ("18446744073709551616", 0xFFFF_FFFF_FFFF_FFFF'u64), + ("99999999999999999999", 0xFFFF_FFFF_FFFF_FFFF'u64), + ("999999999999999999999999999999999999", 0xFFFF_FFFF_FFFF_FFFF'u64), ("FFFFFFFFFFFFFFFF", 0'u64), ("0123456789ABCDEF", 123456789'u64) ] @@ -770,6 +774,51 @@ suite "HTTP server testing suite": for item in TestVectors: check bytesToDec(item[0]) == item[1] + test "HttpTable behavior test": + var table1 = HttpTable.init() + var table2 = HttpTable.init([("Header1", "value1"), ("Header2", "value2")]) + check: + table1.isEmpty() == true + table2.isEmpty() == false + + table1.add("Header1", "value1") + table1.add("Header2", "value2") + table1.add("HEADER2", "VALUE3") + check: + table1.getList("HeAdEr2") == @["value2", "VALUE3"] + table1.getString("HeAdEr2") == "value2,VALUE3" + table2.getString("HEADER1") == "value1" + table1.count("HEADER2") == 2 + table1.count("HEADER1") == 1 + table1.getLastString("HEADER1") == "value1" + table1.getLastString("HEADER2") == "VALUE3" + "header1" in table1 == true + "HEADER1" in table1 == true + "header2" in table1 == true + "HEADER2" in table1 == true + "HEADER3" in table1 == false + + var + data1: seq[tuple[key: string, value: string]] + data2: seq[tuple[key: string, value: seq[string]]] + for key, value in table1.stringItems(true): + data1.add((key, value)) + for key, value in table1.items(true): + data2.add((key, value)) + + check: + data1 == @[("Header2", "value2"), ("Header2", "VALUE3"), + ("Header1", "value1")] + data2 == @[("Header2", @["value2", "VALUE3"]), + ("Header1", @["value1"])] + + table1.set("header2", "value4") + check: + table1.getList("header2") == @["value4"] + table1.getString("header2") == "value4" + table1.count("header2") == 1 + table1.getLastString("header2") == "value4" + test "getTransferEncoding() test": var encodings = [ "chunked", "compress", "deflate", "gzip", "identity", "x-gzip"