diff --git a/examples/server.nim b/examples/server.nim index 523b610624..f5c33da1d5 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -3,7 +3,7 @@ import pkg/[chronos, chronicles, httputils] -import ../ws/ws +import ../ws/[ws, frame, errors] proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isOk(): @@ -12,7 +12,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if request.uri.path == "/ws": debug "Initiating web socket connection." try: - let ws = await createServer(request, "") + let ws = await WebSocket.createServer(request, "") if ws.readyState != Open: error "Failed to open websocket connection." return diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim index 010bc02bea..35f1a25faf 100644 --- a/examples/tlsserver.nim +++ b/examples/tlsserver.nim @@ -4,7 +4,7 @@ import pkg/[chronos, httputils, stew/byteutils] -import ../ws/ws +import ../ws/[ws, frame, errors] import ../tests/keys let secureKey = TLSPrivateKey.init(SecureKey) @@ -18,7 +18,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if request.uri.path == "/wss": debug "Initiating web socket connection." try: - var ws = await createServer(request, "myfancyprotocol") + var ws = await WebSocket.createServer(request, "myfancyprotocol") if ws.readyState != Open: error "Failed to open websocket connection." return diff --git a/tests/testframes.nim b/tests/testframes.nim index dbb931829f..037209a311 100644 --- a/tests/testframes.nim +++ b/tests/testframes.nim @@ -1,6 +1,6 @@ -import unittest +import unittest, stew/byteutils -include ../ws/ws +include ../ws/frame include ../ws/utils # TODO: Fix Test. @@ -9,19 +9,18 @@ var maskKey: array[4, char] suite "Test data frames": test "# 7bit length text": - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, rsv3: false, opcode: Opcode.Text, mask: false, - data: toBytes("hi there"), - maskKey: maskKey - )) == toBytes("\1\8hi there") + data: toBytes("hi there") + ).encode() == toBytes("\1\8hi there") test "# 7bit length text fin bit": - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -30,10 +29,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - )) == toBytes("\129\8hi there") + ).encode() == toBytes("\129\8hi there") test "# 7bit length binary": - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -42,10 +41,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - )) == toBytes("\2\8hi there") + ).encode() == toBytes("\2\8hi there") test "# 7bit length binary fin bit": - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -54,10 +53,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - )) == toBytes("\130\8hi there") + ).encode() == toBytes("\130\8hi there") test "# 7bit length continuation": - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -66,14 +65,14 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - )) == toBytes("\0\8hi there") + ).encode() == toBytes("\0\8hi there") test "# 7+16 length text": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -82,14 +81,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\1\126\4\98" & data) + ).encode() == toBytes("\1\126\4\98" & data) test "# 7+16 length text fin bit": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -98,14 +97,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\1\126\4\98" & data) + ).encode() == toBytes("\1\126\4\98" & data) test "# 7+16 length binary": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -114,14 +113,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\2\126\4\98" & data) + ).encode() == toBytes("\2\126\4\98" & data) test "# 7+16 length binary fin bit": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -130,14 +129,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\130\126\4\98" & data) + ).encode() == toBytes("\130\126\4\98" & data) test "# 7+16 length continuation": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -146,14 +145,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\0\126\4\98" & data) + ).encode() == toBytes("\0\126\4\98" & data) test "# 7+64 length text": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -162,14 +161,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\1\127\0\0\0\0\0\1\169\34" & data) + ).encode() == toBytes("\1\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length fin bit": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -178,14 +177,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\129\127\0\0\0\0\0\1\169\34" & data) + ).encode() == toBytes("\129\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -194,14 +193,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\2\127\0\0\0\0\0\1\169\34" & data) + ).encode() == toBytes("\2\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary fin bit": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -210,14 +209,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\130\127\0\0\0\0\0\1\169\34" & data) + ).encode() == toBytes("\130\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -226,10 +225,10 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - )) == toBytes("\0\127\0\0\0\0\0\1\169\34" & data) + ).encode() == toBytes("\0\127\0\0\0\0\0\1\169\34" & data) test "# masking": - let data = encodeFrame(Frame( + let data = Frame( fin: true, rsv1: false, rsv2: false, @@ -238,14 +237,14 @@ suite "Test data frames": mask: true, data: toBytes("hi there"), maskKey: ['\xCF', '\xD8', '\x05', 'e'] - )) + ).encode() check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") suite "Test control frames": test "Close": - check encodeFrame(Frame( + check Frame( fin: true, rsv1: false, rsv2: false, @@ -254,10 +253,10 @@ suite "Test control frames": mask: false, data: @[3'u8, 232'u8] & toBytes("hi there"), maskKey: maskKey - )) == toBytes("\136\10\3\232hi there") + ).encode() == toBytes("\136\10\3\232hi there") test "Ping": - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -265,10 +264,10 @@ suite "Test control frames": opcode: Opcode.Ping, mask: false, maskKey: maskKey - )) == toBytes("\9\0") + ).encode() == toBytes("\9\0") test "Pong": - check encodeFrame(Frame( + check Frame( fin: false, rsv1: false, rsv2: false, @@ -276,4 +275,4 @@ suite "Test control frames": opcode: Opcode.Pong, mask: false, maskKey: maskKey - )) == toBytes("\10\0") + ).encode() == toBytes("\10\0") diff --git a/tests/testtlswebsockets.nim b/tests/testtlswebsockets.nim index 102fa33ef0..d1c49b0bd1 100644 --- a/tests/testtlswebsockets.nim +++ b/tests/testtlswebsockets.nim @@ -6,7 +6,7 @@ import pkg/[asynctest, chronos/apps/http/shttpserver, stew/byteutils] -import ../ws/[ws, stream], +import ../ws/[ws, stream, errors], ../examples/tlsserver import ./keys @@ -40,7 +40,7 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" expect WSProtoMismatchError: - discard await createServer(request, "proto") + discard await WebSocket.createServer(request, "proto") let res = SecureHttpServerRef.new( address, cb, @@ -68,7 +68,7 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" expect WSVersionError: - discard await createServer(request, "proto") + discard await WebSocket.createServer(request, "proto") let res = SecureHttpServerRef.new( address, cb, @@ -135,7 +135,7 @@ suite "Test websocket TLS transmission": let request = r.get() check request.uri.path == "/wss" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == testString await waitForClose(ws) @@ -169,7 +169,7 @@ suite "Test websocket TLS transmission": let request = r.get() check request.uri.path == "/wss" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await ws.send(testString) await ws.close() return dumbResponse() diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index dd4c63cb4d..2f0ddf8831 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -6,7 +6,7 @@ import pkg/[asynctest, chronicles, stew/byteutils] -import ../ws/[ws, stream, utils] +import ../ws/[ws, stream, utils, frame, errors] var server: HttpServerRef let address = initTAddress("127.0.0.1:8888") @@ -39,7 +39,7 @@ suite "Test handshake": let request = r.get() check request.uri.path == "/ws" expect WSProtoMismatchError: - discard await createServer(request, "proto") + discard await WebSocket.createServer(request, "proto") let res = HttpServerRef.new(address, cb) server = res.get() @@ -59,7 +59,7 @@ suite "Test handshake": let request = r.get() check request.uri.path == "/ws" expect WSVersionError: - discard await createServer(request, "proto") + discard await WebSocket.createServer(request, "proto") let res = HttpServerRef.new(address, cb) server = res.get() @@ -111,7 +111,7 @@ suite "Test handshake": check request.uri.path == "/ws" expect WSProtoMismatchError: - var ws = await createServer(request, "proto") + var ws = await WebSocket.createServer(request, "proto") check ws.readyState == ReadyState.Closed return await request.respond(Http200, "Connection established") @@ -142,14 +142,14 @@ suite "Test transmission": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isErr(): return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == testString - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -158,6 +158,7 @@ suite "Test transmission": Port(8888), path = "/ws", protocols = @["proto"]) + await wsClient.send(testString) await wsClient.close() @@ -168,7 +169,7 @@ suite "Test transmission": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == testString await waitForClose(ws) @@ -198,7 +199,7 @@ suite "Test transmission": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await ws.send(testString) await ws.close() @@ -238,7 +239,7 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPing = proc(data: openArray[byte]) = @@ -266,7 +267,7 @@ suite "Test ping-pong": let maskKey = genMaskKey(newRng()) await wsClient.stream.writer.write( - encodeFrame(Frame( + Frame( fin: false, rsv1: false, rsv2: false, @@ -274,12 +275,13 @@ suite "Test ping-pong": opcode: Opcode.Text, mask: true, data: msg[0..4], - maskKey: maskKey))) + maskKey: maskKey) + .encode()) await wsClient.ping() await wsClient.stream.writer.write( - encodeFrame(Frame( + Frame( fin: true, rsv1: false, rsv2: false, @@ -287,7 +289,8 @@ suite "Test ping-pong": opcode: Opcode.Cont, mask: true, data: msg[5..9], - maskKey: maskKey))) + maskKey: maskKey) + .encode()) await wsClient.close() check: @@ -307,7 +310,7 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPong = proc(data: openArray[byte]) = @@ -343,7 +346,7 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPing = proc(data: openArray[byte]) = @@ -390,7 +393,7 @@ suite "Test framing": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let frame1 = await ws.readFrame() check not isNil(frame1) var data1 = newSeq[byte](frame1.remainder().int) @@ -432,7 +435,7 @@ suite "Test framing": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await ws.send(testString) await ws.close() @@ -473,7 +476,7 @@ suite "Test Closing": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await ws.close() let res = HttpServerRef.new(address, cb) @@ -512,7 +515,7 @@ suite "Test Closing": return (Status.Fulfilled, "") - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onClose = closeServer) @@ -548,7 +551,7 @@ suite "Test Closing": let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await waitForClose(ws) let res = HttpServerRef.new(address, cb) @@ -577,7 +580,7 @@ suite "Test Closing": except Exception as exc: raise newException(Defect, exc.msg) - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onClose = closeServer) @@ -612,7 +615,7 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto") @@ -655,7 +658,7 @@ suite "Test Closing": except Exception as exc: raise newException(Defect, exc.msg) - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onClose = closeServer) @@ -686,7 +689,7 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") # Close with payload of length 2 await ws.close(reason = "HH") @@ -708,7 +711,7 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") await waitForClose(ws) let res = HttpServerRef.new( @@ -731,7 +734,6 @@ suite "Test Closing": getTracker("stream.server").isLeaked() == false getTracker("stream.transport").isLeaked() == false - suite "Test Payload": teardown: await server.closeWait() @@ -742,7 +744,7 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto") @@ -773,7 +775,7 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == emptyStr await waitForClose(ws) @@ -799,7 +801,7 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == emptyStr await waitForClose(ws) @@ -827,7 +829,7 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPing = proc(data: openArray[byte]) = @@ -874,7 +876,7 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check: @@ -904,7 +906,7 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await WebSocket.createServer(request, "proto") let servRes = await ws.recv() check: @@ -937,7 +939,7 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPing = proc(data: openArray[byte]) = @@ -976,7 +978,7 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await createServer( + let ws = await WebSocket.createServer( request, "proto", onPing = proc(data: openArray[byte]) = diff --git a/ws/errors.nim b/ws/errors.nim new file mode 100644 index 0000000000..d98b097982 --- /dev/null +++ b/ws/errors.nim @@ -0,0 +1,31 @@ +## Nim-Libp2p +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +type + WebSocketError* = object of CatchableError + WSMalformedHeaderError* = object of WebSocketError + WSFailedUpgradeError* = object of WebSocketError + WSVersionError* = object of WebSocketError + WSProtoMismatchError* = object of WebSocketError + WSMaskMismatchError* = object of WebSocketError + WSHandshakeError* = object of WebSocketError + WSOpcodeMismatchError* = object of WebSocketError + WSRsvMismatchError* = object of WebSocketError + WSWrongUriSchemeError* = object of WebSocketError + WSMaxMessageSizeError* = object of WebSocketError + WSClosedError* = object of WebSocketError + WSSendError* = object of WebSocketError + WSPayloadTooLarge* = object of WebSocketError + WSReserverdOpcodeError* = object of WebSocketError + WSFragmentedControlFrameError* = object of WebSocketError + WSInvalidCloseCodeError* = object of WebSocketError + WSPayloadLengthError* = object of WebSocketError + WSInvalidOpcodeError* = object of WebSocketError diff --git a/ws/frame.nim b/ws/frame.nim new file mode 100644 index 0000000000..b30b66f978 --- /dev/null +++ b/ws/frame.nim @@ -0,0 +1,207 @@ +## Nim-Libp2p +## Copyright (c) 2020 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +import pkg/[chronos, chronicles, stew/endians2, stew/results] +import ./errors + +#[ + +---------------------------------------------------------------+ + |0 1 2 3 | + |0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ +]# + +type + Opcode* {.pure.} = enum + ## 4 bits. Defines the interpretation of the "Payload data". + Cont = 0x0 ## Denotes a continuation frame. + Text = 0x1 ## Denotes a text frame. + Binary = 0x2 ## Denotes a binary frame. + # 3-7 are reserved for further non-control frames. + Close = 0x8 ## Denotes a connection close. + Ping = 0x9 ## Denotes a ping. + Pong = 0xa ## Denotes a pong. + # B-F are reserved for further control frames. + + HeaderFlag* {.pure, size: sizeof(uint8).} = enum + rsv3 + rsv2 + rsv1 + fin + + HeaderFlags = set[HeaderFlag] + + MaskKey = array[4, char] + + Frame* = ref object + fin*: bool ## Indicates that this is the final fragment in a message. + rsv1*: bool ## MUST be 0 unless negotiated that defines meanings + rsv2*: bool ## MUST be 0 + rsv3*: bool ## MUST be 0 + opcode*: Opcode ## Defines the interpretation of the "Payload data". + mask*: bool ## Defines whether the "Payload data" is masked. + data*: seq[byte] ## Payload data + maskKey*: MaskKey ## Masking key + length*: uint64 ## Message size. + consumed*: uint64 ## how much has been consumed from the frame + +proc mask*( + data: var openArray[byte], + maskKey: MaskKey, + offset = 0) = + ## Unmask a data payload using key + ## + + for i in 0 ..< data.len: + data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8) + +proc encode*(f: Frame, offset = 0): seq[byte] = + ## Encodes a frame into a string buffer. + ## See https://tools.ietf.org/html/rfc6455#section-5.2 + + var ret: seq[byte] + var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags. + if f.fin: + b0 = b0 or 128'u8 + + ret.add(b0) + + # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. + # 1st byte: payload len start and mask bit. + var b1 = 0'u8 + + if f.data.len <= 125: + b1 = f.data.len.uint8 + elif f.data.len > 125 and f.data.len <= 0xFFFF: + b1 = 126'u8 + else: + b1 = 127'u8 + + if f.mask: + b1 = b1 or (1 shl 7) + + ret.add(uint8 b1) + + # Only need more bytes if data len is 7+16 bits, or 7+64 bits. + if f.data.len > 125 and f.data.len <= 0xFFFF: + # Data len is 7+16 bits. + var len = f.data.len.uint16 + ret.add ((len shr 8) and 0xFF).uint8 + ret.add (len and 0xFF).uint8 + elif f.data.len > 0xFFFF: + # Data len is 7+64 bits. + var len = f.data.len.uint64 + ret.add(len.toBytesBE()) + + var data = f.data + + if f.mask: + # If we need to mask it generate random mask key and mask the data. + mask(data, f.maskKey, offset) + + # Write mask key next. + ret.add(f.maskKey[0].uint8) + ret.add(f.maskKey[1].uint8) + ret.add(f.maskKey[2].uint8) + ret.add(f.maskKey[3].uint8) + + # Write the data. + ret.add(data) + return ret + +proc decode*( + _: typedesc[Frame], + reader: AsyncStreamReader, + masked: bool): + Future[Frame] {.async.} = + ## Read and Decode incoming header + ## + + var header = newSeq[byte](2) + await reader.readExactly(addr header[0], 2) + if header.len != 2: + debug "Invalid websocket header length" + raise newException(WSMalformedHeaderError, + "Invalid websocket header length") + + let b0 = header[0].uint8 + let b1 = header[1].uint8 + + var frame = Frame() + # Read the flags and fin from the header. + + var hf = cast[HeaderFlags](b0 shr 4) + frame.fin = HeaderFlag.fin in hf + frame.rsv1 = HeaderFlag.rsv1 in hf + frame.rsv2 = HeaderFlag.rsv2 in hf + frame.rsv3 = HeaderFlag.rsv3 in hf + + let opcode = (b0 and 0x0f) + if opcode > ord(Opcode.high): + raise newException(WSOpcodeMismatchError, "Wrong opcode!") + + frame.opcode = (opcode).Opcode + + # If any of the rsv are set close the socket. + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") + + # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. + var finalLen: uint64 = 0 + + let headerLen = uint(b1 and 0x7f) + if headerLen == 0x7e: + # Length must be 7+16 bits. + var length = newSeq[byte](2) + await reader.readExactly(addr length[0], 2) + finalLen = uint16.fromBytesBE(length) + elif headerLen == 0x7f: + # Length must be 7+64 bits. + var length = newSeq[byte](8) + await reader.readExactly(addr length[0], 8) + finalLen = uint64.fromBytesBE(length) + else: + # Length must be 7 bits. + finalLen = headerLen + + frame.length = finalLen + + # Do we need to apply mask? + frame.mask = (b1 and 0x80) == 0x80 + if masked == frame.mask: + # Server sends unmasked but accepts only masked. + # Client sends masked but accepts only unmasked. + raise newException(WSMaskMismatchError, + "Socket mask mismatch") + + var maskKey = newSeq[byte](4) + if frame.mask: + # Read the mask. + await reader.readExactly(addr maskKey[0], 4) + for i in 0.. 999: @@ -234,6 +162,7 @@ proc handshake*( ws.readyState = ReadyState.Open proc createServer*( + _: typedesc[WebSocket], request: HttpRequestRef, protocol: string = "", frameSize = WSDefaultFrameSize, @@ -263,60 +192,6 @@ proc createServer*( await ws.handshake(request) return ws -proc encodeFrame*(f: Frame, offset = 0): seq[byte] = - ## Encodes a frame into a string buffer. - ## See https://tools.ietf.org/html/rfc6455#section-5.2 - - var ret: seq[byte] - var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags. - if f.fin: - b0 = b0 or 128'u8 - - ret.add(b0) - - # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. - # 1st byte: payload len start and mask bit. - var b1 = 0'u8 - - if f.data.len <= 125: - b1 = f.data.len.uint8 - elif f.data.len > 125 and f.data.len <= 0xFFFF: - b1 = 126'u8 - else: - b1 = 127'u8 - - if f.mask: - b1 = b1 or (1 shl 7) - - ret.add(uint8 b1) - - # Only need more bytes if data len is 7+16 bits, or 7+64 bits. - if f.data.len > 125 and f.data.len <= 0xFFFF: - # Data len is 7+16 bits. - var len = f.data.len.uint16 - ret.add ((len shr 8) and 0xFF).uint8 - ret.add (len and 0xFF).uint8 - elif f.data.len > 0xFFFF: - # Data len is 7+64 bits. - var len = f.data.len.uint64 - ret.add(len.toBytesBE()) - - var data = f.data - - if f.mask: - # If we need to mask it generate random mask key and mask the data. - mask(data, f.maskKey, offset) - - # Write mask key next. - ret.add(f.maskKey[0].uint8) - ret.add(f.maskKey[1].uint8) - ret.add(f.maskKey[2].uint8) - ret.add(f.maskKey[3].uint8) - - # Write the data. - ret.add(data) - return ret - proc send*( ws: WebSocket, data: seq[byte] = @[], @@ -339,19 +214,21 @@ proc send*( maskKey = genMaskKey(ws.rng) if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: + if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}: return await ws.stream.writer.write( - encodeFrame(Frame( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: opcode, - mask: ws.masked, - data: data, # allow sending data with close messages - maskKey: maskKey))) + Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: opcode, + mask: ws.masked, + data: data, # allow sending data with close messages + maskKey: maskKey) + .encode()) return @@ -360,7 +237,7 @@ proc send*( while ws.readyState notin {ReadyState.Closing}: let len = min(data.len, (maxSize + i)) await ws.stream.writer.write( - encodeFrame(Frame( + Frame( fin: if (i + len >= data.len): true else: false, rsv1: false, rsv2: false, @@ -368,7 +245,8 @@ proc send*( opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames mask: ws.masked, data: data[i ..< len], - maskKey: maskKey))) + maskKey: maskKey) + .encode()) i += len if i >= data.len: @@ -491,69 +369,7 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = try: while ws.readyState != ReadyState.Closed: - # Grab the header. - var header = newSeq[byte](2) - await ws.stream.reader.readExactly(addr header[0], 2) - if header.len != 2: - debug "Invalid websocket header length" - raise newException(WSMalformedHeaderError, - "Invalid websocket header length") - - let b0 = header[0].uint8 - let b1 = header[1].uint8 - - var frame = Frame() - # Read the flags and fin from the header. - - var hf = cast[HeaderFlags](b0 shr 4) - frame.fin = fin in hf - frame.rsv1 = rsv1 in hf - frame.rsv2 = rsv2 in hf - frame.rsv3 = rsv3 in hf - - let opcode = (b0 and 0x0f) - if opcode > ord(Opcode.high): - raise newException(WSOpcodeMismatchError, "Wrong opcode!") - - frame.opcode = (opcode).Opcode - - # If any of the rsv are set close the socket. - if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") - - # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. - var finalLen: uint64 = 0 - - let headerLen = uint(b1 and 0x7f) - if headerLen == 0x7e: - # Length must be 7+16 bits. - var length = newSeq[byte](2) - await ws.stream.reader.readExactly(addr length[0], 2) - finalLen = uint16.fromBytesBE(length) - elif headerLen == 0x7f: - # Length must be 7+64 bits. - var length = newSeq[byte](8) - await ws.stream.reader.readExactly(addr length[0], 8) - finalLen = uint64.fromBytesBE(length) - else: - # Length must be 7 bits. - finalLen = headerLen - frame.length = finalLen - - # Do we need to apply mask? - frame.mask = (b1 and 0x80) == 0x80 - if ws.masked == frame.mask: - # Server sends unmasked but accepts only masked. - # Client sends masked but accepts only unmasked. - raise newException(WSMaskMismatchError, "Socket mask mismatch") - - var maskKey = newSeq[byte](4) - if frame.mask: - # Read the mask. - await ws.stream.reader.readExactly(addr maskKey[0], 4) - for i in 0..