diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a26716..e60d1c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: - target: os: windows builder: windows-2019 - name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})' + name: "${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})" runs-on: ${{ matrix.builder }} steps: - name: Checkout nim-ws diff --git a/examples/server.nim b/examples/server.nim index fceee46..7950bb7 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -1,8 +1,7 @@ - import pkg/[chronos, +import pkg/[chronos, chronos/apps/http/httpserver, chronicles, - httputils, - stew/byteutils] + httputils] import ../ws/ws @@ -13,22 +12,18 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if request.uri.path == "/ws": debug "Initiating web socket connection." try: - var ws = await createServer(request,"") + let ws = await createServer(request, "") if ws.readyState != Open: error "Failed to open websocket connection." return debug "Websocket handshake completed." - while ws.readyState != ReadyState.Closed: - # Only reads header for data frame. - var recvData = await ws.recv() - if recvData.len <= 0: - debug "Empty messages" + while true: + let recvData = await ws.recv() + if ws.readyState == ReadyState.Closed: + debug "Websocket closed." break - - # debug "Client Response: ", data = string.fromBytes(recvData), size = recvData.len debug "Client Response: ", size = recvData.len await ws.send(recvData) - # await ws.close() except WebSocketError as exc: error "WebSocket error:", exception = exc.msg diff --git a/tests/testframes.nim b/tests/testframes.nim index 18b32be..dbb9318 100644 --- a/tests/testframes.nim +++ b/tests/testframes.nim @@ -1,7 +1,7 @@ import unittest include ../ws/ws -include ../ws/random +include ../ws/utils # TODO: Fix Test. diff --git a/tests/testtlswebsockets.nim b/tests/testtlswebsockets.nim index 455d371..78c1ea0 100644 --- a/tests/testtlswebsockets.nim +++ b/tests/testtlswebsockets.nim @@ -2,22 +2,32 @@ import std/strutils, httputils import pkg/[asynctest, chronos, + chronicles, chronos/apps/http/shttpserver, stew/byteutils] -import ../ws/[ws, stream], +import ../ws/[ws, stream], ../examples/tlsserver import ./keys -var server: SecureHttpServerRef -let address = initTAddress("127.0.0.1:8888") -let serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect} -let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} -let clientFlags = {NoVerifyHost, NoVerifyServerName} +proc waitForClose(ws: WebSocket) {.async.} = + try: + while ws.readystate != ReadyState.Closed: + discard await ws.recv() + except CatchableError: + debug "Closing websocket" + +var server: SecureHttpServerRef + +let + address = initTAddress("127.0.0.1:8888") + serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect} + socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + clientFlags = {NoVerifyHost, NoVerifyServerName} + secureKey = TLSPrivateKey.init(SecureKey) + secureCert = TLSCertificate.init(SecureCert) -let secureKey = TLSPrivateKey.init(SecureKey) -let secureCert = TLSCertificate.init(SecureCert) suite "Test websocket TLS handshake": teardown: @@ -31,10 +41,7 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" expect WSProtoMismatchError: - var ws = await createServer(request, "proto") - check ws.readyState == ReadyState.Closed - - return await request.respond(Http200, "Connection established") + discard await createServer(request, "proto") let res = SecureHttpServerRef.new( address, cb, @@ -62,10 +69,7 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" expect WSVersionError: - var ws = await createServer(request, "proto") - check ws.readyState == ReadyState.Closed - - return await request.respond(Http200, "Connection established") + discard await createServer(request, "proto") let res = SecureHttpServerRef.new( address, cb, @@ -91,14 +95,17 @@ suite "Test websocket TLS handshake": check r.isOk() let request = r.get() check request.uri.path == "/wss" - check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() - check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii() - check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii() + check request.headers.getString("Connection").toUpperAscii() == + "Upgrade".toUpperAscii() + check request.headers.getString("Upgrade").toUpperAscii() == + "websocket".toUpperAscii() + check request.headers.getString("Cache-Control").toUpperAscii() == + "no-cache".toUpperAscii() check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion check request.headers.contains("Sec-WebSocket-Key") - discard await request.respond( Http200,"Connection established") + discard await request.respond(Http200, "Connection established") let res = SecureHttpServerRef.new( address, cb, @@ -133,7 +140,7 @@ suite "Test websocket TLS transmission": let ws = await createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == testString - await ws.close() + await waitForClose(ws) return dumbResponse() let res = SecureHttpServerRef.new( @@ -158,41 +165,7 @@ suite "Test websocket TLS transmission": test "Client - test reading simple frame": let testString = "Hello!" - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return dumbResponse() - - let request = r.get() - check request.uri.path == "/wss" - let ws = await createServer(request, "proto") - let servRes = await ws.recv() - check string.fromBytes(servRes) == testString - await ws.close() - return dumbResponse() - - let res = SecureHttpServerRef.new( - address, cb, - serverFlags = serverFlags, - socketFlags = socketFlags, - tlsPrivateKey = secureKey, - tlsCertificate = secureCert) - - server = res.get() - server.start() - - let wsClient = await WebSocket.tlsConnect( - "127.0.0.1", - Port(8888), - path = "/wss", - protocols = @["proto"], - clientFlags) - - await wsClient.send(testString) - await wsClient.close() - - test "Client - test reading simple frame": - let testString = "Hello!" - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isErr(): return dumbResponse() @@ -222,4 +195,4 @@ suite "Test websocket TLS transmission": var clientRes = await wsClient.recv() check string.fromBytes(clientRes) == testString - await wsClient.close() + await waitForClose(wsClient) diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index 5a12f8d..526a58e 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -1,458 +1,827 @@ -import std/strutils,httputils +import std/[strutils, random], httputils import pkg/[asynctest, chronos, chronos/apps/http/httpserver, + chronicles, stew/byteutils] -import ../ws/[ws, stream] +import ../ws/[ws, stream] + +include ../ws/ws var server: HttpServerRef let address = initTAddress("127.0.0.1:8888") +proc rndStr*(size: int): string = + for _ in .. size: + add(result, char(rand(int('A') .. int('z')))) + +proc waitForClose(ws: WebSocket) {.async.} = + try: + while ws.readystate != ReadyState.Closed: + discard await ws.recv() + except CatchableError: + debug "Closing websocket" + suite "Test handshake": - teardown: - await server.stop() - await server.closeWait() + teardown: + await server.stop() + await server.closeWait() - test "Test for incorrect protocol": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + test "Test for incorrect protocol": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - expect WSProtoMismatchError: - var ws = await createServer(request, "proto") - check ws.readyState == ReadyState.Closed + let request = r.get() + check request.uri.path == "/ws" + expect WSProtoMismatchError: + discard await createServer(request, "proto") - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - expect WSFailedUpgradeError: - discard await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["wrongproto"]) + expect WSFailedUpgradeError: + discard await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["wrongproto"]) - test "Test for incorrect version": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + test "Test for incorrect version": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return - let request = r.get() - check request.uri.path == "/ws" - expect WSVersionError: - var ws = await createServer(request, "proto") - check ws.readyState == ReadyState.Closed + let request = r.get() + check request.uri.path == "/ws" + expect WSVersionError: + discard await createServer(request, "proto") - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - expect WSFailedUpgradeError: - discard await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["wrongproto"], - version = 14) + expect WSFailedUpgradeError: + discard await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["wrongproto"], + version = 14) - test "Test for client headers": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + test "Test for client headers": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() - check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii() - check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii() - check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion + let request = r.get() + check request.uri.path == "/ws" + check request.headers.getString("Connection").toUpperAscii() == + "Upgrade".toUpperAscii() + check request.headers.getString("Upgrade").toUpperAscii() == + "websocket".toUpperAscii() + check request.headers.getString("Cache-Control").toUpperAscii() == + "no-cache".toUpperAscii() + check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion - check request.headers.contains("Sec-WebSocket-Key") + check request.headers.contains("Sec-WebSocket-Key") - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + discard await request.respond(Http200, "Connection established") - expect WSFailedUpgradeError: - discard await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - test "Test for incorrect scheme": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + expect WSFailedUpgradeError: + discard await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) - let request = r.get() - check request.uri.path == "/ws" + test "Test for incorrect scheme": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - expect WSProtoMismatchError: - var ws = await createServer(request, "proto") - check ws.readyState == ReadyState.Closed + let request = r.get() + check request.uri.path == "/ws" - return await request.respond(Http200, "Connection established") + expect WSProtoMismatchError: + var ws = await createServer(request, "proto") + check ws.readyState == ReadyState.Closed - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + return await request.respond(Http200, "Connection established") - let uri = "wx://127.0.0.1:8888/ws" - expect WSWrongUriSchemeError: - discard await WebSocket.connect( - parseUri(uri), - protocols = @["proto"]) + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() + + let uri = "wx://127.0.0.1:8888/ws" + expect WSWrongUriSchemeError: + discard await WebSocket.connect( + parseUri(uri), + protocols = @["proto"]) + + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false suite "Test transmission": - teardown: - await server.closeWait() + teardown: + await server.closeWait() - test "Server - test reading simple frame": - let testString = "Hello!" - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + test "Send text message message with payload of length 65535": + let testString = rndStr(65535) + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == testString - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - let servRes = await ws.recv() + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() - check string.fromBytes(servRes) == testString - await ws.close() + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + await wsClient.send(testString) + await wsClient.close() - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + test "Server - test reading simple frame": + let testString = "Hello!" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == testString + await waitForClose(ws) - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) - await wsClient.send(testString) - await wsClient.close() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - test "Client - test reading simple frame": - let testString = "Hello!" - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + await wsClient.send(testString) + await wsClient.close() - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - await ws.send(testString) - await ws.close() + test "Client - test reading simple frame": + let testString = "Hello!" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + await ws.send(testString) + await ws.close() - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - var clientRes = await wsClient.recv() - await wsClient.close() - check string.fromBytes(clientRes) == testString + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + var clientRes = await wsClient.recv() + check string.fromBytes(clientRes) == testString + await waitForClose(wsClient) + + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false suite "Test ping-pong": - teardown: - await server.closeWait() + teardown: + await server.closeWait() + test "Send text Message fragmented into 2 fragments, one ping with payload in-between": + var ping, pong = false + let testString = "1234567890" + let msg = toBytes(testString) + let maxFrameSize = 5 - test "Server - test ping-pong control messages": - var ping, pong = false - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto", + onPing = proc() = + ping = true + ) - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer( - request, - "proto", + let respData = await ws.recv() + check string.fromBytes(respData) == testString + await waitForClose(ws) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + frameSize = maxFrameSize, onPong = proc() = - pong = true - ) - - await ws.ping() - await ws.close() - - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() - - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"], - onPing = proc() = - ping = true + pong = true ) - discard await wsClient.recv() - check: - ping - pong + let maskKey = genMaskKey(newRng()) + let encframe = encodeFrame(Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + data: msg[0..4], + maskKey: maskKey)) - test "Client - test ping-pong control messages": - var ping, pong = false - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return + await wsClient.stream.writer.write(encframe) + await wsClient.ping() + let encframe1 = encodeFrame(Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Cont, + mask: true, + data: msg[5..9], + maskKey: maskKey)) - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer( - request, - "proto", + await wsClient.stream.writer.write(encframe1) + await wsClient.close() + check: + ping + pong + test "Server - test ping-pong control messages": + var ping, pong = false + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto", + onPong = proc() = + pong = true + ) + + await ws.ping() + await ws.close() + + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], onPing = proc() = - ping = true - ) - - discard await ws.recv() - await ws.close() - - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() - - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"], - onPong = proc() = - pong = true + ping = true ) - await wsClient.ping() - await wsClient.close() - check: - ping - pong + await waitForClose(wsClient) + check: + ping + pong + + test "Client - test ping-pong control messages": + var ping, pong = false + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto", + onPing = proc() = + ping = true + ) + await waitForClose(ws) + check: + ping + pong + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onPong = proc() = + pong = true + ) + + await wsClient.ping() + await wsClient.close() + + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false suite "Test framing": - teardown: - await server.closeWait() + teardown: + await server.closeWait() - test "should split message into frames": - let testString = "1234567890" - proc cb(r: RequestFence): Future[HttpResponseRef]{.async.} = - if r.isErr(): - return dumbResponse() + test "should split message into frames": + let testString = "1234567890" + proc cb(r: RequestFence): Future[HttpResponseRef]{.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" + let request = r.get() + check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - let frame1 = await ws.readFrame() - check not isNil(frame1) - var data1 = newSeq[byte](frame1.remainder().int) - let read1 = await ws.stream.reader.readOnce(addr data1[0], data1.len) - check read1 == 5 + let ws = await createServer(request, "proto") + let frame1 = await ws.readFrame() + check not isNil(frame1) + var data1 = newSeq[byte](frame1.remainder().int) + let read1 = await ws.stream.reader.readOnce(addr data1[0], data1.len) + check read1 == 5 - let frame2 = await ws.readFrame() - check not isNil(frame2) - var data2 = newSeq[byte](frame2.remainder().int) - let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len) - check read2 == 5 + let frame2 = await ws.readFrame() + check not isNil(frame2) + var data2 = newSeq[byte](frame2.remainder().int) + let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len) + check read2 == 5 - await ws.close() - return dumbResponse() + await waitForClose(ws) - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"], - frameSize = 5) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + frameSize = 5) - await wsClient.send(testString) - await wsClient.close() + await wsClient.send(testString) + await wsClient.close() - test "should fail to read past max message size": - let testString = "1234567890" - proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = - if r.isErr(): - return dumbResponse() + test "should fail to read past max message size": + let testString = "1234567890" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - await ws.send(testString) - await ws.close() - return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + await ws.send(testString) + await ws.close() - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) - expect WSMaxMessageSizeError: - discard await wsClient.recv(5) + expect WSMaxMessageSizeError: + discard await wsClient.recv(5) + await waitForClose(wsClient) - await wsClient.close() + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false suite "Test Closing": - teardown: - await server.closeWait() + teardown: + await server.closeWait() - test "Server closing": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return dumbResponse() + test "Server closing": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - await ws.close() - return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + await ws.close() - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) - discard await wsClient.recv() - check wsClient.readyState == ReadyState.Closed + await waitForClose(wsClient) + check wsClient.readyState == ReadyState.Closed - test "Server closing with status": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return dumbResponse() + test "Server closing with status": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - proc closeServer(status: Status, reason: string): CloseResult - {.gcsafe, raises: [Defect].} = - try: - check status == Status.TooLarge - check reason == "Message too big!" - except Exception as exc: - raise newException(Defect, exc.msg) + let request = r.get() + check request.uri.path == "/ws" + proc closeServer(status: Status, reason: string): CloseResult{.gcsafe, + raises: [Defect].} = + try: + check status == Status.TooLarge + check reason == "Message too big!" + except Exception as exc: + raise newException(Defect, exc.msg) - return (Status.Fulfilled, "") + return (Status.Fulfilled, "") - let ws = await createServer( - request, - "proto", - onClose = closeServer) + let ws = await createServer( + request, + "proto", + onClose = closeServer) - await ws.close() - return dumbResponse() + await ws.close() - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - proc clientClose(status: Status, reason: string): CloseResult - {.gcsafe, raises: [Defect].} = - try: - check status == Status.Fulfilled - return (Status.TooLarge, "Message too big!") - except Exception as exc: - raise newException(Defect, exc.msg) + proc clientClose(status: Status, reason: string): CloseResult {.gcsafe, + raises: [Defect].} = + try: + check status == Status.Fulfilled + return (Status.TooLarge, "Message too big!") + except Exception as exc: + raise newException(Defect, exc.msg) - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"], - onClose = clientClose) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onClose = clientClose) - discard await wsClient.recv() - check wsClient.readyState == ReadyState.Closed + await waitForClose(wsClient) + check wsClient.readyState == ReadyState.Closed - test "Client closing": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - if r.isErr(): - return dumbResponse() + test "Client closing": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - let ws = await createServer(request, "proto") - discard await ws.recv() - await ws.close() - return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + await waitForClose(ws) - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"]) - await wsClient.close() + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + await wsClient.close() - test "Client closing with status": - proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = - if r.isErr(): - return dumbResponse() + test "Client closing with status": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = + if r.isErr(): + return dumbResponse() - let request = r.get() - check request.uri.path == "/ws" - proc closeServer(status: Status, reason: string): CloseResult - {.gcsafe, raises: [Defect].} = - try: - check status == Status.Fulfilled - return (Status.TooLarge, "Message too big!") - except Exception as exc: - raise newException(Defect, exc.msg) + let request = r.get() + check request.uri.path == "/ws" + proc closeServer(status: Status, reason: string): CloseResult{.gcsafe, + raises: [Defect].} = + try: + check status == Status.Fulfilled + return (Status.TooLarge, "Message too big!") + except Exception as exc: + raise newException(Defect, exc.msg) - let ws = await createServer( - request, - "proto", - onClose = closeServer) - discard await ws.recv() - await ws.close() - return dumbResponse() + let ws = await createServer( + request, + "proto", + onClose = closeServer) + await waitForClose(ws) - let res = HttpServerRef.new(address, cb) - server = res.get() - server.start() + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() - proc clientClose(status: Status, reason: string): CloseResult - {.gcsafe, raises: [Defect].} = - try: - check status == Status.TooLarge - check reason == "Message too big!" - return (Status.Fulfilled, "") - except Exception as exc: - raise newException(Defect, exc.msg) + proc clientClose(status: Status, reason: string): CloseResult {.gcsafe, + raises: [Defect].} = + try: + check status == Status.TooLarge + check reason == "Message too big!" + return (Status.Fulfilled, "") + except Exception as exc: + raise newException(Defect, exc.msg) - let wsClient = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws", - protocols = @["proto"], - onClose = clientClose) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onClose = clientClose) - await wsClient.close() - check wsClient.readyState == ReadyState.Closed + await wsClient.close() + check wsClient.readyState == ReadyState.Closed + + test "Server closing with valid close code 3999": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto") + + await ws.close(code = Status.ReservedCode) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + proc closeClient(status: Status, reason: string): CloseResult{.gcsafe, + raises: [Defect].} = + try: + check status == Status.ReservedCode + return (Status.ReservedCode, "Reserved Status") + except Exception as exc: + raise newException(Defect, exc.msg) + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onClose = closeClient) + + await waitForClose(wsClient) + + test "Client closing with valid close code 3999": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + + proc closeServer(status: Status, reason: string): CloseResult{.gcsafe, + raises: [Defect].} = + try: + check status == Status.ReservedCode + return (Status.ReservedCode, "Reserved Status") + except Exception as exc: + raise newException(Defect, exc.msg) + + let ws = await createServer( + request, + "proto", + onClose = closeServer) + + await waitForClose(ws) + return dumbResponse() + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + await wsClient.close(code = Status.ReservedCode) + + test "Server closing with Payload of length 2": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + # Close with payload of length 2 + await ws.close(reason = "HH") + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + await waitForClose(wsClient) + + test "Client closing with Payload of length 2": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + await waitForClose(ws) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + # Close with payload of length 2 + await wsClient.close(reason = "HH") + + + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false + +suite "Test Payload": + teardown: + await server.closeWait() + + test "Test payload message length": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto") + + expect WSPayloadTooLarge: + discard await ws.recv() + await waitForClose(ws) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let str = rndStr(126) + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + await wsClient.send(toBytes(str), Opcode.Ping) + await wsClient.close() + + test "Test single empty payload": + let emptyStr = "" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == emptyStr + await waitForClose(ws) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + await wsClient.send(emptyStr) + await wsClient.close() + + test "Test multiple empty payload": + let emptyStr = "" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == emptyStr + await waitForClose(ws) + + let res = HttpServerRef.new( + address, cb) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"]) + + for i in 0..3: + await wsClient.send(emptyStr) + await wsClient.close() + + test "Send ping with small text payload": + let testData = toBytes("Hello, world!") + var ping, pong = false + proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + let request = r.get() + check request.uri.path == "/ws" + let ws = await createServer( + request, + "proto", + onPing = proc() = + ping = true + ) + + await waitForClose(ws) + + let res = HttpServerRef.new( + address, process) + server = res.get() + server.start() + + let wsClient = await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws", + protocols = @["proto"], + onPong = proc() = + pong = true + ) + + await wsClient.send(testData, Opcode.Ping) + await wsClient.close() + check: + ping + pong + test "AsyncStream leaks test": + check: + getTracker("async.stream.reader").isLeaked() == false + getTracker("async.stream.writer").isLeaked() == false + getTracker("stream.server").isLeaked() == false + getTracker("stream.transport").isLeaked() == false diff --git a/ws/stream.nim b/ws/stream.nim index fdd05df..bb48c1b 100644 --- a/ws/stream.nim +++ b/ws/stream.nim @@ -1,4 +1,4 @@ - import pkg/[chronos, +import pkg/[chronos, chronos/apps/http/httpserver, chronos/timer, chronicles, @@ -8,7 +8,7 @@ import strutils const HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) - MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets + MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = var buffer = newSeq[byte](MaxHttpHeadersSize) @@ -45,7 +45,11 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = return buffer -proc closeWait*(wsStream : AsyncStream) {.async.} = +proc closeWait*(wsStream: AsyncStream) {.async.} = + + await allFutures( + wsStream.writer.closeWait(), + wsStream.reader.closeWait()) await allFutures( wsStream.writer.tsource.closeWait(), wsStream.reader.tsource.closeWait()) diff --git a/ws/random.nim b/ws/utils.nim similarity index 100% rename from ws/random.nim rename to ws/utils.nim diff --git a/ws/ws.nim b/ws/ws.nim index c35f9d5..1c88130 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -18,7 +18,7 @@ import pkg/[chronos, stew/base10, nimcrypto/sha] -import ./random, ./stream +import ./utils, ./stream #[ +---------------------------------------------------------------+ @@ -47,7 +47,7 @@ const WSHeaderSize* = 12 WSDefaultVersion* = 13 WSDefaultFrameSize* = 1 shl 20 # 1mb - WSMaxMessageSize* = 20 shl 20 # 20mb + WSMaxMessageSize* = 20 shl 20 # 20mb WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" CRLF* = "\r\n" @@ -71,7 +71,13 @@ type WSMaxMessageSizeError* = object of WebSocketError WSClosedError* = object of WebSocketError WSSendError* = object of WebSocketError - WSPayloadTooLarge = object of WebSocketError + WSPayloadTooLarge* = object of WebSocketError + WSReserverdOpcodeError* = object of WebSocketError + WSFragmentedControlFrameError* = object of WebSocketError + WSInvalidCloseCodeError* = object of WebSocketError + WSPayloadLengthError* = object of WebSocketError + WSInvalidOpcodeError* = object of WebSocketError + Base16Error* = object of CatchableError ## Base16 specific exception type @@ -111,21 +117,21 @@ type TooLarge = 1009 NoExtensions = 1010 UnexpectedError = 1011 - TlsError # use by clients - # 3000-3999 reserved for libs - # 4000-4999 reserved for applications + ReservedCode = 3999 # use by clients + # 3000-3999 reserved for libs + # 4000-4999 reserved for applications Frame = ref object - fin: bool ## Indicates that this is the final fragment in a message. - rsv1: bool ## MUST be 0 unless negotiated that defines meanings - rsv2: bool ## MUST be 0 - rsv3: bool ## MUST be 0 - opcode: Opcode ## Defines the interpretation of the "Payload data". - mask: bool ## Defines whether the "Payload data" is masked. - data: seq[byte] ## Payload data - maskKey: array[4, char] ## Masking key - length: uint64 ## Message size. - consumed: uint64 ## how much has been consumed from the frame + fin: bool ## Indicates that this is the final fragment in a message. + rsv1: bool ## MUST be 0 unless negotiated that defines meanings + rsv2: bool ## MUST be 0 + rsv3: bool ## MUST be 0 + opcode: Opcode ## Defines the interpretation of the "Payload data". + mask: bool ## Defines whether the "Payload data" is masked. + data: seq[byte] ## Payload data + maskKey: array[4, char] ## Masking key + length: uint64 ## Message size. + consumed: uint64 ## how much has been consumed from the frame ControlCb* = proc() {.gcsafe, raises: [Defect].} @@ -156,11 +162,11 @@ template remainder*(frame: Frame): uint64 = proc `$`(ht: HttpTables): string = ## Returns string representation of HttpTable/Ref. var res = "" - for key,value in ht.stringItems(true): - res.add(key.normalizeHeaderName()) - res.add(": ") - res.add(value) - res.add(CRLF) + for key, value in ht.stringItems(true): + res.add(key.normalizeHeaderName()) + res.add(": ") + res.add(value) + res.add(CRLF) ## add for end of header mark res.add(CRLF) @@ -209,10 +215,11 @@ proc handshake*( wantProtocol & ")") let cKey = ws.key & WSGuid - let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data) + let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, + cKey.high)).data) var headerData = [ ("Connection", "Upgrade"), - ("Upgrade", "webSocket" ), + ("Upgrade", "webSocket"), ("Sec-WebSocket-Accept", acceptKey)] var headers = HttpTable.init(headerData) @@ -222,7 +229,8 @@ proc handshake*( try: discard await request.respond(httputils.Http101, "", headers) except CatchableError as exc: - raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) + raise newException(WSHandshakeError, + "Failed to sent handshake response. Error: " & exc.msg) ws.readyState = ReadyState.Open proc createServer*( @@ -330,39 +338,47 @@ proc send*( maskKey = genMaskKey(ws.rng) if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: + if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}: + return await ws.stream.writer.write(encodeFrame(Frame( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: opcode, - mask: ws.masked, - data: data, # allow sending data with close messages - maskKey: maskKey))) + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: opcode, + mask: ws.masked, + data: data, # allow sending data with close messages + maskKey: maskKey))) return let maxSize = ws.frameSize var i = 0 - while i < data.len: + while ws.readyState notin {ReadyState.Closing}: let len = min(data.len, (maxSize + i)) - let inFrame = Frame( - fin: if (i + len >= data.len): true else: false, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames - mask: ws.masked, - data: data[i ..< len], - maskKey: maskKey) + let encFrame = encodeFrame(Frame( + fin: if (i + len >= data.len): true else: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames + mask: ws.masked, + data: data[i ..< len], + maskKey: maskKey)) - await ws.stream.writer.write(encodeFrame(inFrame)) + await ws.stream.writer.write(encFrame) i += len + if i >= data.len: + break + proc send*(ws: WebSocket, data: string): Future[void] = send(ws, toBytes(data), Opcode.Text) -proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = +proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = + + if ws.readyState notin {ReadyState.Open}: + return logScope: fin = frame.fin masked = frame.mask @@ -370,46 +386,46 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = serverState = ws.readyState debug "Handling close sequence" - if ws.readyState == ReadyState.Open or ws.readyState == ReadyState.Closing: - # Read control frame payload. - var data = newSeq[byte](frame.length) - if frame.length > 0: - # Read the data. - await ws.stream.reader.readExactly(addr data[0], int frame.length) - unmask(data.toOpenArray(0, data.high), frame.maskKey) + var + code = Status.Fulfilled + reason = "" - var code: Status - if data.len > 0: - let ccode = uint16.fromBytesBE(data[0..<2]) # first two bytes are the status - doAssert(ccode > 999, "No valid code in close message!") + if payLoad.len == 1: + raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!") + elif payLoad.len > 1: + # first two bytes are the status + let ccode = uint16.fromBytesBE(payLoad[0..<2]) + if ccode <= 999 or ccode > 1015: + raise newException(WSInvalidCloseCodeError, "Invalid code in close message!") + try: code = Status(ccode) - data = data[2..data.high] + except RangeError: + code = Status.Fulfilled + # remining payload bytes are reason for closing + reason = string.fromBytes(payLoad[2..payLoad.high]) - var rcode = Status.Fulfilled - var reason = "" - if not isNil(ws.onClose): - try: - (rcode, reason) = ws.onClose(code, string.fromBytes(data)) - except CatchableError as exc: - debug "Exception in Close callback, this is most likely a bug", exc = exc.msg + var rcode: Status + if code in {Status.Fulfilled}: + rcode = Status.Fulfilled - # don't respond to a terminated connection - if ws.readyState != ReadyState.Closing: - await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) + if not isNil(ws.onClose): + try: + (rcode, reason) = ws.onClose(code, reason) + except CatchableError as exc: + debug "Exception in Close callback, this is most likely a bug", exc = exc.msg + + # don't respond to a terminated connection + if ws.readyState != ReadyState.Closing: + ws.readyState = ReadyState.Closing + await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) - await ws.stream.closeWait() ws.readyState = ReadyState.Closed - else: - raiseAssert("Invalid state during close!") + await ws.stream.closeWait() -proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = +proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = ## handle control frames ## - if frame.length > 125: - raise newException(WSPayloadTooLarge, - "Control message payload is greater than 125 bytes!") - try: # Process control frame payload. case frame.opcode: @@ -421,7 +437,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg # send pong to remote - await ws.send(@[], Opcode.Pong) + await ws.send(payLoad, Opcode.Pong) of Opcode.Pong: if not isNil(ws.onPong): try: @@ -429,9 +445,12 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = except CatchableError as exc: debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg of Opcode.Close: - await ws.handleClose(frame) + await ws.handleClose(frame, payLoad) else: - raiseAssert("Invalid control opcode") + raise newException(WSInvalidOpcodeError, "Invalid control opcode!") + except WebSocketError as exc: + debug "Handled websocket exception", exc = exc.msg + raise exc except CatchableError as exc: trace "Exception handling control messages", exc = exc.msg ws.readyState = ReadyState.Closed @@ -467,8 +486,12 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = if opcode > ord(Opcode.high): raise newException(WSOpcodeMismatchError, "Wrong opcode!") - frame.opcode = (opcode).Opcode + let frameOpcode = (opcode).Opcode + if frameOpcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary, + Opcode.Ping, Opcode.Pong, Opcode.Close}: + raise newException(WSReserverdOpcodeError, "Unknown opcode received!") + frame.opcode = frameOpcode # If any of the rsv are set close the socket. if frame.rsv1 or frame.rsv2 or frame.rsv3: raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") @@ -508,12 +531,33 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = # return the current frame if it's not one of the control frames if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: - asyncSpawn ws.handleControl(frame) # process control frames + if not frame.fin: + raise newException(WSFragmentedControlFrameError, "Control frame cannot be fragmented!") + if frame.length > 125: + raise newException(WSPayloadTooLarge, + "Control message payload is greater than 125 bytes!") + var payLoad = newSeq[byte](frame.length) + if frame.length > 0: + # Read control frame payload. + await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int) + unmask(payLoad.toOpenArray(0, payLoad.high), frame.maskKey) + await ws.handleControl(frame, payLoad) # process control frames# process control frames continue - debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask + debug "Decoded new frame", opcode = frame.opcode, len = frame.length, + mask = frame.mask return frame + + except WSReserverdOpcodeError as exc: + trace "Handled websocket opcode exception", exc = exc.msg + raise exc + except WSPayloadTooLarge as exc: + debug "Handled payload too large exception", exc = exc.msg + raise exc + except WebSocketError as exc: + debug "Handled websocket exception", exc = exc.msg + raise exc except CatchableError as exc: debug "Exception reading frame, dropping socket", exc = exc.msg ws.readyState = ReadyState.Closed @@ -543,17 +587,32 @@ proc recv*( while consumed < size: # we might have to read more than # one frame to fill the buffer - if isNil(ws.frame): - ws.frame = await ws.readFrame() - # all has been consumed from the frame # read the next frame - if ws.frame.remainder() <= 0: + if isNil(ws.frame): ws.frame = await ws.readFrame() + # This could happen if the connection is closed. + if isNil(ws.frame): + return consumed.int + if ws.frame.opcode == Opcode.Cont: + raise newException(WSOpcodeMismatchError, "First frame cannot be continue frame") + + elif (not ws.frame.fin and ws.frame.remainder() <= 0): + ws.frame = await ws.readFrame() + # This could happen if the connection is closed. + if isNil(ws.frame): + return consumed.int + + if ws.frame.opcode != Opcode.Cont: + raise newException(WSOpcodeMismatchError, "expected continue frame") + if ws.frame.fin and ws.frame.remainder().int <= 0: + ws.frame = nil + break let len = min(ws.frame.remainder().int, size - consumed) + if len == 0: + continue let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) - if read <= 0: continue @@ -562,14 +621,18 @@ proc recv*( unmask( pbuffer.toOpenArray(consumed, (consumed + read) - 1), ws.frame.maskKey, - consumed) + ws.frame.consumed.int) consumed += read ws.frame.consumed += read.uint64 - if ws.frame.fin and ws.frame.remainder().int <= 0: - break return consumed.int + + except WebSocketError as exc: + debug "Websocket error", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + raise exc except CancelledError as exc: debug "Cancelling reading", exc = exc.msg raise exc @@ -611,7 +674,8 @@ proc recv*( # read the entire message, exit if ws.frame.fin and ws.frame.remainder().int <= 0: break - except WSMaxMessageSizeError as exc: + except WebSocketError as exc: + debug "Websocket error", exc = exc.msg raise exc except CancelledError as exc: debug "Cancelling reading", exc = exc.msg @@ -659,41 +723,48 @@ proc initiateHandshake( TransportError, "Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg) - let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers - let reader = newAsyncStreamReader(transp) - let writer = newAsyncStreamWriter(transp) + let + requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers + reader = newAsyncStreamReader(transp) + writer = newAsyncStreamWriter(transp) + var stream: AsyncStream - var res: seq[byte] - if uri.scheme == "https": - let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags) - stream = AsyncStream( - reader: tlsstream.reader, - writer: tlsstream.writer) + try: + var res: seq[byte] + if uri.scheme == "https": + let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags) + stream = AsyncStream( + reader: tlsstream.reader, + writer: tlsstream.writer) - await tlsstream.writer.write(requestHeader) - res = await tlsstream.reader.readHeaders() - else: - stream = AsyncStream( - reader: reader, - writer: writer) - await stream.writer.write(requestHeader) - res = await stream.reader.readHeaders() + await tlsstream.writer.write(requestHeader) + res = await tlsstream.reader.readHeaders() + else: + stream = AsyncStream( + reader: reader, + writer: writer) + await stream.writer.write(requestHeader) + res = await stream.reader.readHeaders() - if res.len == 0: - raise newException(ValueError, "Empty response from server") + if res.len == 0: + raise newException(ValueError, "Empty response from server") - let resHeader = res.parseResponse() - if resHeader.failed(): - # Header could not be parsed - raise newException(WSMalformedHeaderError, "Malformed header received.") + let resHeader = res.parseResponse() + if resHeader.failed(): + # Header could not be parsed + raise newException(WSMalformedHeaderError, "Malformed header received.") - if resHeader.code != ord(Http101): - raise newException(WSFailedUpgradeError, - "Server did not reply with a websocket upgrade:" & - " Header code: " & $resHeader.code & - " Header reason: " & resHeader.reason() & - " Address: " & $transp.remoteAddress()) + if resHeader.code != ord(Http101): + raise newException(WSFailedUpgradeError, + "Server did not reply with a websocket upgrade:" & + " Header code: " & $resHeader.code & + " Header reason: " & resHeader.reason() & + " Address: " & $transp.remoteAddress()) + except CatchableError as exc: + debug "Websocket failed during handshake", exc = exc.msg + await stream.closeWait() + raise exc return stream @@ -738,7 +809,7 @@ proc connect*( # Client data should be masked. return WebSocket( stream: stream, - readyState: Open, + readyState: ReadyState.Open, masked: true, rng: newRng(), frameSize: frameSize,