implement permessage-deflate compression extension
depends on zlib as it's backend compressor pass both client and server tests in autobahn test suite
This commit is contained in:
parent
fef04a1595
commit
14d8e51f53
|
@ -10,7 +10,7 @@
|
||||||
import
|
import
|
||||||
std/[strutils],
|
std/[strutils],
|
||||||
pkg/[chronos, chronicles, stew/byteutils],
|
pkg/[chronos, chronicles, stew/byteutils],
|
||||||
../ws/[ws, types, frame]
|
../ws/[ws, types, frame, extensions/compression/deflate]
|
||||||
|
|
||||||
const
|
const
|
||||||
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||||
|
@ -28,13 +28,14 @@ else:
|
||||||
secure = false
|
secure = false
|
||||||
serverPort = 9001
|
serverPort = 9001
|
||||||
|
|
||||||
proc connectServer(path: string): Future[WSSession] {.async.} =
|
proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSession] {.async.} =
|
||||||
let ws = await WebSocket.connect(
|
let ws = await WebSocket.connect(
|
||||||
host = "127.0.0.1",
|
host = "127.0.0.1",
|
||||||
port = Port(serverPort),
|
port = Port(serverPort),
|
||||||
path = path,
|
path = path,
|
||||||
secure=secure,
|
secure=secure,
|
||||||
flags=clientFlags
|
flags=clientFlags,
|
||||||
|
factories = factories
|
||||||
)
|
)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
@ -71,11 +72,12 @@ proc main() {.async.} =
|
||||||
let caseCount = await getCaseCount()
|
let caseCount = await getCaseCount()
|
||||||
trace "case count", count=caseCount
|
trace "case count", count=caseCount
|
||||||
|
|
||||||
|
var deflateFactory = @[deflateFactory()]
|
||||||
for i in 1..caseCount:
|
for i in 1..caseCount:
|
||||||
trace "runcase", no=i
|
trace "runcase", no=i
|
||||||
let path = "/runCase?case=$1&agent=$2" % [$i, agent]
|
let path = "/runCase?case=$1&agent=$2" % [$i, agent]
|
||||||
try:
|
try:
|
||||||
let ws = await connectServer(path)
|
let ws = await connectServer(path, deflateFactory)
|
||||||
|
|
||||||
while ws.readystate != ReadyState.Closed:
|
while ws.readystate != ReadyState.Closed:
|
||||||
# echo back
|
# echo back
|
||||||
|
|
|
@ -12,7 +12,7 @@ import pkg/[chronos,
|
||||||
chronicles,
|
chronicles,
|
||||||
httputils]
|
httputils]
|
||||||
|
|
||||||
import ../ws/ws
|
import ../ws/[ws, extensions/compression/deflate]
|
||||||
import ../tests/keys
|
import ../tests/keys
|
||||||
|
|
||||||
proc handle(request: HttpRequest) {.async.} =
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
|
@ -23,7 +23,8 @@ proc handle(request: HttpRequest) {.async.} =
|
||||||
|
|
||||||
trace "Initiating web socket connection."
|
trace "Initiating web socket connection."
|
||||||
try:
|
try:
|
||||||
let server = WSServer.new()
|
let deflateFactory = deflateFactory()
|
||||||
|
let server = WSServer.new(factories = [deflateFactory])
|
||||||
let ws = await server.handleRequest(request)
|
let ws = await server.handleRequest(request)
|
||||||
if ws.readyState != Open:
|
if ws.readyState != Open:
|
||||||
error "Failed to open websocket connection"
|
error "Failed to open websocket connection"
|
||||||
|
@ -32,8 +33,13 @@ proc handle(request: HttpRequest) {.async.} =
|
||||||
trace "Websocket handshake completed"
|
trace "Websocket handshake completed"
|
||||||
while ws.readyState != ReadyState.Closed:
|
while ws.readyState != ReadyState.Closed:
|
||||||
let recvData = await ws.recv()
|
let recvData = await ws.recv()
|
||||||
|
|
||||||
trace "Client Response: ", size = recvData.len, binary = ws.binary
|
trace "Client Response: ", size = recvData.len, binary = ws.binary
|
||||||
|
|
||||||
|
if ws.readyState == ReadyState.Closed:
|
||||||
|
# if session already terminated by peer,
|
||||||
|
# no need to send response
|
||||||
|
break
|
||||||
|
|
||||||
await ws.send(recvData,
|
await ws.send(recvData,
|
||||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ requires "stew >= 0.1.0"
|
||||||
requires "asynctest >= 0.2.0 & < 0.3.0"
|
requires "asynctest >= 0.2.0 & < 0.3.0"
|
||||||
requires "nimcrypto"
|
requires "nimcrypto"
|
||||||
requires "bearssl"
|
requires "bearssl"
|
||||||
|
requires "https://github.com/status-im/nim-zlib"
|
||||||
|
|
||||||
task test, "run tests":
|
task test, "run tests":
|
||||||
exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim"
|
exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim"
|
||||||
|
|
|
@ -1,212 +0,0 @@
|
||||||
## nim-ws
|
|
||||||
## Copyright (c) 2021 Status Research & Development GmbH
|
|
||||||
## Licensed under either of
|
|
||||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
|
||||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
|
||||||
## at your option.
|
|
||||||
## This file may not be copied, modified, or distributed except according to
|
|
||||||
## those terms.
|
|
||||||
|
|
||||||
|
|
||||||
import
|
|
||||||
std/[strutils],
|
|
||||||
pkg/[stew/results, chronos],
|
|
||||||
../../types, ../../frame, ./miniz/miniz_api
|
|
||||||
|
|
||||||
type
|
|
||||||
DeflateOpts = object
|
|
||||||
isServer: bool
|
|
||||||
serverNoContextTakeOver: bool
|
|
||||||
clientNoContextTakeOver: bool
|
|
||||||
serverMaxWindowBits: int
|
|
||||||
clientMaxWindowBits: int
|
|
||||||
|
|
||||||
DeflateExt = ref object of Ext
|
|
||||||
paramStr: string
|
|
||||||
opts: DeflateOpts
|
|
||||||
# 'messageCompressed' is a two way flag:
|
|
||||||
# 1. the original message is compressible or not
|
|
||||||
# 2. the frame we received need decompression or not
|
|
||||||
messageCompressed: bool
|
|
||||||
|
|
||||||
const
|
|
||||||
extID = "permessage-deflate"
|
|
||||||
|
|
||||||
proc concatParam(resp: var string, param: string) =
|
|
||||||
resp.add ';'
|
|
||||||
resp.add param
|
|
||||||
|
|
||||||
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
|
|
||||||
if arg.value.len == 0:
|
|
||||||
return ok("")
|
|
||||||
|
|
||||||
if arg.value.len > 2:
|
|
||||||
return err("window bits expect 2 bytes, got " & $arg.value.len)
|
|
||||||
|
|
||||||
for n in arg.value:
|
|
||||||
if n notin Digits:
|
|
||||||
return err("window bits value contains illegal char: " & $n)
|
|
||||||
|
|
||||||
var winbit = 0
|
|
||||||
for i in 0..<arg.value.len:
|
|
||||||
winbit = winbit * 10 + arg.value[i].int - '0'.int
|
|
||||||
|
|
||||||
if winbit < 8 or winbit > 15:
|
|
||||||
return err("window bits should between 8-15, got " & $winbit)
|
|
||||||
|
|
||||||
res = winbit
|
|
||||||
return ok("=" & arg.value)
|
|
||||||
|
|
||||||
proc validateParams(args: seq[ExtParam],
|
|
||||||
opts: var DeflateOpts): Result[string, string] =
|
|
||||||
# besides validating extensions params, this proc
|
|
||||||
# also constructing extension param for response
|
|
||||||
var resp = ""
|
|
||||||
for arg in args:
|
|
||||||
case arg.name
|
|
||||||
of "server_no_context_takeover":
|
|
||||||
if arg.value.len > 0:
|
|
||||||
return err("'server_no_context_takeover' should have no param")
|
|
||||||
if opts.isServer:
|
|
||||||
concatParam(resp, arg.name)
|
|
||||||
opts.serverNoContextTakeOver = true
|
|
||||||
of "client_no_context_takeover":
|
|
||||||
if arg.value.len > 0:
|
|
||||||
return err("'client_no_context_takeover' should have no param")
|
|
||||||
if opts.isServer:
|
|
||||||
concatParam(resp, arg.name)
|
|
||||||
opts.clientNoContextTakeOver = true
|
|
||||||
of "server_max_window_bits":
|
|
||||||
if opts.isServer:
|
|
||||||
concatParam(resp, arg.name)
|
|
||||||
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
|
|
||||||
if res.isErr:
|
|
||||||
return res
|
|
||||||
resp.add res.get()
|
|
||||||
of "client_max_window_bits":
|
|
||||||
if opts.isServer:
|
|
||||||
concatParam(resp, arg.name)
|
|
||||||
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
|
|
||||||
if res.isErr:
|
|
||||||
return res
|
|
||||||
resp.add res.get()
|
|
||||||
else:
|
|
||||||
return err("unrecognized param: " & arg.name)
|
|
||||||
|
|
||||||
ok(resp)
|
|
||||||
|
|
||||||
method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
|
||||||
if not ext.messageCompressed:
|
|
||||||
return data
|
|
||||||
|
|
||||||
# TODO: append trailing bytes
|
|
||||||
var mz = MzStream(
|
|
||||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
|
||||||
avail_in: data.len.cuint
|
|
||||||
)
|
|
||||||
|
|
||||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
|
||||||
MZ_DEFAULT_WINDOW_BITS
|
|
||||||
else:
|
|
||||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
|
||||||
|
|
||||||
doAssert(mz.inflateInit2(windowBits) == MZ_OK)
|
|
||||||
var res: seq[byte]
|
|
||||||
var buf: array[0xFFFF, byte]
|
|
||||||
|
|
||||||
while true:
|
|
||||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
|
||||||
mz.avail_out = buf.len.cuint
|
|
||||||
let r = mz.inflate(MZ_SYNC_FLUSH)
|
|
||||||
let outSize = buf.len - mz.avail_out.int
|
|
||||||
res.add toOpenArray(buf, 0, outSize-1)
|
|
||||||
if r == MZ_STREAM_END:
|
|
||||||
break
|
|
||||||
elif r == MZ_OK:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
doAssert(false, "decompression error")
|
|
||||||
|
|
||||||
doAssert(mz.inflateEnd() == MZ_OK)
|
|
||||||
return res
|
|
||||||
|
|
||||||
method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
|
||||||
var mz = MzStream(
|
|
||||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
|
||||||
avail_in: data.len.cuint
|
|
||||||
)
|
|
||||||
|
|
||||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
|
||||||
MZ_DEFAULT_WINDOW_BITS
|
|
||||||
else:
|
|
||||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
|
||||||
|
|
||||||
doAssert(mz.deflateInit2(
|
|
||||||
level = MZ_DEFAULT_LEVEL,
|
|
||||||
meth = MZ_DEFLATED,
|
|
||||||
windowBits,
|
|
||||||
1,
|
|
||||||
strategy = MZ_DEFAULT_STRATEGY) == MZ_OK
|
|
||||||
)
|
|
||||||
|
|
||||||
let maxSize = mz.deflateBound(data.len.culong).int
|
|
||||||
var res: seq[byte]
|
|
||||||
var buf: array[0xFFFF, byte]
|
|
||||||
|
|
||||||
while true:
|
|
||||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
|
||||||
mz.avail_out = buf.len.cuint
|
|
||||||
let r = mz.deflate(MZ_FINISH)
|
|
||||||
let outSize = buf.len - mz.avail_out.int
|
|
||||||
res.add toOpenArray(buf, 0, outSize-1)
|
|
||||||
if r == MZ_STREAM_END:
|
|
||||||
break
|
|
||||||
elif r == MZ_OK:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
doAssert(false, "compression error")
|
|
||||||
|
|
||||||
# TODO: cut trailing bytes
|
|
||||||
doAssert(mz.deflateEnd() == MZ_OK)
|
|
||||||
ext.messageCompressed = res.len < data.len
|
|
||||||
if ext.messageCompressed:
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
|
||||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
|
||||||
# only data frame can be compressed
|
|
||||||
# and we want to know if this message is compressed or not
|
|
||||||
# if the frame opcode is text or binary, it should also the first frame
|
|
||||||
ext.messageCompressed = frame.rsv1
|
|
||||||
# clear rsv1 bit because we already done with it
|
|
||||||
frame.rsv1 = false
|
|
||||||
return frame
|
|
||||||
|
|
||||||
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
|
||||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
|
||||||
# only data frame can be compressed
|
|
||||||
# and we only set rsv1 bit to true if the message is compressible
|
|
||||||
# if the frame opcode is text or binary, it should also the first frame
|
|
||||||
frame.rsv1 = ext.messageCompressed
|
|
||||||
return frame
|
|
||||||
|
|
||||||
method toHttpOptions(ext: DeflateExt): string =
|
|
||||||
# using paramStr here is a bit clunky
|
|
||||||
extID & "; " & ext.paramStr
|
|
||||||
|
|
||||||
proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} =
|
|
||||||
var opts = DeflateOpts(isServer: isServer)
|
|
||||||
let resp = validateParams(args, opts)
|
|
||||||
if resp.isErr:
|
|
||||||
return err(resp.error)
|
|
||||||
let ext = DeflateExt(
|
|
||||||
name: extID,
|
|
||||||
paramStr: resp.get(),
|
|
||||||
opts: opts
|
|
||||||
)
|
|
||||||
ok(ext)
|
|
||||||
|
|
||||||
const
|
|
||||||
deflateFactory* = (extID, deflateExtFactory)
|
|
|
@ -0,0 +1,394 @@
|
||||||
|
## nim-ws
|
||||||
|
## Copyright (c) 2021 Status Research & Development GmbH
|
||||||
|
## Licensed under either of
|
||||||
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||||
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||||
|
## at your option.
|
||||||
|
## This file may not be copied, modified, or distributed except according to
|
||||||
|
## those terms.
|
||||||
|
|
||||||
|
import
|
||||||
|
std/[strutils],
|
||||||
|
pkg/[stew/results,
|
||||||
|
chronos,
|
||||||
|
chronicles,
|
||||||
|
zlib],
|
||||||
|
../../types,
|
||||||
|
../../frame
|
||||||
|
|
||||||
|
type
|
||||||
|
DeflateOpts = object
|
||||||
|
isServer: bool
|
||||||
|
decompressLimit: int # max allowed decompression size
|
||||||
|
threshold: int # size in bytes below which messages
|
||||||
|
# should not be compressed.
|
||||||
|
level: ZLevel # compression level
|
||||||
|
strategy: ZStrategy # compression strategy
|
||||||
|
memLevel: ZMemLevel # hint for miniz memory consumption
|
||||||
|
serverNoContextTakeOver: bool
|
||||||
|
clientNoContextTakeOver: bool
|
||||||
|
serverMaxWindowBits: int
|
||||||
|
clientMaxWindowBits: int
|
||||||
|
|
||||||
|
ContextState {.pure.} = enum
|
||||||
|
Invalid
|
||||||
|
Initialized
|
||||||
|
Reset
|
||||||
|
|
||||||
|
DeflateExt = ref object of Ext
|
||||||
|
paramStr : string
|
||||||
|
opts : DeflateOpts
|
||||||
|
compressedMsg : bool
|
||||||
|
compCtx : ZStream
|
||||||
|
compCtxState : ContextState
|
||||||
|
decompCtx : ZStream
|
||||||
|
decompCtxState: ContextState
|
||||||
|
|
||||||
|
const
|
||||||
|
extID = "permessage-deflate"
|
||||||
|
TrailingBytes = [0x00.byte, 0x00.byte, 0xff.byte, 0xff.byte]
|
||||||
|
ExtDeflateThreshold* = 1024
|
||||||
|
ExtDeflateDecompressLimit* = 10 shl 20 # 10mb
|
||||||
|
|
||||||
|
proc concatParam(resp: var string, param: string) =
|
||||||
|
resp.add "; "
|
||||||
|
resp.add param
|
||||||
|
|
||||||
|
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
|
||||||
|
if arg.value.len == 0:
|
||||||
|
return ok("")
|
||||||
|
|
||||||
|
if arg.value.len > 2:
|
||||||
|
return err("window bits expect 2 bytes, got " & $arg.value.len)
|
||||||
|
|
||||||
|
for n in arg.value:
|
||||||
|
if n notin Digits:
|
||||||
|
return err("window bits value contains illegal char: " & $n)
|
||||||
|
|
||||||
|
var winbit = 0
|
||||||
|
for i in 0..<arg.value.len:
|
||||||
|
winbit = winbit * 10 + arg.value[i].int - '0'.int
|
||||||
|
|
||||||
|
if winbit < 8 or winbit > 15:
|
||||||
|
return err("window bits should between 8-15, got " & $winbit)
|
||||||
|
|
||||||
|
res = winbit
|
||||||
|
return ok("=" & arg.value)
|
||||||
|
|
||||||
|
proc createParams(args: seq[ExtParam],
|
||||||
|
opts: var DeflateOpts): Result[string, string] =
|
||||||
|
# besides validating extensions params, this proc
|
||||||
|
# also constructing extension params for response
|
||||||
|
var resp = ""
|
||||||
|
for arg in args:
|
||||||
|
case arg.name
|
||||||
|
of "server_no_context_takeover":
|
||||||
|
if arg.value.len > 0:
|
||||||
|
return err("'server_no_context_takeover' should have no param")
|
||||||
|
opts.serverNoContextTakeOver = true
|
||||||
|
if opts.isServer:
|
||||||
|
concatParam(resp, arg.name)
|
||||||
|
of "client_no_context_takeover":
|
||||||
|
if arg.value.len > 0:
|
||||||
|
return err("'client_no_context_takeover' should have no param")
|
||||||
|
opts.clientNoContextTakeOver = true
|
||||||
|
if opts.isServer:
|
||||||
|
concatParam(resp, arg.name)
|
||||||
|
of "server_max_window_bits":
|
||||||
|
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
|
||||||
|
if res.isErr:
|
||||||
|
return res
|
||||||
|
if opts.isServer:
|
||||||
|
concatParam(resp, arg.name)
|
||||||
|
if opts.serverMaxWindowBits == 8:
|
||||||
|
# zlib does not support windowBits == 8
|
||||||
|
resp.add "=9"
|
||||||
|
else:
|
||||||
|
resp.add res.get()
|
||||||
|
of "client_max_window_bits":
|
||||||
|
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
|
||||||
|
if res.isErr:
|
||||||
|
return res
|
||||||
|
if not opts.isServer:
|
||||||
|
concatParam(resp, arg.name)
|
||||||
|
if opts.clientMaxWindowBits == 8:
|
||||||
|
# zlib does not support windowBits == 8
|
||||||
|
resp.add "=9"
|
||||||
|
else:
|
||||||
|
resp.add res.get()
|
||||||
|
else:
|
||||||
|
return err("unrecognized param: " & arg.name)
|
||||||
|
|
||||||
|
ok(resp)
|
||||||
|
|
||||||
|
proc getWindowBits(opts: DeflateOpts, isServer: bool): ZWindowBits =
|
||||||
|
if isServer:
|
||||||
|
if opts.serverMaxWindowBits == 0:
|
||||||
|
Z_RAW_DEFLATE
|
||||||
|
else:
|
||||||
|
ZWindowBits(-opts.serverMaxWindowBits)
|
||||||
|
else:
|
||||||
|
if opts.clientMaxWindowBits == 0:
|
||||||
|
Z_RAW_DEFLATE
|
||||||
|
else:
|
||||||
|
ZWindowBits(-opts.clientMaxWindowBits)
|
||||||
|
|
||||||
|
proc getContextTakeover(opts: DeflateOpts, isServer: bool): bool =
|
||||||
|
if isServer:
|
||||||
|
opts.serverNoContextTakeOver
|
||||||
|
else:
|
||||||
|
opts.clientNoContextTakeOver
|
||||||
|
|
||||||
|
proc decompressInit(ext: DeflateExt) =
|
||||||
|
# decompression using `client_` prefixed config
|
||||||
|
let windowBits = getWindowBits(ext.opts, not ext.opts.isServer)
|
||||||
|
doAssert(ext.decompCtx.inflateInit2(windowBits) == Z_OK)
|
||||||
|
ext.decompCtxState = ContextState.Initialized
|
||||||
|
|
||||||
|
proc compressInit(ext: DeflateExt) =
|
||||||
|
# compression using `server_` prefixed config
|
||||||
|
let windowBits = getWindowBits(ext.opts, ext.opts.isServer)
|
||||||
|
doAssert(ext.compCtx.deflateInit2(
|
||||||
|
level = ext.opts.level,
|
||||||
|
meth = Z_DEFLATED,
|
||||||
|
windowBits,
|
||||||
|
memLevel = ext.opts.memLevel,
|
||||||
|
strategy = ext.opts.strategy) == Z_OK
|
||||||
|
)
|
||||||
|
ext.compCtxState = ContextState.Initialized
|
||||||
|
|
||||||
|
proc compress(zs: var ZStream, data: openArray[byte]): seq[byte] =
|
||||||
|
var buf: array[0xFFFF, byte]
|
||||||
|
|
||||||
|
# these casting is needed to prevent compilation
|
||||||
|
# error with CLANG
|
||||||
|
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
|
||||||
|
zs.avail_in = data.len.cuint
|
||||||
|
|
||||||
|
while true:
|
||||||
|
zs.next_out = cast[ptr cuchar](buf[0].addr)
|
||||||
|
zs.avail_out = buf.len.cuint
|
||||||
|
|
||||||
|
let r = zs.deflate(Z_SYNC_FLUSH)
|
||||||
|
let outSize = buf.len - zs.avail_out.int
|
||||||
|
result.add toOpenArray(buf, 0, outSize-1)
|
||||||
|
|
||||||
|
if r == Z_STREAM_END:
|
||||||
|
break
|
||||||
|
elif r == Z_OK:
|
||||||
|
# need more input or more output available
|
||||||
|
if zs.avail_in > 0 or zs.avail_out == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise newException(WSExtError, "compression error " & $r)
|
||||||
|
|
||||||
|
proc decompress(zs: var ZStream, limit: int, data: openArray[byte]): seq[byte] =
|
||||||
|
var buf: array[0xFFFF, byte]
|
||||||
|
|
||||||
|
# these casting is needed to prevent compilation
|
||||||
|
# error with CLANG
|
||||||
|
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
|
||||||
|
zs.avail_in = data.len.cuint
|
||||||
|
|
||||||
|
while true:
|
||||||
|
zs.next_out = cast[ptr cuchar](buf[0].addr)
|
||||||
|
zs.avail_out = buf.len.cuint
|
||||||
|
|
||||||
|
let r = zs.inflate(Z_NO_FLUSH)
|
||||||
|
let outSize = buf.len - zs.avail_out.int
|
||||||
|
result.add toOpenArray(buf, 0, outSize-1)
|
||||||
|
|
||||||
|
if result.len > limit:
|
||||||
|
raise newException(WSExtError, "decompression exceeds allowed limit")
|
||||||
|
|
||||||
|
if r == Z_STREAM_END:
|
||||||
|
break
|
||||||
|
elif r == Z_OK:
|
||||||
|
# need more input or more output available
|
||||||
|
if zs.avail_in > 0 or zs.avail_out == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise newException(WSExtError, "decompression error " & $r)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||||
|
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
|
||||||
|
# only data frames can be decompressed
|
||||||
|
return frame
|
||||||
|
|
||||||
|
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||||
|
# we want to know if this message is compressed or not
|
||||||
|
# if the frame opcode is text or binary, it should also the first frame
|
||||||
|
ext.compressedMsg = frame.rsv1
|
||||||
|
# clear rsv1 bit because we already done with it
|
||||||
|
frame.rsv1 = false
|
||||||
|
|
||||||
|
if not ext.compressedMsg:
|
||||||
|
# don't bother with uncompressed message
|
||||||
|
return frame
|
||||||
|
|
||||||
|
if ext.decompCtxState == ContextState.Invalid:
|
||||||
|
ext.decompressInit()
|
||||||
|
|
||||||
|
# even though the frame.data.len == 0, the stream needs
|
||||||
|
# to be closed with trailing bytes if it's a final frame
|
||||||
|
|
||||||
|
var data: seq[byte]
|
||||||
|
var buf: array[0xFFFF, byte]
|
||||||
|
|
||||||
|
while data.len < frame.length.int:
|
||||||
|
let len = min(frame.length.int - data.len, buf.len)
|
||||||
|
let read = await frame.read(ext.session.stream.reader, addr buf[0], len)
|
||||||
|
data.add toOpenArray(buf, 0, read - 1)
|
||||||
|
|
||||||
|
if data.len > ext.session.frameSize:
|
||||||
|
raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size")
|
||||||
|
|
||||||
|
if frame.fin:
|
||||||
|
data.add TrailingBytes
|
||||||
|
|
||||||
|
frame.data = decompress(ext.decompCtx, ext.opts.decompressLimit, data)
|
||||||
|
trace "DeflateExt decompress", input=frame.length, output=frame.data.len
|
||||||
|
|
||||||
|
frame.length = frame.data.len.uint64
|
||||||
|
frame.offset = 0
|
||||||
|
frame.consumed = 0
|
||||||
|
frame.mask = false # clear mask flag, decompressed content is not masked
|
||||||
|
|
||||||
|
if frame.fin:
|
||||||
|
# decompression using `client_` prefixed config
|
||||||
|
let noContextTakeover = getContextTakeover(ext.opts, not ext.opts.isServer)
|
||||||
|
if noContextTakeover:
|
||||||
|
doAssert(ext.decompCtx.inflateReset() == Z_OK)
|
||||||
|
ext.decompCtxState = ContextState.Reset
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||||
|
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
|
||||||
|
# only data frames can be compressed
|
||||||
|
return frame
|
||||||
|
|
||||||
|
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||||
|
# we only set rsv1 bit to true if the message is compressible
|
||||||
|
# and only set the first frame's rsv1
|
||||||
|
# if the frame opcode is text or binary, it should also the first frame
|
||||||
|
ext.compressedMsg = frame.data.len >= ext.opts.threshold
|
||||||
|
frame.rsv1 = ext.compressedMsg
|
||||||
|
|
||||||
|
if not ext.compressedMsg:
|
||||||
|
# don't bother with incompressible message
|
||||||
|
return frame
|
||||||
|
|
||||||
|
if ext.compCtxState == ContextState.Invalid:
|
||||||
|
ext.compressInit()
|
||||||
|
|
||||||
|
frame.length = frame.data.len.uint64
|
||||||
|
frame.data = compress(ext.compCtx, frame.data)
|
||||||
|
trace "DeflateExt compress", input=frame.length, output=frame.data.len
|
||||||
|
|
||||||
|
if frame.fin:
|
||||||
|
# remove trailing bytes
|
||||||
|
when not defined(release):
|
||||||
|
var trailer: array[4, byte]
|
||||||
|
trailer[0] = frame.data[^4]
|
||||||
|
trailer[1] = frame.data[^3]
|
||||||
|
trailer[2] = frame.data[^2]
|
||||||
|
trailer[3] = frame.data[^1]
|
||||||
|
doAssert trailer == TrailingBytes
|
||||||
|
frame.data.setLen(frame.data.len - 4)
|
||||||
|
|
||||||
|
frame.length = frame.data.len.uint64
|
||||||
|
frame.offset = 0
|
||||||
|
frame.consumed = 0
|
||||||
|
|
||||||
|
if frame.fin:
|
||||||
|
# compression using `server_` prefixed config
|
||||||
|
let noContextTakeover = getContextTakeover(ext.opts, ext.opts.isServer)
|
||||||
|
if noContextTakeover:
|
||||||
|
doAssert(ext.compCtx.deflateReset() == Z_OK)
|
||||||
|
ext.compCtxState = ContextState.Reset
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
method toHttpOptions(ext: DeflateExt): string =
|
||||||
|
# using paramStr here is a bit clunky
|
||||||
|
extID & ext.paramStr
|
||||||
|
|
||||||
|
proc destroyExt(ext: DeflateExt) =
|
||||||
|
if ext.compCtxState != ContextState.Invalid:
|
||||||
|
# zlib.deflateEnd somehow return DATA_ERROR
|
||||||
|
# when compression succeed some cases.
|
||||||
|
# we forget to do something?
|
||||||
|
discard ext.compCtx.deflateEnd()
|
||||||
|
ext.compCtxState = ContextState.Invalid
|
||||||
|
|
||||||
|
if ext.decompCtxState != ContextState.Invalid:
|
||||||
|
doAssert(ext.decompCtx.inflateEnd() == Z_OK)
|
||||||
|
ext.decompCtxState = ContextState.Invalid
|
||||||
|
|
||||||
|
proc makeOffer(
|
||||||
|
clientNoContextTakeOver: bool,
|
||||||
|
clientMaxWindowBits: int): string =
|
||||||
|
|
||||||
|
var param = extID
|
||||||
|
if clientMaxWindowBits in {9..15}:
|
||||||
|
param.add "; client_max_window_bits=" & $clientMaxWindowBits
|
||||||
|
else:
|
||||||
|
param.add "; client_max_window_bits"
|
||||||
|
|
||||||
|
if clientNoContextTakeOver:
|
||||||
|
param.add "; client_no_context_takeover"
|
||||||
|
|
||||||
|
param
|
||||||
|
|
||||||
|
proc deflateFactory*(
|
||||||
|
threshold = ExtDeflateThreshold,
|
||||||
|
decompressLimit = ExtDeflateDecompressLimit,
|
||||||
|
level = Z_DEFAULT_LEVEL,
|
||||||
|
strategy = Z_DEFAULT_STRATEGY,
|
||||||
|
memLevel = Z_DEFAULT_MEM_LEVEL,
|
||||||
|
clientNoContextTakeOver = false,
|
||||||
|
clientMaxWindowBits = 15): ExtFactory =
|
||||||
|
|
||||||
|
proc factory(isServer: bool,
|
||||||
|
args: seq[ExtParam]): Result[Ext, string] {.
|
||||||
|
gcsafe, raises: [Defect].} =
|
||||||
|
|
||||||
|
# capture user configuration via closure
|
||||||
|
var opts = DeflateOpts(
|
||||||
|
isServer: isServer,
|
||||||
|
threshold: threshold,
|
||||||
|
decompressLimit: decompressLimit,
|
||||||
|
level: level,
|
||||||
|
strategy: strategy,
|
||||||
|
memLevel: memLevel
|
||||||
|
)
|
||||||
|
let resp = createParams(args, opts)
|
||||||
|
if resp.isErr:
|
||||||
|
return err(resp.error)
|
||||||
|
|
||||||
|
var ext: DeflateExt
|
||||||
|
ext.new(destroyExt)
|
||||||
|
ext.name = extID
|
||||||
|
ext.paramStr = resp.get()
|
||||||
|
ext.opts = opts
|
||||||
|
ext.compressedMsg = false
|
||||||
|
ext.compCtxState = ContextState.Invalid
|
||||||
|
ext.decompCtxState= ContextState.Invalid
|
||||||
|
|
||||||
|
ok(ext)
|
||||||
|
|
||||||
|
ExtFactory(
|
||||||
|
name: extID,
|
||||||
|
factory: factory,
|
||||||
|
clientOffer: makeOffer(
|
||||||
|
clientNoContextTakeOver,
|
||||||
|
clientMaxWindowBits
|
||||||
|
)
|
||||||
|
)
|
Loading…
Reference in New Issue