diff --git a/examples/autobahn_client.nim b/examples/autobahn_client.nim index 6c7a0b3..5345d79 100644 --- a/examples/autobahn_client.nim +++ b/examples/autobahn_client.nim @@ -43,7 +43,7 @@ proc getCaseCount(): Future[int] {.async.} = block: try: let ws = await connectServer("/getCaseCount") - let buff = await ws.recv() + let buff = await ws.recvMsg() let dataStr = string.fromBytes(buff) caseCount = parseInt(dataStr) await ws.close() @@ -60,7 +60,7 @@ proc generateReport() {.async.} = trace "request autobahn server to generate report" let ws = await connectServer("/updateReports?agent=" & agent) while true: - let buff = await ws.recv() + let buff = await ws.recvMsg() if buff.len <= 0: break await ws.close() @@ -80,7 +80,7 @@ proc main() {.async.} = while ws.readystate != ReadyState.Closed: # echo back - let data = await ws.recv() + let data = await ws.recvMsg() let opCode = if ws.binary: Opcode.Binary else: diff --git a/examples/server.nim b/examples/server.nim index 84a8b00..2215dbc 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -28,7 +28,7 @@ proc handle(request: HttpRequest) {.async.} = trace "Websocket handshake completed" while ws.readyState != ReadyState.Closed: - let recvData = await ws.recv() + let recvData = await ws.recvMsg() trace "Client Response: ", size = recvData.len, binary = ws.binary if ws.readyState == ReadyState.Closed: diff --git a/tests/extensions/testcompression.nim b/tests/extensions/testcompression.nim index 72eeeea..ec95db5 100644 --- a/tests/extensions/testcompression.nim +++ b/tests/extensions/testcompression.nim @@ -35,7 +35,7 @@ suite "permessage deflate compression": let ws = await server.handleRequest(request) while ws.readyState != ReadyState.Closed: - let recvData = await ws.recv() + let recvData = await ws.recvMsg() if ws.readyState == ReadyState.Closed: break await ws.send(recvData, @@ -58,7 +58,7 @@ suite "permessage deflate compression": var recvData: seq[byte] while recvData.len < textData.len: - let res = await client.recv() + let res = await client.recvMsg() recvData.add res if client.readyState == ReadyState.Closed: break @@ -75,7 +75,7 @@ suite "permessage deflate compression": ) let ws = await server.handleRequest(request) while ws.readyState != ReadyState.Closed: - let recvData = await ws.recv() + let recvData = await ws.recvMsg() if ws.readyState == ReadyState.Closed: break await ws.send(recvData, @@ -98,7 +98,7 @@ suite "permessage deflate compression": var recvData: seq[byte] while recvData.len < binaryData.len: - let res = await client.recv() + let res = await client.recvMsg() recvData.add res if client.readyState == ReadyState.Closed: break diff --git a/tests/extensions/testexts.nim b/tests/extensions/testexts.nim index 8269cd5..96be628 100644 --- a/tests/extensions/testexts.nim +++ b/tests/extensions/testexts.nim @@ -30,7 +30,7 @@ suite "multiple extensions flow": factories = [hexFactory, base64Factory], ) let ws = await server.handleRequest(request) - let recvData = await ws.recv() + let recvData = await ws.recvMsg() await ws.send(recvData, if ws.binary: Opcode.Binary else: Opcode.Text) @@ -50,7 +50,7 @@ suite "multiple extensions flow": ) await client.send(testData) - let res = await client.recv() + let res = await client.recvMsg() check testData.toBytes() == res await client.close() @@ -62,7 +62,7 @@ suite "multiple extensions flow": factories = [hexFactory, base64Factory], ) let ws = await server.handleRequest(request) - let recvData = await ws.recv() + let recvData = await ws.recvMsg() await ws.send(recvData, if ws.binary: Opcode.Binary else: Opcode.Text) @@ -82,6 +82,6 @@ suite "multiple extensions flow": ) await client.send(testData) - let res = await client.recv() + let res = await client.recvMsg() check testData.toBytes() == res await client.close() diff --git a/tests/helpers.nim b/tests/helpers.nim index 2b65e4b..1a82a0a 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -37,7 +37,7 @@ proc rndBin*(size: int): seq[byte] = proc waitForClose*(ws: WSSession) {.async.} = try: while ws.readystate != ReadyState.Closed: - discard await ws.recv() + discard await ws.recvMsg() except CatchableError: trace "Closing websocket" diff --git a/tests/testutf8.nim b/tests/testutf8.nim index 1b5f38f..4443646 100644 --- a/tests/testutf8.nim +++ b/tests/testutf8.nim @@ -74,7 +74,7 @@ suite "UTF-8 DFA validator": proc waitForClose(ws: WSSession) {.async.} = try: while ws.readystate != ReadyState.Closed: - discard await ws.recv() + discard await ws.recvMsg() except CatchableError: trace "Closing websocket" @@ -96,7 +96,7 @@ suite "UTF-8 validator in action": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() check: string.fromBytes(res) == testData ws.binary == false @@ -135,7 +135,7 @@ suite "UTF-8 validator in action": let server = WSServer.new(protos = ["proto"], onClose = onClose) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() await waitForClose(ws) check: @@ -182,7 +182,7 @@ suite "UTF-8 validator in action": ) expect WSInvalidUTF8: - let data = await session.recv() + let data = await session.recvMsg() test "invalid UTF-8 sequence close code": let closeReason = "i want to close\xc0\xaf" @@ -207,4 +207,4 @@ suite "UTF-8 validator in action": ) expect WSInvalidUTF8: - let data = await session.recv() + let data = await session.recvMsg() diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index 9b3a241..d4d9487 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -109,7 +109,7 @@ suite "Test transmission": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check string.fromBytes(servRes) == testString await ws.waitForClose() @@ -129,7 +129,7 @@ suite "Test transmission": check request.uri.path == WSPath let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check string.fromBytes(servRes) == testString await ws.waitForClose() @@ -158,7 +158,7 @@ suite "Test transmission": flags = {ReuseAddr}) let session = await connectClient() - var clientRes = await session.recv() + var clientRes = await session.recvMsg() check string.fromBytes(clientRes) == testString await waitForClose(session) @@ -265,7 +265,7 @@ suite "Test ping-pong": let ws = await server.handleRequest(request) expect WSPayloadTooLarge: - discard await ws.recv() + discard await ws.recvMsg() await waitForClose(ws) @@ -334,7 +334,7 @@ suite "Test framing": let session = await connectClient() expect WSMaxMessageSizeError: - discard await session.recv(5) + discard await session.recvMsg(5) await waitForClose(session) suite "Test Closing": @@ -575,7 +575,7 @@ suite "Test Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check: servRes.len == 0 @@ -591,7 +591,7 @@ suite "Test Payload": let session = await connectClient() await session.send(emptyStr) - let clientRes = await session.recv() + let clientRes = await session.recvMsg() check: clientRes.len == 0 @@ -607,7 +607,7 @@ suite "Test Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) for _ in 0..<3: - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check: servRes.len == 0 @@ -629,7 +629,7 @@ suite "Test Payload": await session.send(emptyStr) for _ in 0..<3: - let clientRes = await session.recv() + let clientRes = await session.recvMsg() check: clientRes.len == 0 @@ -648,7 +648,7 @@ suite "Test Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let respData = await ws.recv() + let respData = await ws.recvMsg() check: string.fromBytes(respData) == testString @@ -706,7 +706,7 @@ suite "Test Payload": ) let ws = await server.handleRequest(request) - let respData = await ws.recv() + let respData = await ws.recvMsg() check: string.fromBytes(respData) == testString @@ -764,7 +764,7 @@ suite "Test Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() check ws.binary == false await ws.send(res, Opcode.Text) @@ -781,7 +781,7 @@ suite "Test Payload": ) await ws.send(testData) - let echoed = await ws.recv() + let echoed = await ws.recvMsg() await ws.close() check: @@ -800,7 +800,7 @@ suite "Test Binary message with Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check: servRes == emptyData @@ -825,7 +825,7 @@ suite "Test Binary message with Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let servRes = await ws.recv() + let servRes = await ws.recvMsg() check: servRes == emptyData @@ -858,7 +858,7 @@ suite "Test Binary message with Payload": ) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() check: res == testData ws.binary == true @@ -892,7 +892,7 @@ suite "Test Binary message with Payload": ) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() check: res == testData ws.binary == true @@ -921,7 +921,7 @@ suite "Test Binary message with Payload": let server = WSServer.new(protos = ["proto"]) let ws = await server.handleRequest(request) - let res = await ws.recv() + let res = await ws.recvMsg() check: ws.binary == true @@ -941,7 +941,7 @@ suite "Test Binary message with Payload": ) await ws.send(testData, Opcode.Binary) - let echoed = await ws.recv() + let echoed = await ws.recvMsg() check: echoed == testData @@ -951,3 +951,61 @@ suite "Test Binary message with Payload": check: echoed == testData ws.binary == true + +suite "Partial frames": + teardown: + server.stop() + await server.closeWait() + + proc lowLevelRecv( + senderFrameSize, receiverFrameSize, readChunkSize: int) {.async.} = + + const + howMuchWood = "How much wood could a wood chuck chuck ..." + + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + + let + server = WSServer.new(frameSize = receiverFrameSize) + ws = await server.handleRequest(request) + + var + res = newSeq[byte](howMuchWood.len) + pos = 0 + + while ws.readyState != ReadyState.Closed: + let read = await ws.recv(addr res[pos], min(res.len - pos, readChunkSize)) + pos += read + + if pos >= res.len: + break + + res.setlen(pos) + check res.len == howMuchWood.toBytes().len + check res == howMuchWood.toBytes() + await ws.waitForClose() + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await connectClient( + address = address, + frameSize = senderFrameSize) + + await session.send(howMuchWood) + await session.close() + + test "read in chunks less than sender frameSize": + await lowLevelRecv(7, 7, 5) + + test "read in chunks greater than sender frameSize": + await lowLevelRecv(3, 7, 5) + + test "sender frameSize greater than receiver": + await lowLevelRecv(7, 5, 5) + + test "receiver frameSize greater than sender": + await lowLevelRecv(7, 10, 5) diff --git a/websock/session.nim b/websock/session.nim index a79e660..7459384 100644 --- a/websock/session.nim +++ b/websock/session.nim @@ -315,23 +315,9 @@ proc recv*( var consumed = 0 var pbuffer = cast[ptr UncheckedArray[byte]](data) try: - var first = true - if not isNil(ws.frame): - if ws.frame.fin and ws.frame.remainder > 0: - trace "Continue reading from the same frame" - first = true - elif not ws.frame.fin and ws.frame.remainder > 0: - trace "Restarting reads in the middle of a frame in a multiframe message" - first = false - elif ws.frame.fin and ws.frame.remainder <= 0: - trace "Resetting an already consumed frame" - ws.frame = nil - elif not ws.frame.fin and ws.frame.remainder <= 0: - trace "No more bytes left and message EOF, resetting frame" - ws.frame = nil - if isNil(ws.frame): ws.frame = await ws.readFrame(ws.extensions) + ws.first = true while consumed < size: if isNil(ws.frame): @@ -339,7 +325,7 @@ proc recv*( break logScope: - first = first + first = ws.first fin = ws.frame.fin len = ws.frame.length consumed = ws.frame.consumed @@ -347,12 +333,12 @@ proc recv*( opcode = ws.frame.opcode masked = ws.frame.mask - if first == (ws.frame.opcode == Opcode.Cont): + if ws.first == (ws.frame.opcode == Opcode.Cont): error "Opcode mismatch!" raise newException(WSOpcodeMismatchError, - &"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}") + &"Opcode mismatch: first: {ws.first}, opcode: {ws.frame.opcode}") - if first: + if ws.first: ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag trace "Setting binary flag" @@ -370,18 +356,15 @@ proc recv*( # all has been consumed from the frame # read the next frame if ws.frame.remainder <= 0: - first = false + ws.first = false if ws.frame.fin: # we're at the end of the message, break trace "Read all frames, breaking" ws.frame = nil break + # read next frame ws.frame = await ws.readFrame(ws.extensions) - - if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false: - raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") - except CatchableError as exc: trace "Exception reading frames", exc = exc.msg ws.readyState = ReadyState.Closed @@ -396,42 +379,58 @@ proc recv*( return consumed -proc recv*( +proc recvMsg*( ws: WSSession, size = WSMaxMessageSize): Future[seq[byte]] {.async.} = ## Attempt to read a full message up to max `size` ## bytes in `frameSize` chunks. ## - ## If no `fin` flag arrives await until either - ## cancelled or the `fin` flag arrives. + ## If no `fin` flag arrives await until cancelled or + ## closed. ## ## If message is larger than `size` a `WSMaxMessageSizeError` ## exception is thrown. ## ## In all other cases it awaits a full message. ## - var res: seq[byte] - while ws.readyState != ReadyState.Closed: - var buf = newSeq[byte](min(size, ws.frameSize)) - let read = await ws.recv(addr buf[0], buf.len) + try: + var res: seq[byte] + while ws.readyState != ReadyState.Closed: + var buf = newSeq[byte](min(size, ws.frameSize)) + let read = await ws.recv(addr buf[0], buf.len) - buf.setLen(read) - if res.len + buf.len > size: - raise newException(WSMaxMessageSizeError, "Max message size exceeded") + buf.setLen(read) + if res.len + buf.len > size: + raise newException(WSMaxMessageSizeError, "Max message size exceeded") - trace "Read message", size = read - res.add(buf) + trace "Read message", size = read + res.add(buf) - # no more frames - if isNil(ws.frame): - break + # no more frames + if isNil(ws.frame): + break - # read the entire message, exit - if ws.frame.fin and ws.frame.remainder <= 0: - trace "Read full message, breaking!" - break + # read the entire message, exit + if ws.frame.fin and ws.frame.remainder <= 0: + trace "Read full message, breaking!" + break - return res + if not ws.binary and validateUTF8(res.toOpenArray(0, res.high)) == false: + raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") + + return res + except CatchableError as exc: + trace "Exception reading message", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + + raise exc + +proc recv*( + ws: WSSession, + size = WSMaxMessageSize): Future[seq[byte]] + {.deprecated: "deprecated in favor of recvMsg()".} = + ws.recvMsg(size) proc close*( ws: WSSession, @@ -451,6 +450,6 @@ proc close*( # read frames until closed while ws.readyState != ReadyState.Closed: - discard await ws.recv() + discard await ws.recvMsg() except CatchableError as exc: trace "Exception closing", exc = exc.msg diff --git a/websock/types.nim b/websock/types.nim index 1f172ef..9538bb2 100644 --- a/websock/types.nim +++ b/websock/types.nim @@ -79,11 +79,11 @@ type version*: uint key*: string readyState*: ReadyState - masked*: bool # send masked packets - binary*: bool # is payload binary? + masked*: bool # send masked packets + binary*: bool # is payload binary? flags*: set[TLSFlags] rng*: Rng - frameSize*: int + frameSize*: int # max frame buffer size onPing*: ControlCb onPong*: ControlCb onClose*: CloseCb @@ -91,6 +91,7 @@ type WSSession* = ref object of WebSocket stream*: AsyncStream frame*: Frame + first*: bool proto*: string Ext* = ref object of RootObj