From eb62ec1725d895ad636bcb718cded4f68fcd2205 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Tue, 25 May 2021 16:39:10 -0600 Subject: [PATCH] Extract session (#31) * extract websocket session * fix tests * fix frame tests --- examples/client.nim | 2 +- examples/server.nim | 2 +- examples/tlsclient.nim | 3 +- examples/tlsserver.nim | 2 +- tests/testframes.nim | 78 +++---- tests/testtlswebsockets.nim | 3 +- tests/testwebsockets.nim | 22 +- ws/errors.nim | 31 --- ws/extension.nim | 26 --- ws/frame.nim | 59 ++--- ws/session.nim | 380 ++++++++++++++++++++++++++++++++ ws/types.nim | 136 ++++++++++++ ws/utils.nim | 12 +- ws/ws.nim | 421 +----------------------------------- 14 files changed, 600 insertions(+), 577 deletions(-) delete mode 100644 ws/errors.nim delete mode 100644 ws/extension.nim create mode 100644 ws/session.nim create mode 100644 ws/types.nim diff --git a/examples/client.nim b/examples/client.nim index 274c75064b..ea61b9917e 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -3,7 +3,7 @@ import pkg/[ chronicles, stew/byteutils] -import ../ws/ws, ../ws/errors +import ../ws/ws proc main() {.async.} = let ws = await WebSocket.connect( diff --git a/examples/server.nim b/examples/server.nim index fa477fe9f6..c9e80a1210 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -3,7 +3,7 @@ import pkg/[chronos, chronicles, httputils] -import ../ws/[ws, frame, errors] +import ../ws/ws proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isOk(): diff --git a/examples/tlsclient.nim b/examples/tlsclient.nim index c80258b5d5..65bdf5097d 100644 --- a/examples/tlsclient.nim +++ b/examples/tlsclient.nim @@ -3,7 +3,7 @@ import pkg/[chronos, chronicles, stew/byteutils] -import ../ws/ws, ../ws/errors +import ../ws/ws proc main() {.async.} = let ws = await WebSocket.tlsConnect( @@ -31,4 +31,5 @@ proc main() {.async.} = # close the websocket await ws.close() + waitFor(main()) diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim index c2ba9e3d3f..a73a02adcc 100644 --- a/examples/tlsserver.nim +++ b/examples/tlsserver.nim @@ -4,7 +4,7 @@ import pkg/[chronos, httputils, stew/byteutils] -import ../ws/[ws, frame, errors] +import ../ws/ws import ../tests/keys let secureKey = TLSPrivateKey.init(SecureKey) diff --git a/tests/testframes.nim b/tests/testframes.nim index 037209a311..71db1a19d6 100644 --- a/tests/testframes.nim +++ b/tests/testframes.nim @@ -1,4 +1,4 @@ -import unittest, stew/byteutils +import pkg/[asynctest, stew/byteutils] include ../ws/frame include ../ws/utils @@ -9,7 +9,7 @@ var maskKey: array[4, char] suite "Test data frames": test "# 7bit length text": - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -17,10 +17,10 @@ suite "Test data frames": opcode: Opcode.Text, mask: false, data: toBytes("hi there") - ).encode() == toBytes("\1\8hi there") + ).encode()) == toBytes("\1\8hi there") test "# 7bit length text fin bit": - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -29,10 +29,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - ).encode() == toBytes("\129\8hi there") + ).encode()) == toBytes("\129\8hi there") test "# 7bit length binary": - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -41,10 +41,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - ).encode() == toBytes("\2\8hi there") + ).encode()) == toBytes("\2\8hi there") test "# 7bit length binary fin bit": - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -53,10 +53,10 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - ).encode() == toBytes("\130\8hi there") + ).encode()) == toBytes("\130\8hi there") test "# 7bit length continuation": - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -65,14 +65,14 @@ suite "Test data frames": mask: false, data: toBytes("hi there"), maskKey: maskKey - ).encode() == toBytes("\0\8hi there") + ).encode()) == toBytes("\0\8hi there") test "# 7+16 length text": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -81,14 +81,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\1\126\4\98" & data) + ).encode()) == toBytes("\1\126\4\98" & data) test "# 7+16 length text fin bit": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -97,14 +97,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\1\126\4\98" & data) + ).encode()) == toBytes("\1\126\4\98" & data) test "# 7+16 length binary": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -113,14 +113,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\2\126\4\98" & data) + ).encode()) == toBytes("\2\126\4\98" & data) test "# 7+16 length binary fin bit": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -129,14 +129,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\130\126\4\98" & data) + ).encode()) == toBytes("\130\126\4\98" & data) test "# 7+16 length continuation": var data = "" for i in 0..32: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -145,14 +145,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\0\126\4\98" & data) + ).encode()) == toBytes("\0\126\4\98" & data) test "# 7+64 length text": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -161,14 +161,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\1\127\0\0\0\0\0\1\169\34" & data) + ).encode()) == toBytes("\1\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length fin bit": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -177,14 +177,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\129\127\0\0\0\0\0\1\169\34" & data) + ).encode()) == toBytes("\129\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -193,14 +193,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\2\127\0\0\0\0\0\1\169\34" & data) + ).encode()) == toBytes("\2\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary fin bit": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -209,14 +209,14 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\130\127\0\0\0\0\0\1\169\34" & data) + ).encode()) == toBytes("\130\127\0\0\0\0\0\1\169\34" & data) test "# 7+64 length binary": var data = "" for i in 0..3200: data.add "How are you this is the payload!!!" - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -225,10 +225,10 @@ suite "Test data frames": mask: false, data: toBytes(data), maskKey: maskKey - ).encode() == toBytes("\0\127\0\0\0\0\0\1\169\34" & data) + ).encode()) == toBytes("\0\127\0\0\0\0\0\1\169\34" & data) test "# masking": - let data = Frame( + let data = (await Frame( fin: true, rsv1: false, rsv2: false, @@ -237,14 +237,14 @@ suite "Test data frames": mask: true, data: toBytes("hi there"), maskKey: ['\xCF', '\xD8', '\x05', 'e'] - ).encode() + ).encode()) check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") suite "Test control frames": test "Close": - check Frame( + check (await Frame( fin: true, rsv1: false, rsv2: false, @@ -253,10 +253,10 @@ suite "Test control frames": mask: false, data: @[3'u8, 232'u8] & toBytes("hi there"), maskKey: maskKey - ).encode() == toBytes("\136\10\3\232hi there") + ).encode()) == toBytes("\136\10\3\232hi there") test "Ping": - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -264,10 +264,10 @@ suite "Test control frames": opcode: Opcode.Ping, mask: false, maskKey: maskKey - ).encode() == toBytes("\9\0") + ).encode()) == toBytes("\9\0") test "Pong": - check Frame( + check (await Frame( fin: false, rsv1: false, rsv2: false, @@ -275,4 +275,4 @@ suite "Test control frames": opcode: Opcode.Pong, mask: false, maskKey: maskKey - ).encode() == toBytes("\10\0") + ).encode()) == toBytes("\10\0") diff --git a/tests/testtlswebsockets.nim b/tests/testtlswebsockets.nim index 786cc9c6c3..7bc7fcf22f 100644 --- a/tests/testtlswebsockets.nim +++ b/tests/testtlswebsockets.nim @@ -6,8 +6,7 @@ import pkg/[asynctest, chronos/apps/http/shttpserver, stew/byteutils] -import ../ws/[ws, stream, errors], - ../examples/tlsserver +import ../ws/ws, ../examples/tlsserver import ./keys diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index 2d1dd2d4f2..7abfcdfe36 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -1,12 +1,13 @@ import std/[strutils, random], httputils -import pkg/[asynctest, - chronos, - chronos/apps/http/httpserver, - chronicles, - stew/byteutils] +import pkg/[ + asynctest, + chronos, + chronos/apps/http/httpserver, + chronicles, + stew/byteutils] -import ../ws/[ws, stream, utils, frame, errors] +import ../ws/ws var server: HttpServerRef let address = initTAddress("127.0.0.1:8888") @@ -138,7 +139,6 @@ suite "Test handshake": suite "Test transmission": teardown: - await server.stop() await server.closeWait() test "Send text message message with payload of length 65535": @@ -279,7 +279,7 @@ suite "Test ping-pong": let maskKey = genMaskKey(newRng()) await wsClient.stream.writer.write( - Frame( + (await Frame( fin: false, rsv1: false, rsv2: false, @@ -288,12 +288,12 @@ suite "Test ping-pong": mask: true, data: msg[0..4], maskKey: maskKey) - .encode()) + .encode())) await wsClient.ping() await wsClient.stream.writer.write( - Frame( + (await Frame( fin: true, rsv1: false, rsv2: false, @@ -302,7 +302,7 @@ suite "Test ping-pong": mask: true, data: msg[5..9], maskKey: maskKey) - .encode()) + .encode())) await wsClient.close() check: diff --git a/ws/errors.nim b/ws/errors.nim deleted file mode 100644 index d98b097982..0000000000 --- a/ws/errors.nim +++ /dev/null @@ -1,31 +0,0 @@ -## Nim-Libp2p -## Copyright (c) 2021 Status Research & Development GmbH -## Licensed under either of -## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) -## * MIT license ([LICENSE-MIT](LICENSE-MIT)) -## at your option. -## This file may not be copied, modified, or distributed except according to -## those terms. - -{.push raises: [Defect].} - -type - WebSocketError* = object of CatchableError - WSMalformedHeaderError* = object of WebSocketError - WSFailedUpgradeError* = object of WebSocketError - WSVersionError* = object of WebSocketError - WSProtoMismatchError* = object of WebSocketError - WSMaskMismatchError* = object of WebSocketError - WSHandshakeError* = object of WebSocketError - WSOpcodeMismatchError* = object of WebSocketError - WSRsvMismatchError* = object of WebSocketError - WSWrongUriSchemeError* = object of WebSocketError - WSMaxMessageSizeError* = object of WebSocketError - WSClosedError* = object of WebSocketError - WSSendError* = object of WebSocketError - WSPayloadTooLarge* = object of WebSocketError - WSReserverdOpcodeError* = object of WebSocketError - WSFragmentedControlFrameError* = object of WebSocketError - WSInvalidCloseCodeError* = object of WebSocketError - WSPayloadLengthError* = object of WebSocketError - WSInvalidOpcodeError* = object of WebSocketError diff --git a/ws/extension.nim b/ws/extension.nim deleted file mode 100644 index b4c875d3ff..0000000000 --- a/ws/extension.nim +++ /dev/null @@ -1,26 +0,0 @@ -## Nim-Libp2p -## Copyright (c) 2021 Status Research & Development GmbH -## Licensed under either of -## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) -## * MIT license ([LICENSE-MIT](LICENSE-MIT)) -## at your option. -## This file may not be copied, modified, or distributed except according to -## those terms. - -{.push raises: [Defect].} - -import pkg/[chronos, chronicles] -import ./frame - -type - Extension* = ref object of RootObj - name*: string - -proc `name=`*(self: Extension, name: string) = - raiseAssert "Can't change extensions name!" - -method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = - raiseAssert "Not implemented!" - -method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = - raiseAssert "Not implemented!" diff --git a/ws/frame.nim b/ws/frame.nim index b30b66f978..b1171a2122 100644 --- a/ws/frame.nim +++ b/ws/frame.nim @@ -1,5 +1,5 @@ ## Nim-Libp2p -## Copyright (c) 2020 Status Research & Development GmbH +## Copyright (c) 2021 Status Research & Development GmbH ## Licensed under either of ## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) ## * MIT license ([LICENSE-MIT](LICENSE-MIT)) @@ -10,7 +10,7 @@ {.push raises: [Defect].} import pkg/[chronos, chronicles, stew/endians2, stew/results] -import ./errors +import ./types #[ +---------------------------------------------------------------+ @@ -34,40 +34,6 @@ import ./errors +---------------------------------------------------------------+ ]# -type - Opcode* {.pure.} = enum - ## 4 bits. Defines the interpretation of the "Payload data". - Cont = 0x0 ## Denotes a continuation frame. - Text = 0x1 ## Denotes a text frame. - Binary = 0x2 ## Denotes a binary frame. - # 3-7 are reserved for further non-control frames. - Close = 0x8 ## Denotes a connection close. - Ping = 0x9 ## Denotes a ping. - Pong = 0xa ## Denotes a pong. - # B-F are reserved for further control frames. - - HeaderFlag* {.pure, size: sizeof(uint8).} = enum - rsv3 - rsv2 - rsv1 - fin - - HeaderFlags = set[HeaderFlag] - - MaskKey = array[4, char] - - Frame* = ref object - fin*: bool ## Indicates that this is the final fragment in a message. - rsv1*: bool ## MUST be 0 unless negotiated that defines meanings - rsv2*: bool ## MUST be 0 - rsv3*: bool ## MUST be 0 - opcode*: Opcode ## Defines the interpretation of the "Payload data". - mask*: bool ## Defines whether the "Payload data" is masked. - data*: seq[byte] ## Payload data - maskKey*: MaskKey ## Masking key - length*: uint64 ## Message size. - consumed*: uint64 ## how much has been consumed from the frame - proc mask*( data: var openArray[byte], maskKey: MaskKey, @@ -78,10 +44,21 @@ proc mask*( for i in 0 ..< data.len: data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8) -proc encode*(f: Frame, offset = 0): seq[byte] = +template remainder*(frame: Frame): uint64 = + frame.length - frame.consumed + +proc encode*( + frame: Frame, + offset = 0, + extensions: seq[Extension] = @[]): Future[seq[byte]] {.async.} = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 + var f = frame + if extensions.len > 0: + for e in extensions: + f = await e.encode(f) + var ret: seq[byte] var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags. if f.fin: @@ -135,8 +112,8 @@ proc encode*(f: Frame, offset = 0): seq[byte] = proc decode*( _: typedesc[Frame], reader: AsyncStreamReader, - masked: bool): - Future[Frame] {.async.} = + masked: bool, + extensions: seq[Extension] = @[]): Future[Frame] {.async.} = ## Read and Decode incoming header ## @@ -204,4 +181,8 @@ proc decode*( for i in 0.. 0: + for e in extensions[extensions.high..extensions.low]: + frame = await e.decode(frame) + return frame diff --git a/ws/session.nim b/ws/session.nim new file mode 100644 index 0000000000..91933c7a98 --- /dev/null +++ b/ws/session.nim @@ -0,0 +1,380 @@ +## Nim-Libp2p +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +import pkg/[chronos, chronicles, stew/byteutils, stew/endians2] +import ./types, ./frame, ./utils, ./stream + +import pkg/chronos/[ + streams/asyncstream, + streams/tlsstream] + +type + WSSession* = ref object of WebSocket + stream*: AsyncStream + frame*: Frame + +proc prepareCloseBody(code: Status, reason: string): seq[byte] = + result = reason.toBytes + if ord(code) > 999: + result = @(ord(code).uint16.toBytesBE()) & result + +proc send*( + ws: WSSession, + data: seq[byte] = @[], + opcode: Opcode) {.async.} = + ## Send a frame + ## + + if ws.readyState == ReadyState.Closed: + raise newException(WSClosedError, "Socket is closed!") + + logScope: + opcode = opcode + dataSize = data.len + masked = ws.masked + + debug "Sending data to remote" + + var maskKey: array[4, char] + if ws.masked: + maskKey = genMaskKey(ws.rng) + + if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: + + if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}: + return + + await ws.stream.writer.write( + (await Frame( + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: opcode, + mask: ws.masked, + data: data, # allow sending data with close messages + maskKey: maskKey) + .encode())) + + return + + let maxSize = ws.frameSize + var i = 0 + while ws.readyState notin {ReadyState.Closing}: + let len = min(data.len, (maxSize + i)) + await ws.stream.writer.write( + (await Frame( + fin: if (i + len >= data.len): true else: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames + mask: ws.masked, + data: data[i ..< len], + maskKey: maskKey) + .encode())) + + i += len + if i >= data.len: + break + +proc send*(ws: WSSession, data: string): Future[void] = + send(ws, data.toBytes(), Opcode.Text) + +proc handleClose*( + ws: WSSession, + frame: Frame, + payLoad: seq[byte] = @[]) {.async.} = + ## Handle close sequence + ## + + logScope: + fin = frame.fin + masked = frame.mask + opcode = frame.opcode + readyState = ws.readyState + + debug "Handling close sequence" + + if ws.readyState notin {ReadyState.Open}: + debug "Connection isn't open, abortig close sequence!" + return + + var + code = Status.Fulfilled + reason = "" + + if payLoad.len == 1: + raise newException(WSPayloadLengthError, + "Invalid close frame with payload length 1!") + + if payLoad.len > 1: + # first two bytes are the status + let ccode = uint16.fromBytesBE(payLoad[0..<2]) + if ccode <= 999 or ccode > 1015: + raise newException(WSInvalidCloseCodeError, + "Invalid code in close message!") + + try: + code = Status(ccode) + except RangeError: + raise newException(WSInvalidCloseCodeError, + "Status code out of range!") + + # remining payload bytes are reason for closing + reason = string.fromBytes(payLoad[2..payLoad.high]) + + var rcode: Status + if code in {Status.Fulfilled}: + rcode = Status.Fulfilled + + if not isNil(ws.onClose): + try: + (rcode, reason) = ws.onClose(code, reason) + except CatchableError as exc: + debug "Exception in Close callback, this is most likely a bug", exc = exc.msg + + # don't respond to a terminated connection + if ws.readyState != ReadyState.Closing: + ws.readyState = ReadyState.Closing + await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) + + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + +proc handleControl*(ws: WSSession, frame: Frame) {.async.} = + ## Handle control frames + ## + + if not frame.fin: + raise newException(WSFragmentedControlFrameError, + "Control frame cannot be fragmented!") + + if frame.length > 125: + raise newException(WSPayloadTooLarge, + "Control message payload is greater than 125 bytes!") + + try: + var payLoad = newSeq[byte](frame.length.int) + if frame.length > 0: + payLoad.setLen(frame.length.int) + # Read control frame payload. + await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int) + if frame.mask: + mask( + payLoad.toOpenArray(0, payLoad.high), + frame.maskKey) + + # Process control frame payload. + case frame.opcode: + of Opcode.Ping: + if not isNil(ws.onPing): + try: + ws.onPing(payLoad) + except CatchableError as exc: + debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg + + # send pong to remote + await ws.send(payLoad, Opcode.Pong) + of Opcode.Pong: + if not isNil(ws.onPong): + try: + ws.onPong(payLoad) + except CatchableError as exc: + debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg + of Opcode.Close: + await ws.handleClose(frame, payLoad) + else: + raise newException(WSInvalidOpcodeError, "Invalid control opcode!") + except WebSocketError as exc: + debug "Handled websocket exception", exc = exc.msg + raise exc + except CatchableError as exc: + trace "Exception handling control messages", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + +proc readFrame*(ws: WSSession): Future[Frame] {.async.} = + ## Gets a frame from the WebSocket. + ## See https://tools.ietf.org/html/rfc6455#section-5.2 + ## + + try: + while ws.readyState != ReadyState.Closed: + let frame = await Frame.decode(ws.stream.reader, ws.masked) + debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask + + # return the current frame if it's not one of the control frames + if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: + await ws.handleControl(frame) # process control frames# process control frames + continue + + return frame + except WebSocketError as exc: + trace "Websocket error", exc = exc.msg + raise exc + except CatchableError as exc: + debug "Exception reading frame, dropping socket", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + raise exc + +proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] = + ws.send(data, opcode = Opcode.Ping) + +proc recv*( + ws: WSSession, + data: pointer, + size: int): Future[int] {.async.} = + ## Attempts to read up to `size` bytes + ## + ## Will read as many frames as necessary + ## to fill the buffer until either + ## the message ends (frame.fin) or + ## the buffer is full. If no data is on + ## the pipe will await until at least + ## one byte is available + ## + + var consumed = 0 + var pbuffer = cast[ptr UncheckedArray[byte]](data) + try: + while consumed < size: + # we might have to read more than + # one frame to fill the buffer + + # TODO: Figure out a cleaner way to handle + # retrieving new frames + if isNil(ws.frame): + ws.frame = await ws.readFrame() + + if isNil(ws.frame): + return consumed + + if ws.frame.opcode == Opcode.Cont: + raise newException(WSOpcodeMismatchError, + "Expected Text or Binary frame") + elif (not ws.frame.fin and ws.frame.remainder() <= 0): + ws.frame = await ws.readFrame() + # This could happen if the connection is closed. + + if isNil(ws.frame): + return consumed + + if ws.frame.opcode != Opcode.Cont: + raise newException(WSOpcodeMismatchError, + "Expected Continuation frame") + + ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag + if ws.frame.fin and ws.frame.remainder() <= 0: + ws.frame = nil + break + + let len = min(ws.frame.remainder().int, size - consumed) + if len == 0: + continue + + let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) + if read <= 0: + continue + + if ws.frame.mask: + # unmask data using offset + mask( + pbuffer.toOpenArray(consumed, (consumed + read) - 1), + ws.frame.maskKey, + ws.frame.consumed.int) + + consumed += read + ws.frame.consumed += read.uint64 + + return consumed.int + + except WebSocketError as exc: + debug "Websocket error", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + raise exc + except CancelledError as exc: + debug "Cancelling reading", exc = exc.msg + raise exc + except CatchableError as exc: + debug "Exception reading frames", exc = exc.msg + +proc recv*( + ws: WSSession, + size = WSMaxMessageSize): Future[seq[byte]] {.async.} = + ## Attempt to read a full message up to max `size` + ## bytes in `frameSize` chunks. + ## + ## If no `fin` flag arrives await until either + ## cancelled or the `fin` flag arrives. + ## + ## If message is larger than `size` a `WSMaxMessageSizeError` + ## exception is thrown. + ## + ## In all other cases it awaits a full message. + ## + var res: seq[byte] + try: + while ws.readyState != ReadyState.Closed: + var buf = newSeq[byte](ws.frameSize) + let read = await ws.recv(addr buf[0], buf.len) + if read <= 0: + break + + buf.setLen(read) + if res.len + buf.len > size: + raise newException(WSMaxMessageSizeError, "Max message size exceeded") + + res.add(buf) + + # no more frames + if isNil(ws.frame): + break + + # read the entire message, exit + if ws.frame.fin and ws.frame.remainder().int <= 0: + break + except WebSocketError as exc: + debug "Websocket error", exc = exc.msg + raise exc + except CancelledError as exc: + debug "Cancelling reading", exc = exc.msg + raise exc + except CatchableError as exc: + debug "Exception reading frames", exc = exc.msg + + return res + +proc close*( + ws: WSSession, + code: Status = Status.Fulfilled, + reason: string = "") {.async.} = + ## Close the Socket, sends close packet. + ## + + if ws.readyState != ReadyState.Open: + return + + try: + ws.readyState = ReadyState.Closing + await ws.send( + prepareCloseBody(code, reason), + opcode = Opcode.Close) + + # read frames until closed + while ws.readyState != ReadyState.Closed: + discard await ws.recv() + + except CatchableError as exc: + debug "Exception closing", exc = exc.msg + diff --git a/ws/types.nim b/ws/types.nim new file mode 100644 index 0000000000..f4ee492763 --- /dev/null +++ b/ws/types.nim @@ -0,0 +1,136 @@ +## Nim-Libp2p +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +import chronos +import ./utils + +const + SHA1DigestSize* = 20 + WSHeaderSize* = 12 + WSDefaultVersion* = 13 + WSDefaultFrameSize* = 1 shl 20 # 1mb + WSMaxMessageSize* = 20 shl 20 # 20mb + WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + CRLF* = "\r\n" + +type + ReadyState* {.pure.} = enum + Connecting = 0 # The connection is not yet open. + Open = 1 # The connection is open and ready to communicate. + Closing = 2 # The connection is in the process of closing. + Closed = 3 # The connection is closed or couldn't be opened. + + Opcode* {.pure.} = enum + ## 4 bits. Defines the interpretation of the "Payload data". + Cont = 0x0 ## Denotes a continuation frame. + Text = 0x1 ## Denotes a text frame. + Binary = 0x2 ## Denotes a binary frame. + # 3-7 are reserved for further non-control frames. + Close = 0x8 ## Denotes a connection close. + Ping = 0x9 ## Denotes a ping. + Pong = 0xa ## Denotes a pong. + # B-F are reserved for further control frames. + + HeaderFlag* {.pure, size: sizeof(uint8).} = enum + rsv3 + rsv2 + rsv1 + fin + + HeaderFlags* = set[HeaderFlag] + + MaskKey* = array[4, char] + + Frame* = ref object + fin*: bool ## Indicates that this is the final fragment in a message. + rsv1*: bool ## MUST be 0 unless negotiated that defines meanings + rsv2*: bool ## MUST be 0 + rsv3*: bool ## MUST be 0 + opcode*: Opcode ## Defines the interpretation of the "Payload data". + mask*: bool ## Defines whether the "Payload data" is masked. + data*: seq[byte] ## Payload data + maskKey*: MaskKey ## Masking key + length*: uint64 ## Message size. + consumed*: uint64 ## how much has been consumed from the frame + + Status* {.pure.} = enum + # 0-999 not used + Fulfilled = 1000 + GoingAway = 1001 + ProtocolError = 1002 + CannotAccept = 1003 + # 1004 reserved + NoStatus = 1005 # use by clients + ClosedAbnormally = 1006 # use by clients + Inconsistent = 1007 + PolicyError = 1008 + TooLarge = 1009 + NoExtensions = 1010 + UnexpectedError = 1011 + ReservedCode = 3999 # use by clients + # 3000-3999 reserved for libs + # 4000-4999 reserved for applications + + ControlCb* = proc(data: openArray[byte] = []) + {.gcsafe, raises: [Defect].} + + CloseResult* = tuple + code: Status + reason: string + + CloseCb* = proc(code: Status, reason: string): + CloseResult {.gcsafe, raises: [Defect].} + + Extension* = ref object of RootObj + name*: string + + WebSocket* = ref object of RootObj + extensions: seq[Extension] # extension active for this session + version*: uint + key*: string + proto*: string + readyState*: ReadyState + masked*: bool # send masked packets + binary*: bool # is payload binary? + rng*: Rng + frameSize*: int + onPing*: ControlCb + onPong*: ControlCb + onClose*: CloseCb + + WebSocketError* = object of CatchableError + WSMalformedHeaderError* = object of WebSocketError + WSFailedUpgradeError* = object of WebSocketError + WSVersionError* = object of WebSocketError + WSProtoMismatchError* = object of WebSocketError + WSMaskMismatchError* = object of WebSocketError + WSHandshakeError* = object of WebSocketError + WSOpcodeMismatchError* = object of WebSocketError + WSRsvMismatchError* = object of WebSocketError + WSWrongUriSchemeError* = object of WebSocketError + WSMaxMessageSizeError* = object of WebSocketError + WSClosedError* = object of WebSocketError + WSSendError* = object of WebSocketError + WSPayloadTooLarge* = object of WebSocketError + WSReserverdOpcodeError* = object of WebSocketError + WSFragmentedControlFrameError* = object of WebSocketError + WSInvalidCloseCodeError* = object of WebSocketError + WSPayloadLengthError* = object of WebSocketError + WSInvalidOpcodeError* = object of WebSocketError + +proc `name=`*(self: Extension, name: string) = + raiseAssert "Can't change extensions name!" + +method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = + raiseAssert "Not implemented!" + +method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = + raiseAssert "Not implemented!" diff --git a/ws/utils.nim b/ws/utils.nim index 66302a3178..80b36321ee 100644 --- a/ws/utils.nim +++ b/ws/utils.nim @@ -21,22 +21,22 @@ proc newRng*(): ref BrHmacDrbgContext = return nil rng -proc rand*(rng: var BrHmacDrbgContext, max: Natural): int = +proc rand*(rng: Rng, max: Natural): int = if max == 0: return 0 var x: uint64 while true: - brHmacDrbgGenerate(addr rng, addr x, csize_t(sizeof(x))) + brHmacDrbgGenerate(addr rng[], addr x, csize_t(sizeof(x))) if x < randMax - (randMax mod (uint64(max) + 1'u64)): # against modulo bias return int(x mod (uint64(max) + 1'u64)) -proc genMaskKey*(rng: ref BrHmacDrbgContext): array[4, char] = +proc genMaskKey*(rng: Rng): array[4, char] = ## Generates a random key of 4 random chars. - proc r(): char = char(rand(rng[], 255)) + proc r(): char = char(rand(rng, 255)) return [r(), r(), r(), r()] -proc genWebSecKey*(rng: ref BrHmacDrbgContext): seq[byte] = +proc genWebSecKey*(rng: Rng): seq[byte] = var key = newSeq[byte](16) - proc r(): byte = byte(rand(rng[], 255)) + proc r(): byte = byte(rand(rng, 255)) ## Generates a random key of 16 random chars. for i in 0..15: key.add(r()) diff --git a/ws/ws.nim b/ws/ws.nim index 988d697b47..2ece0a17c4 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -7,7 +7,6 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. - {.push raises: [Defect].} import std/[tables, @@ -29,79 +28,17 @@ import pkg/[chronos, stew/base10, nimcrypto/sha] -import ./utils, ./stream, ./frame, ./errors, ./extension +import ./utils, ./stream, ./frame, ./session, /types -const - SHA1DigestSize* = 20 - WSHeaderSize* = 12 - WSDefaultVersion* = 13 - WSDefaultFrameSize* = 1 shl 20 # 1mb - WSMaxMessageSize* = 20 shl 20 # 20mb - WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - CRLF* = "\r\n" +export utils, session, frame, stream, types type - ReadyState* {.pure.} = enum - Connecting = 0 # The connection is not yet open. - Open = 1 # The connection is open and ready to communicate. - Closing = 2 # The connection is in the process of closing. - Closed = 3 # The connection is closed or couldn't be opened. - HttpCode* = enum Http101 = 101 # Switching Protocols - Status* {.pure.} = enum - # 0-999 not used - Fulfilled = 1000 - GoingAway = 1001 - ProtocolError = 1002 - CannotAccept = 1003 - # 1004 reserved - NoStatus = 1005 # use by clients - ClosedAbnormally = 1006 # use by clients - Inconsistent = 1007 - PolicyError = 1008 - TooLarge = 1009 - NoExtensions = 1010 - UnexpectedError = 1011 - ReservedCode = 3999 # use by clients - # 3000-3999 reserved for libs - # 4000-4999 reserved for applications - - ControlCb* = proc(data: openArray[byte] = []) - {.gcsafe, raises: [Defect].} - - CloseResult* = tuple - code: Status - reason: string - - CloseCb* = proc(code: Status, reason: string): - CloseResult {.gcsafe, raises: [Defect].} - - WebSocket* = ref object of RootObj - extensions: seq[Extension] # extension active for this session - version*: uint - key*: string - proto*: string - readyState*: ReadyState - masked*: bool # send masked packets - binary*: bool # is payload binary? - rng*: ref BrHmacDrbgContext - frameSize: int - onPing: ControlCb - onPong: ControlCb - onClose: CloseCb - WSServer* = ref object of WebSocket protocols: seq[string] - WSSession* = ref object of WebSocket - stream*: AsyncStream - frame*: Frame - -template remainder*(frame: Frame): uint64 = - frame.length - frame.consumed - proc `$`(ht: HttpTables): string = ## Returns string representation of HttpTable/Ref. var res = "" @@ -115,11 +52,6 @@ proc `$`(ht: HttpTables): string = res.add(CRLF) res -proc prepareCloseBody(code: Status, reason: string): seq[byte] = - result = reason.toBytes - if ord(code) > 999: - result = @(ord(code).uint16.toBytesBE()) & result - proc handshake*( ws: WSServer, request: HttpRequestRef, @@ -187,355 +119,6 @@ proc handshake*( onPong: ws.onPong, onClose: ws.onClose) -proc send*( - ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode) {.async.} = - ## Send a frame - ## - - if ws.readyState == ReadyState.Closed: - raise newException(WSClosedError, "Socket is closed!") - - logScope: - opcode = opcode - dataSize = data.len - masked = ws.masked - - debug "Sending data to remote" - - var maskKey: array[4, char] - if ws.masked: - maskKey = genMaskKey(ws.rng) - - if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: - - if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}: - return - - await ws.stream.writer.write( - Frame( - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: opcode, - mask: ws.masked, - data: data, # allow sending data with close messages - maskKey: maskKey) - .encode()) - - return - - let maxSize = ws.frameSize - var i = 0 - while ws.readyState notin {ReadyState.Closing}: - let len = min(data.len, (maxSize + i)) - await ws.stream.writer.write( - Frame( - fin: if (i + len >= data.len): true else: false, - rsv1: false, - rsv2: false, - rsv3: false, - opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames - mask: ws.masked, - data: data[i ..< len], - maskKey: maskKey) - .encode()) - - i += len - if i >= data.len: - break - -proc send*(ws: WSSession, data: string): Future[void] = - send(ws, toBytes(data), Opcode.Text) - -proc handleClose*(ws: WSSession, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = - ## Handle close sequence - ## - - logScope: - fin = frame.fin - masked = frame.mask - opcode = frame.opcode - readyState = ws.readyState - - debug "Handling close sequence" - - if ws.readyState notin {ReadyState.Open}: - debug "Connection isn't open, abortig close sequence!" - return - - var - code = Status.Fulfilled - reason = "" - - if payLoad.len == 1: - raise newException(WSPayloadLengthError, - "Invalid close frame with payload length 1!") - - if payLoad.len > 1: - # first two bytes are the status - let ccode = uint16.fromBytesBE(payLoad[0..<2]) - if ccode <= 999 or ccode > 1015: - raise newException(WSInvalidCloseCodeError, - "Invalid code in close message!") - - try: - code = Status(ccode) - except RangeError: - raise newException(WSInvalidCloseCodeError, - "Status code out of range!") - - # remining payload bytes are reason for closing - reason = string.fromBytes(payLoad[2..payLoad.high]) - - var rcode: Status - if code in {Status.Fulfilled}: - rcode = Status.Fulfilled - - if not isNil(ws.onClose): - try: - (rcode, reason) = ws.onClose(code, reason) - except CatchableError as exc: - debug "Exception in Close callback, this is most likely a bug", exc = exc.msg - - # don't respond to a terminated connection - if ws.readyState != ReadyState.Closing: - ws.readyState = ReadyState.Closing - await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) - - ws.readyState = ReadyState.Closed - await ws.stream.closeWait() - -proc handleControl*(ws: WSSession, frame: Frame) {.async.} = - ## Handle control frames - ## - - if not frame.fin: - raise newException(WSFragmentedControlFrameError, - "Control frame cannot be fragmented!") - - if frame.length > 125: - raise newException(WSPayloadTooLarge, - "Control message payload is greater than 125 bytes!") - - try: - var payLoad = newSeq[byte](frame.length.int) - if frame.length > 0: - payLoad.setLen(frame.length.int) - # Read control frame payload. - await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int) - if frame.mask: - mask( - payLoad.toOpenArray(0, payLoad.high), - frame.maskKey) - - # Process control frame payload. - case frame.opcode: - of Opcode.Ping: - if not isNil(ws.onPing): - try: - ws.onPing(payLoad) - except CatchableError as exc: - debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg - - # send pong to remote - await ws.send(payLoad, Opcode.Pong) - of Opcode.Pong: - if not isNil(ws.onPong): - try: - ws.onPong(payLoad) - except CatchableError as exc: - debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg - of Opcode.Close: - await ws.handleClose(frame, payLoad) - else: - raise newException(WSInvalidOpcodeError, "Invalid control opcode!") - except WebSocketError as exc: - debug "Handled websocket exception", exc = exc.msg - raise exc - except CatchableError as exc: - trace "Exception handling control messages", exc = exc.msg - ws.readyState = ReadyState.Closed - await ws.stream.closeWait() - -proc readFrame*(ws: WSSession): Future[Frame] {.async.} = - ## Gets a frame from the WebSocket. - ## See https://tools.ietf.org/html/rfc6455#section-5.2 - ## - - try: - while ws.readyState != ReadyState.Closed: - let frame = await Frame.decode(ws.stream.reader, ws.masked) - debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask - - # return the current frame if it's not one of the control frames - if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: - await ws.handleControl(frame) # process control frames# process control frames - continue - - return frame - except WebSocketError as exc: - trace "Websocket error", exc = exc.msg - raise exc - except CatchableError as exc: - debug "Exception reading frame, dropping socket", exc = exc.msg - ws.readyState = ReadyState.Closed - await ws.stream.closeWait() - raise exc - -proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] = - ws.send(data, opcode = Opcode.Ping) - -proc recv*( - ws: WSSession, - data: pointer, - size: int): Future[int] {.async.} = - ## Attempts to read up to `size` bytes - ## - ## Will read as many frames as necessary - ## to fill the buffer until either - ## the message ends (frame.fin) or - ## the buffer is full. If no data is on - ## the pipe will await until at least - ## one byte is available - ## - - var consumed = 0 - var pbuffer = cast[ptr UncheckedArray[byte]](data) - try: - while consumed < size: - # we might have to read more than - # one frame to fill the buffer - - # TODO: Figure out a cleaner way to handle - # retrieving new frames - if isNil(ws.frame): - ws.frame = await ws.readFrame() - - if isNil(ws.frame): - return consumed - - if ws.frame.opcode == Opcode.Cont: - raise newException(WSOpcodeMismatchError, - "Expected Text or Binary frame") - elif (not ws.frame.fin and ws.frame.remainder() <= 0): - ws.frame = await ws.readFrame() - # This could happen if the connection is closed. - - if isNil(ws.frame): - return consumed - - if ws.frame.opcode != Opcode.Cont: - raise newException(WSOpcodeMismatchError, - "Expected Continuation frame") - - ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag - if ws.frame.fin and ws.frame.remainder() <= 0: - ws.frame = nil - break - - let len = min(ws.frame.remainder().int, size - consumed) - if len == 0: - continue - - let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) - if read <= 0: - continue - - if ws.frame.mask: - # unmask data using offset - mask( - pbuffer.toOpenArray(consumed, (consumed + read) - 1), - ws.frame.maskKey, - ws.frame.consumed.int) - - consumed += read - ws.frame.consumed += read.uint64 - - return consumed.int - - except WebSocketError as exc: - debug "Websocket error", exc = exc.msg - ws.readyState = ReadyState.Closed - await ws.stream.closeWait() - raise exc - except CancelledError as exc: - debug "Cancelling reading", exc = exc.msg - raise exc - except CatchableError as exc: - debug "Exception reading frames", exc = exc.msg - -proc recv*( - ws: WSSession, - size = WSMaxMessageSize): Future[seq[byte]] {.async.} = - ## Attempt to read a full message up to max `size` - ## bytes in `frameSize` chunks. - ## - ## If no `fin` flag arrives await until either - ## cancelled or the `fin` flag arrives. - ## - ## If message is larger than `size` a `WSMaxMessageSizeError` - ## exception is thrown. - ## - ## In all other cases it awaits a full message. - ## - var res: seq[byte] - try: - while ws.readyState != ReadyState.Closed: - var buf = newSeq[byte](ws.frameSize) - let read = await ws.recv(addr buf[0], buf.len) - if read <= 0: - break - - buf.setLen(read) - if res.len + buf.len > size: - raise newException(WSMaxMessageSizeError, "Max message size exceeded") - - res.add(buf) - - # no more frames - if isNil(ws.frame): - break - - # read the entire message, exit - if ws.frame.fin and ws.frame.remainder().int <= 0: - break - except WebSocketError as exc: - debug "Websocket error", exc = exc.msg - raise exc - except CancelledError as exc: - debug "Cancelling reading", exc = exc.msg - raise exc - except CatchableError as exc: - debug "Exception reading frames", exc = exc.msg - - return res - -proc close*( - ws: WSSession, - code: Status = Status.Fulfilled, - reason: string = "") {.async.} = - ## Close the Socket, sends close packet. - ## - - if ws.readyState != ReadyState.Open: - return - - try: - ws.readyState = ReadyState.Closing - await ws.send( - prepareCloseBody(code, reason), - opcode = Opcode.Close) - - # read frames until closed - while ws.readyState != ReadyState.Closed: - discard await ws.recv() - - except CatchableError as exc: - debug "Exception closing", exc = exc.msg - proc initiateHandshake( uri: Uri, address: TransportAddress,