diff --git a/examples/autobahn_client.nim b/examples/autobahn_client.nim index 8629145c83..c0fa450cfd 100644 --- a/examples/autobahn_client.nim +++ b/examples/autobahn_client.nim @@ -10,7 +10,7 @@ import std/[strutils], pkg/[chronos, chronicles, stew/byteutils], - ../ws/[ws, types, frame] + ../ws/[ws, types, frame, extensions/compression/deflate] const clientFlags = {NoVerifyHost, NoVerifyServerName} @@ -28,13 +28,14 @@ else: secure = false serverPort = 9001 -proc connectServer(path: string): Future[WSSession] {.async.} = +proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSession] {.async.} = let ws = await WebSocket.connect( host = "127.0.0.1", port = Port(serverPort), path = path, secure=secure, - flags=clientFlags + flags=clientFlags, + factories = factories ) return ws @@ -71,11 +72,12 @@ proc main() {.async.} = let caseCount = await getCaseCount() trace "case count", count=caseCount + var deflateFactory = @[deflateFactory()] for i in 1..caseCount: trace "runcase", no=i let path = "/runCase?case=$1&agent=$2" % [$i, agent] try: - let ws = await connectServer(path) + let ws = await connectServer(path, deflateFactory) while ws.readystate != ReadyState.Closed: # echo back diff --git a/examples/server.nim b/examples/server.nim index ff826c19ff..e243e4c68d 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -12,7 +12,7 @@ import pkg/[chronos, chronicles, httputils] -import ../ws/ws +import ../ws/[ws, extensions/compression/deflate] import ../tests/keys proc handle(request: HttpRequest) {.async.} = @@ -23,7 +23,8 @@ proc handle(request: HttpRequest) {.async.} = trace "Initiating web socket connection." try: - let server = WSServer.new() + let deflateFactory = deflateFactory() + let server = WSServer.new(factories = [deflateFactory]) let ws = await server.handleRequest(request) if ws.readyState != Open: error "Failed to open websocket connection" @@ -32,8 +33,13 @@ proc handle(request: HttpRequest) {.async.} = trace "Websocket handshake completed" while ws.readyState != ReadyState.Closed: let recvData = await ws.recv() - trace "Client Response: ", size = recvData.len, binary = ws.binary + + if ws.readyState == ReadyState.Closed: + # if session already terminated by peer, + # no need to send response + break + await ws.send(recvData, if ws.binary: Opcode.Binary else: Opcode.Text) diff --git a/ws.nimble b/ws.nimble index 7e70265b79..28729cb2b7 100644 --- a/ws.nimble +++ b/ws.nimble @@ -22,6 +22,7 @@ requires "stew >= 0.1.0" requires "asynctest >= 0.2.0 & < 0.3.0" requires "nimcrypto" requires "bearssl" +requires "https://github.com/status-im/nim-zlib" task test, "run tests": exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim" diff --git a/ws/extensions/compression/compression.nim b/ws/extensions/compression/compression.nim deleted file mode 100644 index 1a015c9e6d..0000000000 --- a/ws/extensions/compression/compression.nim +++ /dev/null @@ -1,212 +0,0 @@ -## nim-ws -## Copyright (c) 2021 Status Research & Development GmbH -## Licensed under either of -## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) -## * MIT license ([LICENSE-MIT](LICENSE-MIT)) -## at your option. -## This file may not be copied, modified, or distributed except according to -## those terms. - - -import - std/[strutils], - pkg/[stew/results, chronos], - ../../types, ../../frame, ./miniz/miniz_api - -type - DeflateOpts = object - isServer: bool - serverNoContextTakeOver: bool - clientNoContextTakeOver: bool - serverMaxWindowBits: int - clientMaxWindowBits: int - - DeflateExt = ref object of Ext - paramStr: string - opts: DeflateOpts - # 'messageCompressed' is a two way flag: - # 1. the original message is compressible or not - # 2. the frame we received need decompression or not - messageCompressed: bool - -const - extID = "permessage-deflate" - -proc concatParam(resp: var string, param: string) = - resp.add ';' - resp.add param - -proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] = - if arg.value.len == 0: - return ok("") - - if arg.value.len > 2: - return err("window bits expect 2 bytes, got " & $arg.value.len) - - for n in arg.value: - if n notin Digits: - return err("window bits value contains illegal char: " & $n) - - var winbit = 0 - for i in 0.. 15: - return err("window bits should between 8-15, got " & $winbit) - - res = winbit - return ok("=" & arg.value) - -proc validateParams(args: seq[ExtParam], - opts: var DeflateOpts): Result[string, string] = - # besides validating extensions params, this proc - # also constructing extension param for response - var resp = "" - for arg in args: - case arg.name - of "server_no_context_takeover": - if arg.value.len > 0: - return err("'server_no_context_takeover' should have no param") - if opts.isServer: - concatParam(resp, arg.name) - opts.serverNoContextTakeOver = true - of "client_no_context_takeover": - if arg.value.len > 0: - return err("'client_no_context_takeover' should have no param") - if opts.isServer: - concatParam(resp, arg.name) - opts.clientNoContextTakeOver = true - of "server_max_window_bits": - if opts.isServer: - concatParam(resp, arg.name) - let res = validateWindowBits(arg, opts.serverMaxWindowBits) - if res.isErr: - return res - resp.add res.get() - of "client_max_window_bits": - if opts.isServer: - concatParam(resp, arg.name) - let res = validateWindowBits(arg, opts.clientMaxWindowBits) - if res.isErr: - return res - resp.add res.get() - else: - return err("unrecognized param: " & arg.name) - - ok(resp) - -method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} = - if not ext.messageCompressed: - return data - - # TODO: append trailing bytes - var mz = MzStream( - next_in: cast[ptr cuchar](data[0].unsafeAddr), - avail_in: data.len.cuint - ) - - let windowBits = if ext.opts.serverMaxWindowBits == 0: - MZ_DEFAULT_WINDOW_BITS - else: - MzWindowBits(ext.opts.serverMaxWindowBits) - - doAssert(mz.inflateInit2(windowBits) == MZ_OK) - var res: seq[byte] - var buf: array[0xFFFF, byte] - - while true: - mz.next_out = cast[ptr cuchar](buf[0].addr) - mz.avail_out = buf.len.cuint - let r = mz.inflate(MZ_SYNC_FLUSH) - let outSize = buf.len - mz.avail_out.int - res.add toOpenArray(buf, 0, outSize-1) - if r == MZ_STREAM_END: - break - elif r == MZ_OK: - continue - else: - doAssert(false, "decompression error") - - doAssert(mz.inflateEnd() == MZ_OK) - return res - -method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} = - var mz = MzStream( - next_in: cast[ptr cuchar](data[0].unsafeAddr), - avail_in: data.len.cuint - ) - - let windowBits = if ext.opts.serverMaxWindowBits == 0: - MZ_DEFAULT_WINDOW_BITS - else: - MzWindowBits(ext.opts.serverMaxWindowBits) - - doAssert(mz.deflateInit2( - level = MZ_DEFAULT_LEVEL, - meth = MZ_DEFLATED, - windowBits, - 1, - strategy = MZ_DEFAULT_STRATEGY) == MZ_OK - ) - - let maxSize = mz.deflateBound(data.len.culong).int - var res: seq[byte] - var buf: array[0xFFFF, byte] - - while true: - mz.next_out = cast[ptr cuchar](buf[0].addr) - mz.avail_out = buf.len.cuint - let r = mz.deflate(MZ_FINISH) - let outSize = buf.len - mz.avail_out.int - res.add toOpenArray(buf, 0, outSize-1) - if r == MZ_STREAM_END: - break - elif r == MZ_OK: - continue - else: - doAssert(false, "compression error") - - # TODO: cut trailing bytes - doAssert(mz.deflateEnd() == MZ_OK) - ext.messageCompressed = res.len < data.len - if ext.messageCompressed: - return res - else: - return data - -method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = - if frame.opcode in {Opcode.Text, Opcode.Binary}: - # only data frame can be compressed - # and we want to know if this message is compressed or not - # if the frame opcode is text or binary, it should also the first frame - ext.messageCompressed = frame.rsv1 - # clear rsv1 bit because we already done with it - frame.rsv1 = false - return frame - -method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = - if frame.opcode in {Opcode.Text, Opcode.Binary}: - # only data frame can be compressed - # and we only set rsv1 bit to true if the message is compressible - # if the frame opcode is text or binary, it should also the first frame - frame.rsv1 = ext.messageCompressed - return frame - -method toHttpOptions(ext: DeflateExt): string = - # using paramStr here is a bit clunky - extID & "; " & ext.paramStr - -proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} = - var opts = DeflateOpts(isServer: isServer) - let resp = validateParams(args, opts) - if resp.isErr: - return err(resp.error) - let ext = DeflateExt( - name: extID, - paramStr: resp.get(), - opts: opts - ) - ok(ext) - -const - deflateFactory* = (extID, deflateExtFactory) diff --git a/ws/extensions/compression/deflate.nim b/ws/extensions/compression/deflate.nim new file mode 100644 index 0000000000..c9e2bf3c5c --- /dev/null +++ b/ws/extensions/compression/deflate.nim @@ -0,0 +1,394 @@ +## nim-ws +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import + std/[strutils], + pkg/[stew/results, + chronos, + chronicles, + zlib], + ../../types, + ../../frame + +type + DeflateOpts = object + isServer: bool + decompressLimit: int # max allowed decompression size + threshold: int # size in bytes below which messages + # should not be compressed. + level: ZLevel # compression level + strategy: ZStrategy # compression strategy + memLevel: ZMemLevel # hint for miniz memory consumption + serverNoContextTakeOver: bool + clientNoContextTakeOver: bool + serverMaxWindowBits: int + clientMaxWindowBits: int + + ContextState {.pure.} = enum + Invalid + Initialized + Reset + + DeflateExt = ref object of Ext + paramStr : string + opts : DeflateOpts + compressedMsg : bool + compCtx : ZStream + compCtxState : ContextState + decompCtx : ZStream + decompCtxState: ContextState + +const + extID = "permessage-deflate" + TrailingBytes = [0x00.byte, 0x00.byte, 0xff.byte, 0xff.byte] + ExtDeflateThreshold* = 1024 + ExtDeflateDecompressLimit* = 10 shl 20 # 10mb + +proc concatParam(resp: var string, param: string) = + resp.add "; " + resp.add param + +proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] = + if arg.value.len == 0: + return ok("") + + if arg.value.len > 2: + return err("window bits expect 2 bytes, got " & $arg.value.len) + + for n in arg.value: + if n notin Digits: + return err("window bits value contains illegal char: " & $n) + + var winbit = 0 + for i in 0.. 15: + return err("window bits should between 8-15, got " & $winbit) + + res = winbit + return ok("=" & arg.value) + +proc createParams(args: seq[ExtParam], + opts: var DeflateOpts): Result[string, string] = + # besides validating extensions params, this proc + # also constructing extension params for response + var resp = "" + for arg in args: + case arg.name + of "server_no_context_takeover": + if arg.value.len > 0: + return err("'server_no_context_takeover' should have no param") + opts.serverNoContextTakeOver = true + if opts.isServer: + concatParam(resp, arg.name) + of "client_no_context_takeover": + if arg.value.len > 0: + return err("'client_no_context_takeover' should have no param") + opts.clientNoContextTakeOver = true + if opts.isServer: + concatParam(resp, arg.name) + of "server_max_window_bits": + let res = validateWindowBits(arg, opts.serverMaxWindowBits) + if res.isErr: + return res + if opts.isServer: + concatParam(resp, arg.name) + if opts.serverMaxWindowBits == 8: + # zlib does not support windowBits == 8 + resp.add "=9" + else: + resp.add res.get() + of "client_max_window_bits": + let res = validateWindowBits(arg, opts.clientMaxWindowBits) + if res.isErr: + return res + if not opts.isServer: + concatParam(resp, arg.name) + if opts.clientMaxWindowBits == 8: + # zlib does not support windowBits == 8 + resp.add "=9" + else: + resp.add res.get() + else: + return err("unrecognized param: " & arg.name) + + ok(resp) + +proc getWindowBits(opts: DeflateOpts, isServer: bool): ZWindowBits = + if isServer: + if opts.serverMaxWindowBits == 0: + Z_RAW_DEFLATE + else: + ZWindowBits(-opts.serverMaxWindowBits) + else: + if opts.clientMaxWindowBits == 0: + Z_RAW_DEFLATE + else: + ZWindowBits(-opts.clientMaxWindowBits) + +proc getContextTakeover(opts: DeflateOpts, isServer: bool): bool = + if isServer: + opts.serverNoContextTakeOver + else: + opts.clientNoContextTakeOver + +proc decompressInit(ext: DeflateExt) = + # decompression using `client_` prefixed config + let windowBits = getWindowBits(ext.opts, not ext.opts.isServer) + doAssert(ext.decompCtx.inflateInit2(windowBits) == Z_OK) + ext.decompCtxState = ContextState.Initialized + +proc compressInit(ext: DeflateExt) = + # compression using `server_` prefixed config + let windowBits = getWindowBits(ext.opts, ext.opts.isServer) + doAssert(ext.compCtx.deflateInit2( + level = ext.opts.level, + meth = Z_DEFLATED, + windowBits, + memLevel = ext.opts.memLevel, + strategy = ext.opts.strategy) == Z_OK + ) + ext.compCtxState = ContextState.Initialized + +proc compress(zs: var ZStream, data: openArray[byte]): seq[byte] = + var buf: array[0xFFFF, byte] + + # these casting is needed to prevent compilation + # error with CLANG + zs.next_in = cast[ptr cuchar](data[0].unsafeAddr) + zs.avail_in = data.len.cuint + + while true: + zs.next_out = cast[ptr cuchar](buf[0].addr) + zs.avail_out = buf.len.cuint + + let r = zs.deflate(Z_SYNC_FLUSH) + let outSize = buf.len - zs.avail_out.int + result.add toOpenArray(buf, 0, outSize-1) + + if r == Z_STREAM_END: + break + elif r == Z_OK: + # need more input or more output available + if zs.avail_in > 0 or zs.avail_out == 0: + continue + else: + break + else: + raise newException(WSExtError, "compression error " & $r) + +proc decompress(zs: var ZStream, limit: int, data: openArray[byte]): seq[byte] = + var buf: array[0xFFFF, byte] + + # these casting is needed to prevent compilation + # error with CLANG + zs.next_in = cast[ptr cuchar](data[0].unsafeAddr) + zs.avail_in = data.len.cuint + + while true: + zs.next_out = cast[ptr cuchar](buf[0].addr) + zs.avail_out = buf.len.cuint + + let r = zs.inflate(Z_NO_FLUSH) + let outSize = buf.len - zs.avail_out.int + result.add toOpenArray(buf, 0, outSize-1) + + if result.len > limit: + raise newException(WSExtError, "decompression exceeds allowed limit") + + if r == Z_STREAM_END: + break + elif r == Z_OK: + # need more input or more output available + if zs.avail_in > 0 or zs.avail_out == 0: + continue + else: + break + else: + raise newException(WSExtError, "decompression error " & $r) + + return result + +method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + # only data frames can be decompressed + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + # we want to know if this message is compressed or not + # if the frame opcode is text or binary, it should also the first frame + ext.compressedMsg = frame.rsv1 + # clear rsv1 bit because we already done with it + frame.rsv1 = false + + if not ext.compressedMsg: + # don't bother with uncompressed message + return frame + + if ext.decompCtxState == ContextState.Invalid: + ext.decompressInit() + + # even though the frame.data.len == 0, the stream needs + # to be closed with trailing bytes if it's a final frame + + var data: seq[byte] + var buf: array[0xFFFF, byte] + + while data.len < frame.length.int: + let len = min(frame.length.int - data.len, buf.len) + let read = await frame.read(ext.session.stream.reader, addr buf[0], len) + data.add toOpenArray(buf, 0, read - 1) + + if data.len > ext.session.frameSize: + raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") + + if frame.fin: + data.add TrailingBytes + + frame.data = decompress(ext.decompCtx, ext.opts.decompressLimit, data) + trace "DeflateExt decompress", input=frame.length, output=frame.data.len + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + frame.mask = false # clear mask flag, decompressed content is not masked + + if frame.fin: + # decompression using `client_` prefixed config + let noContextTakeover = getContextTakeover(ext.opts, not ext.opts.isServer) + if noContextTakeover: + doAssert(ext.decompCtx.inflateReset() == Z_OK) + ext.decompCtxState = ContextState.Reset + + return frame + +method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = + if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: + # only data frames can be compressed + return frame + + if frame.opcode in {Opcode.Text, Opcode.Binary}: + # we only set rsv1 bit to true if the message is compressible + # and only set the first frame's rsv1 + # if the frame opcode is text or binary, it should also the first frame + ext.compressedMsg = frame.data.len >= ext.opts.threshold + frame.rsv1 = ext.compressedMsg + + if not ext.compressedMsg: + # don't bother with incompressible message + return frame + + if ext.compCtxState == ContextState.Invalid: + ext.compressInit() + + frame.length = frame.data.len.uint64 + frame.data = compress(ext.compCtx, frame.data) + trace "DeflateExt compress", input=frame.length, output=frame.data.len + + if frame.fin: + # remove trailing bytes + when not defined(release): + var trailer: array[4, byte] + trailer[0] = frame.data[^4] + trailer[1] = frame.data[^3] + trailer[2] = frame.data[^2] + trailer[3] = frame.data[^1] + doAssert trailer == TrailingBytes + frame.data.setLen(frame.data.len - 4) + + frame.length = frame.data.len.uint64 + frame.offset = 0 + frame.consumed = 0 + + if frame.fin: + # compression using `server_` prefixed config + let noContextTakeover = getContextTakeover(ext.opts, ext.opts.isServer) + if noContextTakeover: + doAssert(ext.compCtx.deflateReset() == Z_OK) + ext.compCtxState = ContextState.Reset + + return frame + +method toHttpOptions(ext: DeflateExt): string = + # using paramStr here is a bit clunky + extID & ext.paramStr + +proc destroyExt(ext: DeflateExt) = + if ext.compCtxState != ContextState.Invalid: + # zlib.deflateEnd somehow return DATA_ERROR + # when compression succeed some cases. + # we forget to do something? + discard ext.compCtx.deflateEnd() + ext.compCtxState = ContextState.Invalid + + if ext.decompCtxState != ContextState.Invalid: + doAssert(ext.decompCtx.inflateEnd() == Z_OK) + ext.decompCtxState = ContextState.Invalid + +proc makeOffer( + clientNoContextTakeOver: bool, + clientMaxWindowBits: int): string = + + var param = extID + if clientMaxWindowBits in {9..15}: + param.add "; client_max_window_bits=" & $clientMaxWindowBits + else: + param.add "; client_max_window_bits" + + if clientNoContextTakeOver: + param.add "; client_no_context_takeover" + + param + +proc deflateFactory*( + threshold = ExtDeflateThreshold, + decompressLimit = ExtDeflateDecompressLimit, + level = Z_DEFAULT_LEVEL, + strategy = Z_DEFAULT_STRATEGY, + memLevel = Z_DEFAULT_MEM_LEVEL, + clientNoContextTakeOver = false, + clientMaxWindowBits = 15): ExtFactory = + + proc factory(isServer: bool, + args: seq[ExtParam]): Result[Ext, string] {. + gcsafe, raises: [Defect].} = + + # capture user configuration via closure + var opts = DeflateOpts( + isServer: isServer, + threshold: threshold, + decompressLimit: decompressLimit, + level: level, + strategy: strategy, + memLevel: memLevel + ) + let resp = createParams(args, opts) + if resp.isErr: + return err(resp.error) + + var ext: DeflateExt + ext.new(destroyExt) + ext.name = extID + ext.paramStr = resp.get() + ext.opts = opts + ext.compressedMsg = false + ext.compCtxState = ContextState.Invalid + ext.decompCtxState= ContextState.Invalid + + ok(ext) + + ExtFactory( + name: extID, + factory: factory, + clientOffer: makeOffer( + clientNoContextTakeOver, + clientMaxWindowBits + ) + )