diff --git a/tests/extensions/testcompression.nim b/tests/extensions/testcompression.nim index e6b2208..89b01e1 100644 --- a/tests/extensions/testcompression.nim +++ b/tests/extensions/testcompression.nim @@ -7,13 +7,13 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/os +import std/[os, strutils] import pkg/[chronos/unittest2/asynctests, stew/io2] import ../../websock/websock import ../../websock/extensions/compression/deflate const - dataFolder = "tests" / "extensions" / "data" + dataFolder = currentSourcePath.rsplit(os.DirSep, 1)[0] / "data" suite "permessage deflate compression": setup: diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index db10082..8096023 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -7,7 +7,10 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/strutils +import std/[ + random, + sequtils, + strutils] import pkg/[ httputils, chronos/unittest2/asynctests, @@ -352,6 +355,120 @@ suite "Test framing": discard await session.recvMsg(5) await waitForClose(session) + asyncTest "should serialize long messages": + const numMessages = 10 + let testData = newSeqWith(10 * 1024 * 1024, byte.rand()) + + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + + for i in 0 ..< numMessages: + try: + let message = await ws.recvMsg() + let matchesExpectedMessage = (message == testData) + check matchesExpectedMessage + except CatchableError: + fail() + + await waitForClose(ws) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await connectClient( + address = initTAddress("127.0.0.1:8888"), + frameSize = 1 * 1024 * 1024) + + var futs: seq[Future[void]] + for i in 0 ..< numMessages: + futs.add session.send(testData, Opcode.Binary) + await allFutures(futs) + await session.close() + + asyncTest "should handle cancellations": + const numMessages = 10 + let expectedNumMessages = numMessages - 1 + let testData = newSeqWith(10 * 1024 * 1024, byte.rand()) + + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + + for i in 0 ..< expectedNumMessages: + try: + let message = await ws.recvMsg() + let matchesExpectedMessage = (message == testData) + check matchesExpectedMessage + except CatchableError: + fail() + + expect WSClosedError: + discard await ws.recvMsg() # try to receive canceled message + + await waitForClose(ws) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await connectClient( + address = initTAddress("127.0.0.1:8888"), + frameSize = 1 * 1024 * 1024) + + var futs: seq[Future[void]] + for i in 0 ..< numMessages: + futs.add session.send(testData, Opcode.Binary) + futs[0].cancel() # expected to complete as it already started sending + futs[^2].cancel() # expected to be canceled as it has not started yet + await allFutures(futs) + await session.close() + + asyncTest "should prioritize control packets": + const numMessages = 10 + let testData = newSeqWith(10 * 1024 * 1024, byte.rand()) + + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + + expect WSClosedError: + discard await ws.recvMsg() + + await waitForClose(ws) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await connectClient( + address = initTAddress("127.0.0.1:8888"), + frameSize = 1 * 1024 * 1024) + + let messageFut = session.send(testData, Opcode.Binary) + + # interleave ping packets + var futs: seq[Future[void]] + for i in 0 ..< numMessages: + futs.add session.send(opcode = Opcode.Ping) + await allFutures(futs) + check not messageFut.finished + + # interleave close packet + await session.close() + check messageFut.finished + await messageFut + suite "Test Closing": setup: var diff --git a/websock/session.nim b/websock/session.nim index f243318..571e399 100644 --- a/websock/session.nim +++ b/websock/session.nim @@ -23,7 +23,7 @@ proc prepareCloseBody(code: StatusCodes, reason: string): seq[byte] = if ord(code) > 999: result = @(ord(code).uint16.toBytesBE()) & result -proc writeMessage*(ws: WSSession, +proc writeMessage(ws: WSSession, data: seq[byte] = @[], opcode: Opcode, maskKey: MaskKey, @@ -36,7 +36,7 @@ proc writeMessage*(ws: WSSession, let maxSize = ws.frameSize var i = 0 - while ws.readyState notin {ReadyState.Closing}: + while ws.readyState notin {ReadyState.Closing, ReadyState.Closed}: let canSend = min(data.len - i, maxSize) let frame = Frame( fin: if (canSend + i >= data.len): true else: false, @@ -55,7 +55,7 @@ proc writeMessage*(ws: WSSession, if i >= data.len: break -proc writeControl*( +proc writeControl( ws: WSSession, data: seq[byte] = @[], opcode: Opcode, @@ -89,11 +89,14 @@ proc writeControl*( trace "Wrote control frame" -proc send*( +func isControl(opcode: Opcode): bool = + opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary} + +proc nonCancellableSend( ws: WSSession, data: seq[byte] = @[], opcode: Opcode): Future[void] - {.async, raises: [Defect, WSClosedError].} = + {.async.} = ## Send a frame ## @@ -116,13 +119,61 @@ proc send*( else: default(MaskKey) - if opcode in {Opcode.Text, Opcode.Cont, Opcode.Binary}: - await ws.writeMessage( - data, opcode, maskKey, ws.extensions) + if opcode.isControl: + await ws.writeControl(data, opcode, maskKey) + else: + await ws.writeMessage(data, opcode, maskKey, ws.extensions) - return +proc doSend( + ws: WSSession, + data: seq[byte] = @[], + opcode: Opcode + ): Future[void] = + let + retFut = newFuture[void]("doSend") + sendFut = ws.nonCancellableSend(data, opcode) - await ws.writeControl(data, opcode, maskKey) + proc handleSend {.async.} = + try: + await sendFut + retFut.complete() + except CatchableError as exc: + retFut.fail(exc) + + asyncSpawn handleSend() + retFut + +proc sendLoop(ws: WSSession) {.gcsafe, async.} = + while ws.sendQueue.len > 0: + let task = ws.sendQueue.popFirst() + if task.fut.cancelled: + continue + + try: + await ws.doSend(task.data, task.opcode) + task.fut.complete() + except CatchableError as exc: + task.fut.fail(exc) + +proc send*( + ws: WSSession, + data: seq[byte] = @[], + opcode: Opcode): Future[void] = + if opcode.isControl: + # Control frames (see Section 5.5) MAY be injected in the middle of + # a fragmented message. Control frames themselves MUST NOT be + # fragmented. + # See RFC 6455 Section 5.4 Fragmentation + return ws.doSend(data, opcode) + + let fut = newFuture[void]("send") + + ws.sendQueue.addLast (data: data, opcode: opcode, fut: fut) + + if isNil(ws.sendLoop) or ws.sendLoop.finished: + ws.sendLoop = sendLoop(ws) + + fut proc send*( ws: WSSession, @@ -420,6 +471,10 @@ proc recvMsg*( trace "Read full message, breaking!" break + if ws.readyState == ReadyState.Closed: + # avoid reporting incomplete message + raise newException(WSClosedError, "WebSocket is closed!") + if not ws.binary and validateUTF8(res.toOpenArray(0, res.high)) == false: raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") diff --git a/websock/types.nim b/websock/types.nim index e843582..13def98 100644 --- a/websock/types.nim +++ b/websock/types.nim @@ -1,5 +1,5 @@ ## nim-websock -## Copyright (c) 2021 Status Research & Development GmbH +## Copyright (c) 2021-2022 Status Research & Development GmbH ## Licensed under either of ## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) ## * MIT license ([LICENSE-MIT](LICENSE-MIT)) @@ -9,6 +9,7 @@ {.push raises: [Defect].} +import std/deques import pkg/[chronos, chronos/streams/tlsstream, chronos/apps/http/httptable, @@ -16,6 +17,8 @@ import pkg/[chronos, stew/results] import ./utils +export deques + const SHA1DigestSize* = 20 WSHeaderSize* = 12 @@ -99,6 +102,14 @@ type reading*: bool proto*: string + # The fragments of one message MUST NOT be interleaved between the + # fragments of another message unless an extension has been + # negotiated that can interpret the interleaving. + # See RFC 6455 Section 5.4 Fragmentation + sendLoop*: Future[void] + sendQueue*: Deque[ + tuple[data: seq[byte], opcode: Opcode, fut: Future[void]]] + Ext* = ref object of RootObj name*: string session*: WSSession