diff --git a/tests/extensions/base64ext.nim b/tests/extensions/base64ext.nim new file mode 100644 index 0000000..35c0753 --- /dev/null +++ b/tests/extensions/base64ext.nim @@ -0,0 +1,120 @@ +## nim-ws +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import + std/[strutils], + pkg/[stew/results, + stew/base64, + chronos, + chronicles], + ../../ws/types, + ../../ws/frame + +type + Base64Ext = ref object of Ext + padding: bool + transform: bool + +const + extID = "base64" + +method decode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + ext.transform = frame.rsv2 + frame.rsv2 = false + + if not ext.transform: + return frame + + if frame.length > 0: + var data: seq[byte] + var buf: array[0xFFFF, byte] + + while data.len < frame.length.int: + let len = min(frame.length.int - data.len, buf.len) + let read = await frame.read(ext.session.stream.reader, addr buf[0], len) + data.add toOpenArray(buf, 0, read - 1) + + if data.len > ext.session.frameSize: + raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") + + # bug in Base64.Decode when accepts seq[byte] + let instr = cast[string](data) + if ext.padding: + frame.data = Base64Pad.decode(instr) + else: + frame.data = Base64.decode(instr) + + trace "Base64Ext decode", input=frame.length, output=frame.data.len + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + frame.mask = false + + return frame + +method encode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + ext.transform = true + frame.rsv2 = ext.transform + + if not ext.transform: + return frame + + frame.length = frame.data.len.uint64 + + if ext.padding: + frame.data = cast[seq[byte]](Base64Pad.encode(frame.data)) + else: + frame.data = cast[seq[byte]](Base64.encode(frame.data)) + + trace "Base64Ext encode", input=frame.length, output=frame.data.len + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + + return frame + +method toHttpOptions(ext: Base64Ext): string = + extID & "; pad=" & $ext.padding + +proc base64Factory*(padding: bool): ExtFactory = + + proc factory(isServer: bool, + args: seq[ExtParam]): Result[Ext, string] {. + gcsafe, raises: [Defect].} = + + # you can capture configuration variables via closure + # if you want + + var ext = Base64Ext( + name : extID, + transform: false + ) + + for arg in args: + if arg.name == "pad": + ext.padding = arg.value == "true" + break + + ok(ext) + + ExtFactory( + name: extID, + factory: factory, + clientOffer: extID & "; pad=" & $padding + ) diff --git a/tests/extensions/hexext.nim b/tests/extensions/hexext.nim new file mode 100644 index 0000000..32ade7c --- /dev/null +++ b/tests/extensions/hexext.nim @@ -0,0 +1,103 @@ +## nim-ws +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import + std/[strutils], + pkg/[stew/results, + stew/byteutils, + chronos, + chronicles], + ../../ws/types, + ../../ws/frame + +type + HexExt = ref object of Ext + transform: bool + +const + extID = "hex" + +method decode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + ext.transform = frame.rsv3 + frame.rsv3 = false + + if not ext.transform: + return frame + + if frame.length > 0: + var data: seq[byte] + var buf: array[0xFFFF, byte] + + while data.len < frame.length.int: + let len = min(frame.length.int - data.len, buf.len) + let read = await frame.read(ext.session.stream.reader, addr buf[0], len) + data.add toOpenArray(buf, 0, read - 1) + + if data.len > ext.session.frameSize: + raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") + + frame.data = hexToSeqByte(cast[string](data)) + trace "HexExt decode", input=frame.length, output=frame.data.len + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + frame.mask = false + + return frame + +method encode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + ext.transform = true + frame.rsv3 = ext.transform + + if not ext.transform: + return frame + + frame.length = frame.data.len.uint64 + frame.data = cast[seq[byte]](toHex(frame.data)) + trace "HexExt encode", input=frame.length, output=frame.data.len + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + + return frame + +method toHttpOptions(ext: HexExt): string = + extID + +proc hexFactory*(): ExtFactory = + + proc factory(isServer: bool, + args: seq[ExtParam]): Result[Ext, string] {. + gcsafe, raises: [Defect].} = + + # you can capture configuration variables via closure + # if you want + + var ext = HexExt( + name : extID, + transform: false + ) + + ok(ext) + + ExtFactory( + name: extID, + factory: factory, + clientOffer: extID + ) diff --git a/tests/extensions/testexts.nim b/tests/extensions/testexts.nim new file mode 100644 index 0000000..4fe5a18 --- /dev/null +++ b/tests/extensions/testexts.nim @@ -0,0 +1,89 @@ +## nim-ws +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import pkg/[chronos, stew/byteutils] +import ../asyncunit +import ./base64ext, ./hexext +import ../../ws/ws, ../helpers + +suite "UTF-8 validator in action": + var server: HttpServer + let address = initTAddress("127.0.0.1:8888") + let hexFactory = hexFactory() + let base64Factory = base64Factory(padding = true) + + teardown: + server.stop() + await server.closeWait() + + test "hex to base64 ext flow": + let testData = "hello world" + proc handle(request: HttpRequest) {.async.} = + let server = WSServer.new( + protos = ["proto"], + factories = [hexFactory, base64Factory], + ) + let ws = await server.handleRequest(request) + let recvData = await ws.recv() + await ws.send(recvData, + if ws.binary: Opcode.Binary else: Opcode.Text) + + await waitForClose(ws) + + server = HttpServer.create( + address, + handle, + flags = {ReuseAddr}) + server.start() + + let client = await WebSocket.connect( + host = "127.0.0.1", + port = Port(8888), + path = "/ws", + protocols = @["proto"], + factories = @[hexFactory, base64Factory] + ) + + await client.send(testData) + let res = await client.recv() + check testData.toBytes() == res + await client.close() + + test "base64 to hex ext flow": + let testData = "hello world" + proc handle(request: HttpRequest) {.async.} = + let server = WSServer.new( + protos = ["proto"], + factories = [hexFactory, base64Factory], + ) + let ws = await server.handleRequest(request) + let recvData = await ws.recv() + await ws.send(recvData, + if ws.binary: Opcode.Binary else: Opcode.Text) + + await waitForClose(ws) + + server = HttpServer.create( + address, + handle, + flags = {ReuseAddr}) + server.start() + + let client = await WebSocket.connect( + host = "127.0.0.1", + port = Port(8888), + path = "/ws", + protocols = @["proto"], + factories = @[base64Factory, hexFactory] + ) + + await client.send(testData) + let res = await client.recv() + check testData.toBytes() == res + await client.close() diff --git a/tests/testcommon.nim b/tests/testcommon.nim index 93c05a4..dba53af 100644 --- a/tests/testcommon.nim +++ b/tests/testcommon.nim @@ -3,3 +3,4 @@ import ./testframes import ./testutf8 import ./testextutils +import ./extensions/testexts