From 3e1599d7908755ff087efd29bd71a370f80297ef Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Fri, 11 Jun 2021 14:04:09 -0600 Subject: [PATCH] Fix partial frame handling and allow extensions to hijack the flow (#56) * moving files around * wip * wip * move tls example into server example * add tls functionality * rename * rename * fix tests * move extension related files to own folder * use trace instead of debug * export extensions * rework partial frame handling and closing * rework status codes as distincts * logging * re-enable extensions processing for frames * enable all test for non-tls server * remove tlsserver * remove offset to mask - don't think we need it * pass sessions extensions when calling send/recv * adding encode/decode extensions flow test * move server/client setup to helpers * proper frame order execution on decode * fix tls tests --- .github/workflows/ci.yml | 4 +- autobahn/fuzzingclient.json | 2 +- examples/client.nim | 19 +- examples/server.nim | 30 +- examples/tlsclient.nim | 35 -- examples/tlsserver.nim | 54 -- tests/extensions/testextflow.nim | 276 +++++++++ tests/helpers.nim | 81 +++ tests/testcommon.nim | 2 +- .../{test_ext_utils.nim => testextutils.nim} | 20 +- tests/testutf8.nim | 16 +- tests/testwebsockets.nim | 541 ++++++++++-------- ws/extensions.nim | 6 + ws/extensions/compression/compression.nim | 212 +++++++ ws/{ext_utils.nim => extensions/extutils.nim} | 0 ws/frame.nim | 35 +- ws/http/client.nim | 5 +- ws/http/common.nim | 6 +- ws/http/server.nim | 26 +- ws/session.nim | 236 +++++--- ws/types.nim | 78 ++- ws/{utf8_dfa.nim => utf8dfa.nim} | 0 ws/ws.nim | 10 +- 23 files changed, 1159 insertions(+), 535 deletions(-) delete mode 100644 examples/tlsclient.nim delete mode 100644 examples/tlsserver.nim create mode 100644 tests/extensions/testextflow.nim create mode 100644 tests/helpers.nim rename tests/{test_ext_utils.nim => testextutils.nim} (98%) create mode 100644 ws/extensions.nim create mode 100644 ws/extensions/compression/compression.nim rename ws/{ext_utils.nim => extensions/extutils.nim} (100%) rename ws/{utf8_dfa.nim => utf8dfa.nim} (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe49e32..d66f0d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -232,8 +232,8 @@ jobs: kill $pid cd .. - nim c examples/tlsserver.nim - examples/tlsserver & + nim -d:tls c examples/server.nim + examples/server & pid=$! cd autobahn wstest --mode fuzzingclient --spec fuzzingclient_tls.json diff --git a/autobahn/fuzzingclient.json b/autobahn/fuzzingclient.json index e41c882..dc036ce 100644 --- a/autobahn/fuzzingclient.json +++ b/autobahn/fuzzingclient.json @@ -7,6 +7,6 @@ } ], "cases": ["*"], - "exclude-cases": ["9.*", "12.*", "13.*"], + "exclude-cases": [], "exclude-agent-cases": {} } diff --git a/examples/client.nim b/examples/client.nim index f71ada4..0fcf5b2 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -6,12 +6,19 @@ import pkg/[ import ../ws/ws proc main() {.async.} = - let ws = await WebSocket.connect( - "127.0.0.1", - Port(8888), - path = "/ws") + let ws = when defined tls: + await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/wss", + flags = {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}) + else: + await WebSocket.connect( + "127.0.0.1", + Port(8888), + path = "/ws") - debug "Websocket client: ", State = ws.readyState + trace "Websocket client: ", State = ws.readyState let reqData = "Hello Server" while true: @@ -22,7 +29,7 @@ proc main() {.async.} = break let dataStr = string.fromBytes(buff) - debug "Server Response: ", data = dataStr + trace "Server Response: ", data = dataStr assert dataStr == reqData break diff --git a/examples/server.nim b/examples/server.nim index 5bd1cbf..017fde6 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -5,13 +5,15 @@ import pkg/[chronos, httputils] import ../ws/ws +import ../tests/keys proc handle(request: HttpRequest) {.async.} = - debug "Handling request:", uri = request.uri.path - if request.uri.path != "/ws": + trace "Handling request:", uri = request.uri.path + let path = when defined tls: "/wss" else: "/ws" + if request.uri.path != path: return - debug "Initiating web socket connection." + trace "Initiating web socket connection." try: let server = WSServer.new() let ws = await server.handleRequest(request) @@ -19,16 +21,14 @@ proc handle(request: HttpRequest) {.async.} = error "Failed to open websocket connection" return - debug "Websocket handshake completed" - while true: + trace "Websocket handshake completed" + while ws.readyState != ReadyState.Closed: let recvData = await ws.recv() - if ws.readyState == ReadyState.Closed: - debug "Websocket closed" - break - debug "Client Response: ", size = recvData.len + trace "Client Response: ", size = recvData.len, binary = ws.binary await ws.send(recvData, if ws.binary: Opcode.Binary else: Opcode.Text) + except WebSocketError as exc: error "WebSocket error:", exception = exc.msg @@ -37,10 +37,18 @@ when isMainModule: let address = initTAddress("127.0.0.1:8888") socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - server = HttpServer.create(address, handle, flags = socketFlags) + server = when defined tls: + TlsHttpServer.create( + address = address, + handler = handle, + tlsPrivateKey = TLSPrivateKey.init(SecureKey), + tlsCertificate = TLSCertificate.init(SecureCert), + flags = socketFlags) + else: + HttpServer.create(address, handle, flags = socketFlags) server.start() - info "Server listening at ", data = $server.localAddress() + trace "Server listening on ", data = $server.localAddress() await server.join() waitFor(main()) diff --git a/examples/tlsclient.nim b/examples/tlsclient.nim deleted file mode 100644 index 3107964..0000000 --- a/examples/tlsclient.nim +++ /dev/null @@ -1,35 +0,0 @@ -import pkg/[chronos, - chronos/streams/tlsstream, - chronicles, - stew/byteutils] - -import ../ws/ws - -proc main() {.async.} = - let ws = await WebSocket.tlsConnect( - "127.0.0.1", - Port(8888), - path = "/wss", - protocols = @["myfancyprotocol"], - flags = {NoVerifyHost, NoVerifyServerName}) - debug "Websocket client: ", State = ws.readyState - - let reqData = "Hello Server" - try: - debug "sending client " - await ws.send(reqData) - let buff = await ws.recv() - if buff.len <= 0: - break - let dataStr = string.fromBytes(buff) - debug "Server:", data = dataStr - - assert dataStr == reqData - return # bail out - except WebSocketError as exc: - error "WebSocket error:", exception = exc.msg - - # close the websocket - await ws.close() - -waitFor(main()) diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim deleted file mode 100644 index 9c32740..0000000 --- a/examples/tlsserver.nim +++ /dev/null @@ -1,54 +0,0 @@ -import pkg/[chronos, - chronicles, - httputils, - stew/byteutils] - -import pkg/[chronos/streams/tlsstream] - -import ../ws/ws -import ../tests/keys - -proc handle(request: HttpRequest) {.async.} = - debug "Handling request:", uri = request.uri.path - if request.uri.path != "/wss": - debug "Initiating web socket connection." - return - - try: - let server = WSServer.new(protos = ["myfancyprotocol"]) - var ws = await server.handleRequest(request) - if ws.readyState != Open: - error "Failed to open websocket connection." - return - debug "Websocket handshake completed." - # Only reads header for data frame. - echo "receiving server " - let recvData = await ws.recv() - if recvData.len <= 0: - debug "Empty messages" - break - - if ws.readyState == ReadyState.Closed: - return - debug "Response: ", data = string.fromBytes(recvData) - await ws.send(recvData, - if ws.binary: Opcode.Binary else: Opcode.Text) - except WebSocketError: - error "WebSocket error:", exception = getCurrentExceptionMsg() - -when isMainModule: - proc main() {.async.} = - let address = initTAddress("127.0.0.1:8888") - let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} - let server = TlsHttpServer.create( - address = address, - handler = handle, - tlsPrivateKey = TLSPrivateKey.init(SecureKey), - tlsCertificate = TLSCertificate.init(SecureCert), - flags = socketFlags) - - server.start() - info "Server listening at ", data = $server.localAddress() - await server.join() - - waitFor(main()) \ No newline at end of file diff --git a/tests/extensions/testextflow.nim b/tests/extensions/testextflow.nim new file mode 100644 index 0000000..9abb434 --- /dev/null +++ b/tests/extensions/testextflow.nim @@ -0,0 +1,276 @@ +import std/strutils +import pkg/[chronos, stew/byteutils] + +import ../../ws/ws +import ../asyncunit + +type + ExtHandler = proc(ext: Ext, frame: Frame): Future[Frame] {.raises: [Defect].} + + HelperExtension = ref object of Ext + handler*: ExtHandler + +proc new*( + T: typedesc[HelperExtension], + handler: ExtHandler, + session: WSSession = nil): HelperExtension = + HelperExtension( + handler: handler, + name: "HelperExtension") + +method decode*( + self: HelperExtension, + frame: Frame): Future[Frame] {.async.} = + return await self.handler(self, frame) + +method encode*( + self: HelperExtension, + frame: Frame): Future[Frame] {.async.} = + return await self.handler(self, frame) + +const TestString = "Hello" + +suite "Encode frame extensions flow": + test "should call extension on encode": + var data = "" + proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "toUpper executed" + data = string.fromBytes(frame.data).toUpper() + check TestString.toUpper() == data + frame.data = data.toBytes() + return frame + + var frame = Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: TestString.toBytes()) + + discard await frame.encode(@[HelperExtension.new(toUpper).Ext]) + check frame.data == TestString.toUpper().toBytes() + + test "should call extensions in correct order on encode": + var count = 0 + proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "first executed" + check count == 0 + count.inc + + return frame + + proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "second executed" + check count == 1 + count.inc + + return frame + + var frame = Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: false, + data: TestString.toBytes()) + + discard await frame.encode(@[ + HelperExtension.new(first).Ext, + HelperExtension.new(second).Ext]) + + check count == 2 + + test "should allow modifying frame headers": + proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "changeHeader executed" + frame.rsv1 = true + frame.rsv2 = true + frame.rsv3 = true + frame.opcode = Opcode.Binary + return frame + + var frame = Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, # fragments have to be `Continuation` frames + mask: false, + data: TestString.toBytes()) + + discard await frame.encode(@[HelperExtension.new(changeHeader).Ext]) + check: + frame.rsv1 == true + frame.rsv2 == true + frame.rsv2 == true + frame.opcode == Opcode.Binary + +suite "Decode frame extensions flow": + var + address: TransportAddress + server: StreamServer + maskKey = genMaskKey(newRng()) + transport: StreamTransport + reader: AsyncStreamReader + frame: Frame + + setup: + server = createStreamServer( + initTAddress("127.0.0.1:0"), + flags = {ServerFlags.ReuseAddr}) + address = server.localAddress() + + teardown: + await transport.closeWait() + await server.closeWait() + server.stop() + + test "should call extension on decode": + var data = "" + proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "toUpper executed" + try: + var buf = newSeq[byte](frame.length) + # read data + await reader.readExactly(addr buf[0], buf.len) + if frame.mask: + mask(buf, maskKey) + frame.mask = false # we can reset the mask key here + + data = string.fromBytes(buf).toUpper() + check: + TestString.toUpper() == data + + frame.data = data.toBytes() + return frame + except CatchableError as exc: + checkpoint exc.msg + check false + + proc acceptHandler() {.async, gcsafe.} = + let transport = await server.accept() + reader = newAsyncStreamReader(transport) + frame = await Frame.decode( + reader, + false, + @[HelperExtension.new(toUpper).Ext]) + + await reader.closeWait() + await transport.closeWait() + + let handlerWait = acceptHandler() + var encodedFrame = (await Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + maskKey: maskKey, + data: TestString.toBytes()) + .encode()) + + transport = await connect(address) + let wrote = await transport.write(encodedFrame) + + await handlerWait + check: + wrote == encodedFrame.len + frame.data == TestString.toUpper().toBytes() + + test "should call extensions in reverse order on decode": + var count = 0 + proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "first executed" + check count == 1 + count.inc + + return frame + + proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "second executed" + check count == 0 + count.inc + + return frame + + proc acceptHandler() {.async, gcsafe.} = + let transport = await server.accept() + reader = newAsyncStreamReader(transport) + frame = await Frame.decode( + reader, + false, + @[HelperExtension.new(first).Ext, + HelperExtension.new(second).Ext]) + + await reader.closeWait() + await transport.closeWait() + + let handlerWait = acceptHandler() + var encodedFrame = (await Frame( + fin: false, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode.Text, + mask: true, + maskKey: maskKey, + data: TestString.toBytes()) + .encode()) + + let transport = await connect(address) + let wrote = await transport.write(encodedFrame) + + await handlerWait + check: + wrote == encodedFrame.len + count == 2 + + test "should allow modifying frame headers": + proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} = + checkpoint "changeHeader executed" + frame.rsv1 = false + frame.rsv2 = false + frame.rsv3 = false + frame.opcode = Opcode.Binary + + return frame + + proc acceptHandler() {.async, gcsafe.} = + let transport = await server.accept() + reader = newAsyncStreamReader(transport) + frame = await Frame.decode( + reader, + false, + @[HelperExtension.new(changeHeader).Ext]) + + check: + frame.rsv1 == false + frame.rsv2 == false + frame.rsv2 == false + frame.opcode == Opcode.Binary + + await reader.closeWait() + await transport.closeWait() + + let handlerWait = acceptHandler() + var encodedFrame = (await Frame( + fin: false, + rsv1: true, + rsv2: true, + rsv3: true, + opcode: Opcode.Text, + mask: true, + maskKey: maskKey, + data: TestString.toBytes()) + .encode()) + + let transport = await connect(address) + let wrote = await transport.write(encodedFrame) + + await handlerWait + check: + wrote == encodedFrame.len diff --git a/tests/helpers.nim b/tests/helpers.nim new file mode 100644 index 0000000..13c3753 --- /dev/null +++ b/tests/helpers.nim @@ -0,0 +1,81 @@ +import std/[strutils, random] +import pkg/[ + chronos, + chronos/streams/tlsstream, + httputils, + chronicles, + stew/byteutils] + +import ../ws/ws +import ./keys + +let + WSSecureKey* = TLSPrivateKey.init(SecureKey) + WSSecureCert* = TLSCertificate.init(SecureCert) + +const WSPath* = when defined secure: "/wss" else: "/ws" + +proc rndStr*(size: int): string = + for _ in 0.. 2: + return err("window bits expect 2 bytes, got " & $arg.value.len) + + for n in arg.value: + if n notin Digits: + return err("window bits value contains illegal char: " & $n) + + var winbit = 0 + for i in 0.. 15: + return err("window bits should between 8-15, got " & $winbit) + + res = winbit + return ok("=" & arg.value) + +proc validateParams(args: seq[ExtParam], + opts: var DeflateOpts): Result[string, string] = + # besides validating extensions params, this proc + # also constructing extension param for response + var resp = "" + for arg in args: + case arg.name + of "server_no_context_takeover": + if arg.value.len > 0: + return err("'server_no_context_takeover' should have no param") + if opts.isServer: + concatParam(resp, arg.name) + opts.serverNoContextTakeOver = true + of "client_no_context_takeover": + if arg.value.len > 0: + return err("'client_no_context_takeover' should have no param") + if opts.isServer: + concatParam(resp, arg.name) + opts.clientNoContextTakeOver = true + of "server_max_window_bits": + if opts.isServer: + concatParam(resp, arg.name) + let res = validateWindowBits(arg, opts.serverMaxWindowBits) + if res.isErr: + return res + resp.add res.get() + of "client_max_window_bits": + if opts.isServer: + concatParam(resp, arg.name) + let res = validateWindowBits(arg, opts.clientMaxWindowBits) + if res.isErr: + return res + resp.add res.get() + else: + return err("unrecognized param: " & arg.name) + + ok(resp) + +method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} = + if not ext.messageCompressed: + return data + + # TODO: append trailing bytes + var mz = MzStream( + next_in: cast[ptr cuchar](data[0].unsafeAddr), + avail_in: data.len.cuint + ) + + let windowBits = if ext.opts.serverMaxWindowBits == 0: + MZ_DEFAULT_WINDOW_BITS + else: + MzWindowBits(ext.opts.serverMaxWindowBits) + + doAssert(mz.inflateInit2(windowBits) == MZ_OK) + var res: seq[byte] + var buf: array[0xFFFF, byte] + + while true: + mz.next_out = cast[ptr cuchar](buf[0].addr) + mz.avail_out = buf.len.cuint + let r = mz.inflate(MZ_SYNC_FLUSH) + let outSize = buf.len - mz.avail_out.int + res.add toOpenArray(buf, 0, outSize-1) + if r == MZ_STREAM_END: + break + elif r == MZ_OK: + continue + else: + doAssert(false, "decompression error") + + doAssert(mz.inflateEnd() == MZ_OK) + return res + +method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} = + var mz = MzStream( + next_in: cast[ptr cuchar](data[0].unsafeAddr), + avail_in: data.len.cuint + ) + + let windowBits = if ext.opts.serverMaxWindowBits == 0: + MZ_DEFAULT_WINDOW_BITS + else: + MzWindowBits(ext.opts.serverMaxWindowBits) + + doAssert(mz.deflateInit2( + level = MZ_DEFAULT_LEVEL, + meth = MZ_DEFLATED, + windowBits, + 1, + strategy = MZ_DEFAULT_STRATEGY) == MZ_OK + ) + + let maxSize = mz.deflateBound(data.len.culong).int + var res: seq[byte] + var buf: array[0xFFFF, byte] + + while true: + mz.next_out = cast[ptr cuchar](buf[0].addr) + mz.avail_out = buf.len.cuint + let r = mz.deflate(MZ_FINISH) + let outSize = buf.len - mz.avail_out.int + res.add toOpenArray(buf, 0, outSize-1) + if r == MZ_STREAM_END: + break + elif r == MZ_OK: + continue + else: + doAssert(false, "compression error") + + # TODO: cut trailing bytes + doAssert(mz.deflateEnd() == MZ_OK) + ext.messageCompressed = res.len < data.len + if ext.messageCompressed: + return res + else: + return data + +method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode in {Opcode.Text, Opcode.Binary}: + # only data frame can be compressed + # and we want to know if this message is compressed or not + # if the frame opcode is text or binary, it should also the first frame + ext.messageCompressed = frame.rsv1 + # clear rsv1 bit because we already done with it + frame.rsv1 = false + return frame + +method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode in {Opcode.Text, Opcode.Binary}: + # only data frame can be compressed + # and we only set rsv1 bit to true if the message is compressible + # if the frame opcode is text or binary, it should also the first frame + frame.rsv1 = ext.messageCompressed + return frame + +method toHttpOptions(ext: DeflateExt): string = + # using paramStr here is a bit clunky + extID & "; " & ext.paramStr + +proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} = + var opts = DeflateOpts(isServer: isServer) + let resp = validateParams(args, opts) + if resp.isErr: + return err(resp.error) + let ext = DeflateExt( + name: extID, + paramStr: resp.get(), + opts: opts + ) + ok(ext) + +const + deflateFactory* = (extID, deflateExtFactory) diff --git a/ws/ext_utils.nim b/ws/extensions/extutils.nim similarity index 100% rename from ws/ext_utils.nim rename to ws/extensions/extutils.nim diff --git a/ws/frame.nim b/ws/frame.nim index 5055a69..a0f52a0 100644 --- a/ws/frame.nim +++ b/ws/frame.nim @@ -15,8 +15,12 @@ import pkg/[ stew/byteutils, stew/endians2, stew/results] + import ./types +logScope: + topics = "ws-frame" + #[ +---------------------------------------------------------------+ |0 1 2 3 | @@ -54,7 +58,6 @@ template remainder*(frame: Frame): uint64 = proc encode*( frame: Frame, - offset = 0, extensions: seq[Ext] = @[]): Future[seq[byte]] {.async.} = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 @@ -65,7 +68,7 @@ proc encode*( f = await e.encode(f) var ret: seq[byte] - var b0 = (f.opcode.uint8 and 0x0F) # 0th byte: opcodes and flags. + var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. if f.fin: b0 = b0 or 128'u8 @@ -77,7 +80,7 @@ proc encode*( if f.data.len <= 125: b1 = f.data.len.uint8 - elif f.data.len > 125 and f.data.len <= 0xFFFF: + elif f.data.len > 125 and f.data.len <= 0xffff: b1 = 126'u8 else: b1 = 127'u8 @@ -88,12 +91,12 @@ proc encode*( ret.add(uint8 b1) # Only need more bytes if data len is 7+16 bits, or 7+64 bits. - if f.data.len > 125 and f.data.len <= 0xFFFF: + if f.data.len > 125 and f.data.len <= 0xffff: # Data len is 7+16 bits. var len = f.data.len.uint16 - ret.add ((len shr 8) and 0xFF).uint8 - ret.add (len and 0xFF).uint8 - elif f.data.len > 0xFFFF: + ret.add ((len shr 8) and 0xff).uint8 + ret.add (len and 0xff).uint8 + elif f.data.len > 0xffff: # Data len is 7+64 bits. var len = f.data.len.uint64 ret.add(len.toBytesBE()) @@ -101,7 +104,7 @@ proc encode*( var data = f.data if f.mask: # If we need to mask it generate random mask key and mask the data. - mask(data, f.maskKey, offset) + mask(data, f.maskKey) # Write mask key next. ret.add(f.maskKey[0].uint8) @@ -122,10 +125,10 @@ proc decode*( ## var header = newSeq[byte](2) - debug "Reading new frame" + trace "Reading new frame" await reader.readExactly(addr header[0], 2) if header.len != 2: - debug "Invalid websocket header length" + trace "Invalid websocket header length" raise newException(WSMalformedHeaderError, "Invalid websocket header length") @@ -147,10 +150,6 @@ proc decode*( frame.opcode = (opcode).Opcode - # If any of the rsv are set close the socket. - if frame.rsv1 or frame.rsv2 or frame.rsv3: - raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") - # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. var finalLen: uint64 = 0 @@ -187,7 +186,11 @@ proc decode*( frame.maskKey[i] = cast[char](maskKey[i]) if extensions.len > 0: - for e in extensions[extensions.high..extensions.low]: - frame = await e.decode(frame) + for i in countdown(extensions.high, extensions.low): + frame = await extensions[i].decode(frame) + + # If any of the rsv are set close the socket. + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") return frame diff --git a/ws/http/client.nim b/ws/http/client.nim index 2b63c12..7fcfc45 100644 --- a/ws/http/client.nim +++ b/ws/http/client.nim @@ -9,6 +9,9 @@ import pkg/[ import ./common +logScope: + topics = "http-client" + type HttpClient* = ref object of RootObj connected*: bool @@ -44,7 +47,7 @@ proc readResponse(stream: AsyncStreamReader): Future[HttpResponseHeader] {.async return buffer.parseResponse() except CatchableError as exc: - debug "Exception reading headers", exc = exc.msg + trace "Exception reading headers", exc = exc.msg buffer.setLen(0) raise exc diff --git a/ws/http/common.nim b/ws/http/common.nim index 5ad8666..85d1171 100644 --- a/ws/http/common.nim +++ b/ws/http/common.nim @@ -53,8 +53,7 @@ proc closeWait*(stream: AsyncStream) {.async.} = await allFutures( stream.reader.tsource.closeTransp(), stream.reader.closeStream(), - stream.writer.closeStream() - ) + stream.writer.closeStream()) proc sendResponse*( request: HttpRequest, @@ -112,8 +111,7 @@ proc sendError*( response.add(CRLF) await stream.write( - response.toBytes() & - content.toBytes()) + response.toBytes() & content.toBytes()) proc sendError*( request: HttpRequest, diff --git a/ws/http/server.nim b/ws/http/server.nim index 15218c4..2173653 100644 --- a/ws/http/server.nim +++ b/ws/http/server.nim @@ -30,13 +30,13 @@ proc validateRequest( ## if header.meth notin {MethodGet}: - debug "GET method is only allowed", address = stream.tsource.remoteAddress() + trace "GET method is only allowed", address = stream.tsource.remoteAddress() await stream.sendError(Http405, version = header.version) return ReqStatus.Error var hlen = header.contentLength() if hlen < 0 or hlen > MaxHttpRequestSize: - debug "Invalid header length", address = stream.tsource.remoteAddress() + trace "Invalid header length", address = stream.tsource.remoteAddress() await stream.sendError(Http413, version = header.version) return ReqStatus.Error @@ -50,14 +50,14 @@ proc handleRequest( var buffer = newSeq[byte](MaxHttpHeadersSize) let remoteAddr = stream.reader.tsource.remoteAddress() - debug "Received connection", address = $remoteAddr + trace "Received connection", address = $remoteAddr try: let hlenfut = stream.reader.readUntil( addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) let ores = await withTimeout(hlenfut, HttpHeadersTimeout) if not ores: # Timeout - debug "Timeout expired while receiving headers", address = $remoteAddr + trace "Timeout expired while receiving headers", address = $remoteAddr await stream.writer.sendError(Http408, version = HttpVersion11) return @@ -66,7 +66,7 @@ proc handleRequest( let requestData = buffer.parseRequest() if requestData.failed(): # Header could not be parsed - debug "Malformed header received", address = $remoteAddr + trace "Malformed header received", address = $remoteAddr await stream.writer.sendError(Http400, version = HttpVersion11) return @@ -79,10 +79,10 @@ proc handleRequest( res if vres == ReqStatus.ErrorFailure: - debug "Remote peer disconnected", address = $remoteAddr + trace "Remote peer disconnected", address = $remoteAddr return - debug "Received valid HTTP request", address = $remoteAddr + trace "Received valid HTTP request", address = $remoteAddr # Call the user's handler. if server.handler != nil: await server.handler( @@ -92,15 +92,15 @@ proc handleRequest( uri: requestData.uri().parseUri())) except TransportLimitError: # size of headers exceeds `MaxHttpHeadersSize` - debug "Maximum size of headers limit reached", address = $remoteAddr + trace "maximum size of headers limit reached", address = $remoteAddr await stream.writer.sendError(Http413, version = HttpVersion11) except TransportIncompleteError: # remote peer disconnected - debug "Remote peer disconnected", address = $remoteAddr + trace "Remote peer disconnected", address = $remoteAddr except TransportOsError as exc: - debug "Problems with networking", address = $remoteAddr, error = exc.msg + trace "Problems with networking", address = $remoteAddr, error = exc.msg except CatchableError as exc: - debug "Unknown exception", address = $remoteAddr, error = exc.msg + trace "Unknown exception", address = $remoteAddr, error = exc.msg finally: await stream.closeWait() @@ -151,6 +151,8 @@ proc create*( flags, child = StreamServer(server))) + trace "Created HTTP Server", host = $address + return server proc create*( @@ -191,6 +193,8 @@ proc create*( flags, child = StreamServer(server))) + trace "Created TLS HTTP Server", host = $address + return server proc create*( diff --git a/ws/session.nim b/ws/session.nim index c9830a4..d200c6a 100644 --- a/ws/session.nim +++ b/ws/session.nim @@ -9,27 +9,27 @@ {.push raises: [Defect].} +import std/strformat import pkg/[chronos, chronicles, stew/byteutils, stew/endians2] -import ./types, ./frame, ./utils, ./utf8_dfa, ./http +import ./types, ./frame, ./utils, ./utf8dfa, ./http -import pkg/chronos/[streams/asyncstream] +import pkg/chronos/streams/asyncstream -type - WSSession* = ref object of WebSocket - stream*: AsyncStream - frame*: Frame - proto*: string +logScope: + topics = "ws-session" -proc prepareCloseBody(code: Status, reason: string): seq[byte] = +proc prepareCloseBody(code: StatusCodes, reason: string): seq[byte] = result = reason.toBytes if ord(code) > 999: result = @(ord(code).uint16.toBytesBE()) & result -proc send*( +proc writeMessage*( ws: WSSession, data: seq[byte] = @[], - opcode: Opcode) {.async.} = - ## Send a frame + opcode: Opcode, + extensions: seq[Ext]) {.async.} = + ## Send a frame applying the supplied + ## extensions ## if ws.readyState == ReadyState.Closed: @@ -40,7 +40,7 @@ proc send*( dataSize = data.len masked = ws.masked - debug "Sending data to remote" + trace "Sending data to remote" var maskKey: array[4, char] if ws.masked: @@ -61,30 +61,40 @@ proc send*( mask: ws.masked, data: data, # allow sending data with close messages maskKey: maskKey) - .encode(extensions = ws.extensions))) + .encode())) return let maxSize = ws.frameSize var i = 0 while ws.readyState notin {ReadyState.Closing}: - let len = min(data.len, (maxSize + i)) - await ws.stream.writer.write( - (await Frame( - fin: if (i + len >= data.len): true else: false, + let len = min(data.len, maxSize) + let frame = Frame( + fin: if (len + i >= data.len): true else: false, rsv1: false, rsv2: false, rsv3: false, opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames mask: ws.masked, - data: data[i ..< len], + data: data[i ..< len + i], maskKey: maskKey) - .encode())) + + let encoded = await frame.encode(extensions) + await ws.stream.writer.write(encoded) i += len if i >= data.len: break +proc send*( + ws: WSSession, + data: seq[byte] = @[], + opcode: Opcode): Future[void] = + ## Send a frame + ## + + return ws.writeMessage(data, opcode, ws.extensions) + proc send*(ws: WSSession, data: string): Future[void] = send(ws, data.toBytes(), Opcode.Text) @@ -101,54 +111,62 @@ proc handleClose*( opcode = frame.opcode readyState = ws.readyState - debug "Handling close" + trace "Handling close" - if ws.readyState notin {ReadyState.Open}: - debug "Connection isn't open, abortig close sequence!" + if ws.readyState != ReadyState.Open: + trace "Connection isn't open, aborting close sequence!" return var - code = Status.Fulfilled + code = StatusFulfilled reason = "" - if payLoad.len == 1: + case payload.len: + of 0: + code = StatusNoStatus + of 1: raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!") - - if payLoad.len > 1: - # first two bytes are the status - let ccode = uint16.fromBytesBE(payLoad[0..<2]) - if ccode <= 999 or ccode > 1015: - raise newException(WSInvalidCloseCodeError, - "Invalid code in close message!") - + else: try: - code = Status(ccode) + code = StatusCodes(uint16.fromBytesBE(payLoad[0..<2])) except RangeError: raise newException(WSInvalidCloseCodeError, "Status code out of range!") - # remining payload bytes are reason for closing + if code in StatusNotUsed or + code in StatusReservedProtocol: + raise newException(WSInvalidCloseCodeError, + &"Can't use reserved status code: {code}") + + if code == StatusReserved or + code == StatusNoStatus or + code == StatusClosedAbnormally: + raise newException(WSInvalidCloseCodeError, + &"Can't use reserved status code: {code}") + + # remaining payload bytes are reason for closing reason = string.fromBytes(payLoad[2..payLoad.high]) if not ws.binary and validateUTF8(reason) == false: raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected in close reason") - var rcode: Status - if code in {Status.Fulfilled}: - rcode = Status.Fulfilled - + trace "Handling close message", code, reason if not isNil(ws.onClose): try: - (rcode, reason) = ws.onClose(code, reason) + (code, reason) = ws.onClose(code, reason) except CatchableError as exc: - debug "Exception in Close callback, this is most likely a bug", exc = exc.msg + trace "Exception in Close callback, this is most likely a bug", exc = exc.msg + else: + code = StatusFulfilled + reason = "" # don't respond to a terminated connection if ws.readyState != ReadyState.Closing: ws.readyState = ReadyState.Closing - await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) + trace "Sending close", code, reason + await ws.send(prepareCloseBody(code, reason), Opcode.Close) ws.readyState = ReadyState.Closed await ws.stream.closeWait() @@ -164,7 +182,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} = readyState = ws.readyState len = frame.length - debug "Handling control frame" + trace "Handling control frame" if not frame.fin: raise newException(WSFragmentedControlFrameError, @@ -191,7 +209,7 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} = try: ws.onPing(payLoad) except CatchableError as exc: - debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg + trace "Exception in Ping callback, this is most likely a bug", exc = exc.msg # send pong to remote await ws.send(payLoad, Opcode.Pong) @@ -200,21 +218,28 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} = try: ws.onPong(payLoad) except CatchableError as exc: - debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg + trace "Exception in Pong callback, this is most likely a bug", exc = exc.msg of Opcode.Close: await ws.handleClose(frame, payLoad) else: raise newException(WSInvalidOpcodeError, "Invalid control opcode!") -proc readFrame*(ws: WSSession): Future[Frame] {.async.} = +proc readFrame*(ws: WSSession, extensions: seq[Ext] = @[]): Future[Frame] {.async.} = ## Gets a frame from the WebSocket. ## See https://tools.ietf.org/html/rfc6455#section-5.2 ## while ws.readyState != ReadyState.Closed: let frame = await Frame.decode( - ws.stream.reader, ws.masked, ws.extensions) - debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask + ws.stream.reader, ws.masked, extensions) + + logScope: + opcode = frame.opcode + len = frame.length + mask = frame.mask + fin = frame.fin + + trace "Decoded new frame" # return the current frame if it's not one of the control frames if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: @@ -230,59 +255,70 @@ proc recv*( ws: WSSession, data: pointer, size: int): Future[int] {.async.} = - ## Attempts to read up to `size` bytes + ## Attempts to read up to ``size`` bytes ## - ## Will read as many frames as necessary - ## to fill the buffer until either - ## the message ends (frame.fin) or - ## the buffer is full. If no data is on - ## the pipe will await until at least - ## one byte is available + ## If ``size`` is less than the data in + ## the frame, allow reading partial frames + ## + ## If no data is left in the pipe await + ## until at least one byte is available + ## + ## Otherwise, read as many frames as needed + ## up to ``size`` bytes, note that we do break + ## at message boundaries (``fin`` flag set). + ## + ## Use this to stream data from frames ## var consumed = 0 var pbuffer = cast[ptr UncheckedArray[byte]](data) try: + var first = true + + # reset previous frame if nothing is left in it + if not isNil(ws.frame) and ws.frame.remainder <= 0: + trace "Resetting previous frame" + first = ws.frame.fin # set as first frame if last frame was final + ws.frame = nil + + if isNil(ws.frame): + ws.frame = await ws.readFrame(ws.extensions) + while consumed < size: - # we might have to read more than - # one frame to fill the buffer - - # TODO: Figure out a cleaner way to handle - # retrieving new frames if isNil(ws.frame): - ws.frame = await ws.readFrame() - - if isNil(ws.frame): - return consumed - - if ws.frame.opcode == Opcode.Cont: - raise newException(WSOpcodeMismatchError, - "Expected Text or Binary frame") - elif (not ws.frame.fin and ws.frame.remainder() <= 0): - ws.frame = await ws.readFrame() - # This could happen if the connection is closed. - - if isNil(ws.frame): - return consumed - - if ws.frame.opcode != Opcode.Cont: - raise newException(WSOpcodeMismatchError, - "Expected Continuation frame") - - ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag - if ws.frame.fin and ws.frame.remainder() <= 0: - ws.frame = nil + trace "Empty frame, breaking" break - let len = min(ws.frame.remainder().int, size - consumed) - if len == 0: - continue + logScope: + first = first + fin = ws.frame.fin + len = ws.frame.length + consumed = ws.frame.consumed + remainder = ws.frame.remainder + opcode = ws.frame.opcode + masked = ws.frame.mask + + if first == (ws.frame.opcode == Opcode.Cont): + error "Opcode mismatch!" + raise newException(WSOpcodeMismatchError, + &"Opcode mismatch: first: {first}, opcode: {ws.frame.opcode}") + + if first: + ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag + trace "Setting binary flag" + + let len = min(ws.frame.remainder.int, size - consumed) + if len <= 0: + trace "Nothing left to read, breaking!" + break let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) if read <= 0: - continue + trace "Didn't read any bytes, breaking" + break if ws.frame.mask: + trace "Unmasking frame" # unmask data using offset mask( pbuffer.toOpenArray(consumed, (consumed + read) - 1), @@ -292,15 +328,31 @@ proc recv*( consumed += read ws.frame.consumed += read.uint64 + trace "Read data from frame", read + # all has been consumed from the frame + # read the next frame + if ws.frame.remainder <= 0: + first = false + + if ws.frame.fin: # we're at the end of the message, break + trace "Read all frames, breaking" + ws.frame = nil + break + + ws.frame = await ws.readFrame(ws.extensions) + if not ws.binary and validateUTF8(pbuffer.toOpenArray(0, consumed - 1)) == false: raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") - return consumed.int + return consumed except CatchableError as exc: ws.readyState = ReadyState.Closed await ws.stream.closeWait() - debug "Exception reading frames", exc = exc.msg + trace "Exception reading frames", exc = exc.msg raise exc + finally: + if not isNil(ws.frame) and (ws.frame.fin and ws.frame.remainder <= 0): + ws.frame = nil proc recv*( ws: WSSession, @@ -318,15 +370,14 @@ proc recv*( ## var res: seq[byte] while ws.readyState != ReadyState.Closed: - var buf = newSeq[byte](ws.frameSize) + var buf = newSeq[byte](min(size, ws.frameSize)) let read = await ws.recv(addr buf[0], buf.len) - if read <= 0: - break buf.setLen(read) if res.len + buf.len > size: raise newException(WSMaxMessageSizeError, "Max message size exceeded") + trace "Read message", size = read res.add(buf) # no more frames @@ -335,13 +386,14 @@ proc recv*( # read the entire message, exit if ws.frame.fin and ws.frame.remainder().int <= 0: + trace "Read full message, breaking!" break return res proc close*( ws: WSSession, - code: Status = Status.Fulfilled, + code = StatusFulfilled, reason: string = "") {.async.} = ## Close the Socket, sends close packet. ## @@ -359,4 +411,4 @@ proc close*( while ws.readyState != ReadyState.Closed: discard await ws.recv() except CatchableError as exc: - debug "Exception closing", exc = exc.msg + trace "Exception closing", exc = exc.msg diff --git a/ws/types.nim b/ws/types.nim index 26729f8..7e95f6c 100644 --- a/ws/types.nim +++ b/ws/types.nim @@ -62,41 +62,18 @@ type length*: uint64 ## Message size. consumed*: uint64 ## how much has been consumed from the frame - Status* {.pure.} = enum - # 0-999 not used - Fulfilled = 1000 - GoingAway = 1001 - ProtocolError = 1002 - CannotAccept = 1003 - # 1004 reserved - NoStatus = 1005 # use by clients - ClosedAbnormally = 1006 # use by clients - Inconsistent = 1007 - PolicyError = 1008 - TooLarge = 1009 - NoExtensions = 1010 - UnexpectedError = 1011 - ReservedCode = 3999 # use by clients - # 3000-3999 reserved for libs - # 4000-4999 reserved for applications + StatusCodes* = distinct range[0..4999] ControlCb* = proc(data: openArray[byte] = []) {.gcsafe, raises: [Defect].} CloseResult* = tuple - code: Status + code: StatusCodes reason: string - CloseCb* = proc(code: Status, reason: string): + CloseCb* = proc(code: StatusCodes, reason: string): CloseResult {.gcsafe, raises: [Defect].} - Ext* = ref object of RootObj - name*: string - options*: Table[string, string] - - ExtFactory* = proc(name: string, options: Table[string, string]): - Ext {.raises: [Defect].} - WebSocket* = ref object of RootObj extensions*: seq[Ext] version*: uint @@ -111,6 +88,21 @@ type onPong*: ControlCb onClose*: CloseCb + WSSession* = ref object of WebSocket + stream*: AsyncStream + frame*: Frame + proto*: string + + Ext* = ref object of RootObj + name*: string + options*: Table[string, string] + session*: WSSession + + ExtFactory* = proc( + name: string, + session: WSSession, + options: Table[string, string]): Ext {.raises: [Defect].} + WebSocketError* = object of CatchableError WSMalformedHeaderError* = object of WebSocketError WSFailedUpgradeError* = object of WebSocketError @@ -125,13 +117,43 @@ type WSClosedError* = object of WebSocketError WSSendError* = object of WebSocketError WSPayloadTooLarge* = object of WebSocketError - WSReserverdOpcodeError* = object of WebSocketError + WSReservedOpcodeError* = object of WebSocketError WSFragmentedControlFrameError* = object of WebSocketError WSInvalidCloseCodeError* = object of WebSocketError WSPayloadLengthError* = object of WebSocketError WSInvalidOpcodeError* = object of WebSocketError WSInvalidUTF8* = object of WebSocketError +const + StatusNotUsed* = (StatusCodes(0)..StatusCodes(999)) + StatusFulfilled* = StatusCodes(1000) + StatusGoingAway* = StatusCodes(1001) + StatusProtocolError* = StatusCodes(1002) + StatusCannotAccept* = StatusCodes(1003) + StatusReserved* = StatusCodes(1004) # 1004 reserved + StatusNoStatus* = StatusCodes(1005) # use by clients + StatusClosedAbnormally* = StatusCodes(1006) # use by clients + StatusInconsistent* = StatusCodes(1007) + StatusPolicyError* = StatusCodes(1008) + StatusTooLarge* = StatusCodes(1009) + StatusNoExtensions* = StatusCodes(1010) + StatusUnexpectedError* = StatusCodes(1011) + StatusFailedTls* = StatusCodes(1015) # passed to applications to indicate TLS errors + StatusReservedProtocol* = StatusCodes(1016)..StatusCodes(2999) # reserved for this protocol + StatusLibsCodes* = (StatusCodes(3000)..StatusCodes(3999)) # 3000-3999 reserved for libs + StatusAppsCodes* = (StatusCodes(4000)..StatusCodes(4999)) # 4000-4999 reserved for apps + +proc `<=`*(a, b: StatusCodes): bool = a.uint16 <= b.uint16 +proc `>=`*(a, b: StatusCodes): bool = a.uint16 >= b.uint16 +proc `<`*(a, b: StatusCodes): bool = a.uint16 < b.uint16 +proc `>`*(a, b: StatusCodes): bool = a.uint16 > b.uint16 +proc `==`*(a, b: StatusCodes): bool = a.uint16 == b.uint16 + +proc high*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.b.uint16 +proc low*(a: HSlice[StatusCodes, StatusCodes]): uint16 = a.a.uint16 + +proc `$`*(a: StatusCodes): string = $(a.int) + proc `name=`*(self: Ext, name: string) = raiseAssert "Can't change extensions name!" @@ -141,5 +163,5 @@ method decode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = method encode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = raiseAssert "Not implemented!" -method toHttpOptions*(self: Ext): string = +method toHttpOptions*(self: Ext): string {.base.} = raiseAssert "Not implemented!" diff --git a/ws/utf8_dfa.nim b/ws/utf8dfa.nim similarity index 100% rename from ws/utf8_dfa.nim rename to ws/utf8dfa.nim diff --git a/ws/ws.nim b/ws/ws.nim index c2017ec..d239e0c 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -23,7 +23,6 @@ import pkg/[chronos, chronicles, httputils, stew/byteutils, - stew/endians2, stew/base64, stew/base10, nimcrypto/sha] @@ -32,6 +31,9 @@ import ./utils, ./frame, ./session, /types, ./http export utils, session, frame, types, http +logScope: + topics = "ws-server" + type WSServer* = ref object of WebSocket protocols: seq[string] @@ -86,7 +88,7 @@ proc connect*( let response = try: await client.request(uri, headers = headers) except CatchableError as exc: - debug "Websocket failed during handshake", exc = exc.msg + trace "Websocket failed during handshake", exc = exc.msg await client.close() raise exc @@ -207,7 +209,7 @@ proc handleRequest*( if ws.version != version: await request.stream.writer.sendError(Http426) - debug "Websocket version not supported", version = ws.version + trace "Websocket version not supported", version = ws.version raise newException(WSVersionError, &"Websocket version not supported, Version: {version}") @@ -236,7 +238,7 @@ proc handleRequest*( if protocol.len > 0: headers.add("Sec-WebSocket-Protocol", protocol) # send back the first matching proto else: - debug "Didn't match any protocol", supported = ws.protocols, requested = wantProtos + trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos try: await request.sendResponse(Http101, headers = headers)