implement permessage-deflate compression extension
depends on zlib as it's backend compressor pass both client and server tests in autobahn test suite
This commit is contained in:
parent
fef04a1595
commit
14d8e51f53
|
@ -10,7 +10,7 @@
|
|||
import
|
||||
std/[strutils],
|
||||
pkg/[chronos, chronicles, stew/byteutils],
|
||||
../ws/[ws, types, frame]
|
||||
../ws/[ws, types, frame, extensions/compression/deflate]
|
||||
|
||||
const
|
||||
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||
|
@ -28,13 +28,14 @@ else:
|
|||
secure = false
|
||||
serverPort = 9001
|
||||
|
||||
proc connectServer(path: string): Future[WSSession] {.async.} =
|
||||
proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSession] {.async.} =
|
||||
let ws = await WebSocket.connect(
|
||||
host = "127.0.0.1",
|
||||
port = Port(serverPort),
|
||||
path = path,
|
||||
secure=secure,
|
||||
flags=clientFlags
|
||||
flags=clientFlags,
|
||||
factories = factories
|
||||
)
|
||||
return ws
|
||||
|
||||
|
@ -71,11 +72,12 @@ proc main() {.async.} =
|
|||
let caseCount = await getCaseCount()
|
||||
trace "case count", count=caseCount
|
||||
|
||||
var deflateFactory = @[deflateFactory()]
|
||||
for i in 1..caseCount:
|
||||
trace "runcase", no=i
|
||||
let path = "/runCase?case=$1&agent=$2" % [$i, agent]
|
||||
try:
|
||||
let ws = await connectServer(path)
|
||||
let ws = await connectServer(path, deflateFactory)
|
||||
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
# echo back
|
||||
|
|
|
@ -12,7 +12,7 @@ import pkg/[chronos,
|
|||
chronicles,
|
||||
httputils]
|
||||
|
||||
import ../ws/ws
|
||||
import ../ws/[ws, extensions/compression/deflate]
|
||||
import ../tests/keys
|
||||
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
|
@ -23,7 +23,8 @@ proc handle(request: HttpRequest) {.async.} =
|
|||
|
||||
trace "Initiating web socket connection."
|
||||
try:
|
||||
let server = WSServer.new()
|
||||
let deflateFactory = deflateFactory()
|
||||
let server = WSServer.new(factories = [deflateFactory])
|
||||
let ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection"
|
||||
|
@ -32,8 +33,13 @@ proc handle(request: HttpRequest) {.async.} =
|
|||
trace "Websocket handshake completed"
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let recvData = await ws.recv()
|
||||
|
||||
trace "Client Response: ", size = recvData.len, binary = ws.binary
|
||||
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
# if session already terminated by peer,
|
||||
# no need to send response
|
||||
break
|
||||
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ requires "stew >= 0.1.0"
|
|||
requires "asynctest >= 0.2.0 & < 0.3.0"
|
||||
requires "nimcrypto"
|
||||
requires "bearssl"
|
||||
requires "https://github.com/status-im/nim-zlib"
|
||||
|
||||
task test, "run tests":
|
||||
exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim"
|
||||
|
|
|
@ -1,212 +0,0 @@
|
|||
## nim-ws
|
||||
## Copyright (c) 2021 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
|
||||
import
|
||||
std/[strutils],
|
||||
pkg/[stew/results, chronos],
|
||||
../../types, ../../frame, ./miniz/miniz_api
|
||||
|
||||
type
|
||||
DeflateOpts = object
|
||||
isServer: bool
|
||||
serverNoContextTakeOver: bool
|
||||
clientNoContextTakeOver: bool
|
||||
serverMaxWindowBits: int
|
||||
clientMaxWindowBits: int
|
||||
|
||||
DeflateExt = ref object of Ext
|
||||
paramStr: string
|
||||
opts: DeflateOpts
|
||||
# 'messageCompressed' is a two way flag:
|
||||
# 1. the original message is compressible or not
|
||||
# 2. the frame we received need decompression or not
|
||||
messageCompressed: bool
|
||||
|
||||
const
|
||||
extID = "permessage-deflate"
|
||||
|
||||
proc concatParam(resp: var string, param: string) =
|
||||
resp.add ';'
|
||||
resp.add param
|
||||
|
||||
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
|
||||
if arg.value.len == 0:
|
||||
return ok("")
|
||||
|
||||
if arg.value.len > 2:
|
||||
return err("window bits expect 2 bytes, got " & $arg.value.len)
|
||||
|
||||
for n in arg.value:
|
||||
if n notin Digits:
|
||||
return err("window bits value contains illegal char: " & $n)
|
||||
|
||||
var winbit = 0
|
||||
for i in 0..<arg.value.len:
|
||||
winbit = winbit * 10 + arg.value[i].int - '0'.int
|
||||
|
||||
if winbit < 8 or winbit > 15:
|
||||
return err("window bits should between 8-15, got " & $winbit)
|
||||
|
||||
res = winbit
|
||||
return ok("=" & arg.value)
|
||||
|
||||
proc validateParams(args: seq[ExtParam],
|
||||
opts: var DeflateOpts): Result[string, string] =
|
||||
# besides validating extensions params, this proc
|
||||
# also constructing extension param for response
|
||||
var resp = ""
|
||||
for arg in args:
|
||||
case arg.name
|
||||
of "server_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'server_no_context_takeover' should have no param")
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
opts.serverNoContextTakeOver = true
|
||||
of "client_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'client_no_context_takeover' should have no param")
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
opts.clientNoContextTakeOver = true
|
||||
of "server_max_window_bits":
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
resp.add res.get()
|
||||
of "client_max_window_bits":
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
resp.add res.get()
|
||||
else:
|
||||
return err("unrecognized param: " & arg.name)
|
||||
|
||||
ok(resp)
|
||||
|
||||
method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
if not ext.messageCompressed:
|
||||
return data
|
||||
|
||||
# TODO: append trailing bytes
|
||||
var mz = MzStream(
|
||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
||||
avail_in: data.len.cuint
|
||||
)
|
||||
|
||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
||||
MZ_DEFAULT_WINDOW_BITS
|
||||
else:
|
||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
||||
|
||||
doAssert(mz.inflateInit2(windowBits) == MZ_OK)
|
||||
var res: seq[byte]
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
while true:
|
||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
mz.avail_out = buf.len.cuint
|
||||
let r = mz.inflate(MZ_SYNC_FLUSH)
|
||||
let outSize = buf.len - mz.avail_out.int
|
||||
res.add toOpenArray(buf, 0, outSize-1)
|
||||
if r == MZ_STREAM_END:
|
||||
break
|
||||
elif r == MZ_OK:
|
||||
continue
|
||||
else:
|
||||
doAssert(false, "decompression error")
|
||||
|
||||
doAssert(mz.inflateEnd() == MZ_OK)
|
||||
return res
|
||||
|
||||
method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
|
||||
var mz = MzStream(
|
||||
next_in: cast[ptr cuchar](data[0].unsafeAddr),
|
||||
avail_in: data.len.cuint
|
||||
)
|
||||
|
||||
let windowBits = if ext.opts.serverMaxWindowBits == 0:
|
||||
MZ_DEFAULT_WINDOW_BITS
|
||||
else:
|
||||
MzWindowBits(ext.opts.serverMaxWindowBits)
|
||||
|
||||
doAssert(mz.deflateInit2(
|
||||
level = MZ_DEFAULT_LEVEL,
|
||||
meth = MZ_DEFLATED,
|
||||
windowBits,
|
||||
1,
|
||||
strategy = MZ_DEFAULT_STRATEGY) == MZ_OK
|
||||
)
|
||||
|
||||
let maxSize = mz.deflateBound(data.len.culong).int
|
||||
var res: seq[byte]
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
while true:
|
||||
mz.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
mz.avail_out = buf.len.cuint
|
||||
let r = mz.deflate(MZ_FINISH)
|
||||
let outSize = buf.len - mz.avail_out.int
|
||||
res.add toOpenArray(buf, 0, outSize-1)
|
||||
if r == MZ_STREAM_END:
|
||||
break
|
||||
elif r == MZ_OK:
|
||||
continue
|
||||
else:
|
||||
doAssert(false, "compression error")
|
||||
|
||||
# TODO: cut trailing bytes
|
||||
doAssert(mz.deflateEnd() == MZ_OK)
|
||||
ext.messageCompressed = res.len < data.len
|
||||
if ext.messageCompressed:
|
||||
return res
|
||||
else:
|
||||
return data
|
||||
|
||||
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# only data frame can be compressed
|
||||
# and we want to know if this message is compressed or not
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
ext.messageCompressed = frame.rsv1
|
||||
# clear rsv1 bit because we already done with it
|
||||
frame.rsv1 = false
|
||||
return frame
|
||||
|
||||
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# only data frame can be compressed
|
||||
# and we only set rsv1 bit to true if the message is compressible
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
frame.rsv1 = ext.messageCompressed
|
||||
return frame
|
||||
|
||||
method toHttpOptions(ext: DeflateExt): string =
|
||||
# using paramStr here is a bit clunky
|
||||
extID & "; " & ext.paramStr
|
||||
|
||||
proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} =
|
||||
var opts = DeflateOpts(isServer: isServer)
|
||||
let resp = validateParams(args, opts)
|
||||
if resp.isErr:
|
||||
return err(resp.error)
|
||||
let ext = DeflateExt(
|
||||
name: extID,
|
||||
paramStr: resp.get(),
|
||||
opts: opts
|
||||
)
|
||||
ok(ext)
|
||||
|
||||
const
|
||||
deflateFactory* = (extID, deflateExtFactory)
|
|
@ -0,0 +1,394 @@
|
|||
## nim-ws
|
||||
## Copyright (c) 2021 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import
|
||||
std/[strutils],
|
||||
pkg/[stew/results,
|
||||
chronos,
|
||||
chronicles,
|
||||
zlib],
|
||||
../../types,
|
||||
../../frame
|
||||
|
||||
type
|
||||
DeflateOpts = object
|
||||
isServer: bool
|
||||
decompressLimit: int # max allowed decompression size
|
||||
threshold: int # size in bytes below which messages
|
||||
# should not be compressed.
|
||||
level: ZLevel # compression level
|
||||
strategy: ZStrategy # compression strategy
|
||||
memLevel: ZMemLevel # hint for miniz memory consumption
|
||||
serverNoContextTakeOver: bool
|
||||
clientNoContextTakeOver: bool
|
||||
serverMaxWindowBits: int
|
||||
clientMaxWindowBits: int
|
||||
|
||||
ContextState {.pure.} = enum
|
||||
Invalid
|
||||
Initialized
|
||||
Reset
|
||||
|
||||
DeflateExt = ref object of Ext
|
||||
paramStr : string
|
||||
opts : DeflateOpts
|
||||
compressedMsg : bool
|
||||
compCtx : ZStream
|
||||
compCtxState : ContextState
|
||||
decompCtx : ZStream
|
||||
decompCtxState: ContextState
|
||||
|
||||
const
|
||||
extID = "permessage-deflate"
|
||||
TrailingBytes = [0x00.byte, 0x00.byte, 0xff.byte, 0xff.byte]
|
||||
ExtDeflateThreshold* = 1024
|
||||
ExtDeflateDecompressLimit* = 10 shl 20 # 10mb
|
||||
|
||||
proc concatParam(resp: var string, param: string) =
|
||||
resp.add "; "
|
||||
resp.add param
|
||||
|
||||
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
|
||||
if arg.value.len == 0:
|
||||
return ok("")
|
||||
|
||||
if arg.value.len > 2:
|
||||
return err("window bits expect 2 bytes, got " & $arg.value.len)
|
||||
|
||||
for n in arg.value:
|
||||
if n notin Digits:
|
||||
return err("window bits value contains illegal char: " & $n)
|
||||
|
||||
var winbit = 0
|
||||
for i in 0..<arg.value.len:
|
||||
winbit = winbit * 10 + arg.value[i].int - '0'.int
|
||||
|
||||
if winbit < 8 or winbit > 15:
|
||||
return err("window bits should between 8-15, got " & $winbit)
|
||||
|
||||
res = winbit
|
||||
return ok("=" & arg.value)
|
||||
|
||||
proc createParams(args: seq[ExtParam],
|
||||
opts: var DeflateOpts): Result[string, string] =
|
||||
# besides validating extensions params, this proc
|
||||
# also constructing extension params for response
|
||||
var resp = ""
|
||||
for arg in args:
|
||||
case arg.name
|
||||
of "server_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'server_no_context_takeover' should have no param")
|
||||
opts.serverNoContextTakeOver = true
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
of "client_no_context_takeover":
|
||||
if arg.value.len > 0:
|
||||
return err("'client_no_context_takeover' should have no param")
|
||||
opts.clientNoContextTakeOver = true
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
of "server_max_window_bits":
|
||||
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
if opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
if opts.serverMaxWindowBits == 8:
|
||||
# zlib does not support windowBits == 8
|
||||
resp.add "=9"
|
||||
else:
|
||||
resp.add res.get()
|
||||
of "client_max_window_bits":
|
||||
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
|
||||
if res.isErr:
|
||||
return res
|
||||
if not opts.isServer:
|
||||
concatParam(resp, arg.name)
|
||||
if opts.clientMaxWindowBits == 8:
|
||||
# zlib does not support windowBits == 8
|
||||
resp.add "=9"
|
||||
else:
|
||||
resp.add res.get()
|
||||
else:
|
||||
return err("unrecognized param: " & arg.name)
|
||||
|
||||
ok(resp)
|
||||
|
||||
proc getWindowBits(opts: DeflateOpts, isServer: bool): ZWindowBits =
|
||||
if isServer:
|
||||
if opts.serverMaxWindowBits == 0:
|
||||
Z_RAW_DEFLATE
|
||||
else:
|
||||
ZWindowBits(-opts.serverMaxWindowBits)
|
||||
else:
|
||||
if opts.clientMaxWindowBits == 0:
|
||||
Z_RAW_DEFLATE
|
||||
else:
|
||||
ZWindowBits(-opts.clientMaxWindowBits)
|
||||
|
||||
proc getContextTakeover(opts: DeflateOpts, isServer: bool): bool =
|
||||
if isServer:
|
||||
opts.serverNoContextTakeOver
|
||||
else:
|
||||
opts.clientNoContextTakeOver
|
||||
|
||||
proc decompressInit(ext: DeflateExt) =
|
||||
# decompression using `client_` prefixed config
|
||||
let windowBits = getWindowBits(ext.opts, not ext.opts.isServer)
|
||||
doAssert(ext.decompCtx.inflateInit2(windowBits) == Z_OK)
|
||||
ext.decompCtxState = ContextState.Initialized
|
||||
|
||||
proc compressInit(ext: DeflateExt) =
|
||||
# compression using `server_` prefixed config
|
||||
let windowBits = getWindowBits(ext.opts, ext.opts.isServer)
|
||||
doAssert(ext.compCtx.deflateInit2(
|
||||
level = ext.opts.level,
|
||||
meth = Z_DEFLATED,
|
||||
windowBits,
|
||||
memLevel = ext.opts.memLevel,
|
||||
strategy = ext.opts.strategy) == Z_OK
|
||||
)
|
||||
ext.compCtxState = ContextState.Initialized
|
||||
|
||||
proc compress(zs: var ZStream, data: openArray[byte]): seq[byte] =
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
# these casting is needed to prevent compilation
|
||||
# error with CLANG
|
||||
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
|
||||
zs.avail_in = data.len.cuint
|
||||
|
||||
while true:
|
||||
zs.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
zs.avail_out = buf.len.cuint
|
||||
|
||||
let r = zs.deflate(Z_SYNC_FLUSH)
|
||||
let outSize = buf.len - zs.avail_out.int
|
||||
result.add toOpenArray(buf, 0, outSize-1)
|
||||
|
||||
if r == Z_STREAM_END:
|
||||
break
|
||||
elif r == Z_OK:
|
||||
# need more input or more output available
|
||||
if zs.avail_in > 0 or zs.avail_out == 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise newException(WSExtError, "compression error " & $r)
|
||||
|
||||
proc decompress(zs: var ZStream, limit: int, data: openArray[byte]): seq[byte] =
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
# these casting is needed to prevent compilation
|
||||
# error with CLANG
|
||||
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
|
||||
zs.avail_in = data.len.cuint
|
||||
|
||||
while true:
|
||||
zs.next_out = cast[ptr cuchar](buf[0].addr)
|
||||
zs.avail_out = buf.len.cuint
|
||||
|
||||
let r = zs.inflate(Z_NO_FLUSH)
|
||||
let outSize = buf.len - zs.avail_out.int
|
||||
result.add toOpenArray(buf, 0, outSize-1)
|
||||
|
||||
if result.len > limit:
|
||||
raise newException(WSExtError, "decompression exceeds allowed limit")
|
||||
|
||||
if r == Z_STREAM_END:
|
||||
break
|
||||
elif r == Z_OK:
|
||||
# need more input or more output available
|
||||
if zs.avail_in > 0 or zs.avail_out == 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise newException(WSExtError, "decompression error " & $r)
|
||||
|
||||
return result
|
||||
|
||||
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
|
||||
# only data frames can be decompressed
|
||||
return frame
|
||||
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# we want to know if this message is compressed or not
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
ext.compressedMsg = frame.rsv1
|
||||
# clear rsv1 bit because we already done with it
|
||||
frame.rsv1 = false
|
||||
|
||||
if not ext.compressedMsg:
|
||||
# don't bother with uncompressed message
|
||||
return frame
|
||||
|
||||
if ext.decompCtxState == ContextState.Invalid:
|
||||
ext.decompressInit()
|
||||
|
||||
# even though the frame.data.len == 0, the stream needs
|
||||
# to be closed with trailing bytes if it's a final frame
|
||||
|
||||
var data: seq[byte]
|
||||
var buf: array[0xFFFF, byte]
|
||||
|
||||
while data.len < frame.length.int:
|
||||
let len = min(frame.length.int - data.len, buf.len)
|
||||
let read = await frame.read(ext.session.stream.reader, addr buf[0], len)
|
||||
data.add toOpenArray(buf, 0, read - 1)
|
||||
|
||||
if data.len > ext.session.frameSize:
|
||||
raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size")
|
||||
|
||||
if frame.fin:
|
||||
data.add TrailingBytes
|
||||
|
||||
frame.data = decompress(ext.decompCtx, ext.opts.decompressLimit, data)
|
||||
trace "DeflateExt decompress", input=frame.length, output=frame.data.len
|
||||
|
||||
frame.length = frame.data.len.uint64
|
||||
frame.offset = 0
|
||||
frame.consumed = 0
|
||||
frame.mask = false # clear mask flag, decompressed content is not masked
|
||||
|
||||
if frame.fin:
|
||||
# decompression using `client_` prefixed config
|
||||
let noContextTakeover = getContextTakeover(ext.opts, not ext.opts.isServer)
|
||||
if noContextTakeover:
|
||||
doAssert(ext.decompCtx.inflateReset() == Z_OK)
|
||||
ext.decompCtxState = ContextState.Reset
|
||||
|
||||
return frame
|
||||
|
||||
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
|
||||
# only data frames can be compressed
|
||||
return frame
|
||||
|
||||
if frame.opcode in {Opcode.Text, Opcode.Binary}:
|
||||
# we only set rsv1 bit to true if the message is compressible
|
||||
# and only set the first frame's rsv1
|
||||
# if the frame opcode is text or binary, it should also the first frame
|
||||
ext.compressedMsg = frame.data.len >= ext.opts.threshold
|
||||
frame.rsv1 = ext.compressedMsg
|
||||
|
||||
if not ext.compressedMsg:
|
||||
# don't bother with incompressible message
|
||||
return frame
|
||||
|
||||
if ext.compCtxState == ContextState.Invalid:
|
||||
ext.compressInit()
|
||||
|
||||
frame.length = frame.data.len.uint64
|
||||
frame.data = compress(ext.compCtx, frame.data)
|
||||
trace "DeflateExt compress", input=frame.length, output=frame.data.len
|
||||
|
||||
if frame.fin:
|
||||
# remove trailing bytes
|
||||
when not defined(release):
|
||||
var trailer: array[4, byte]
|
||||
trailer[0] = frame.data[^4]
|
||||
trailer[1] = frame.data[^3]
|
||||
trailer[2] = frame.data[^2]
|
||||
trailer[3] = frame.data[^1]
|
||||
doAssert trailer == TrailingBytes
|
||||
frame.data.setLen(frame.data.len - 4)
|
||||
|
||||
frame.length = frame.data.len.uint64
|
||||
frame.offset = 0
|
||||
frame.consumed = 0
|
||||
|
||||
if frame.fin:
|
||||
# compression using `server_` prefixed config
|
||||
let noContextTakeover = getContextTakeover(ext.opts, ext.opts.isServer)
|
||||
if noContextTakeover:
|
||||
doAssert(ext.compCtx.deflateReset() == Z_OK)
|
||||
ext.compCtxState = ContextState.Reset
|
||||
|
||||
return frame
|
||||
|
||||
method toHttpOptions(ext: DeflateExt): string =
|
||||
# using paramStr here is a bit clunky
|
||||
extID & ext.paramStr
|
||||
|
||||
proc destroyExt(ext: DeflateExt) =
|
||||
if ext.compCtxState != ContextState.Invalid:
|
||||
# zlib.deflateEnd somehow return DATA_ERROR
|
||||
# when compression succeed some cases.
|
||||
# we forget to do something?
|
||||
discard ext.compCtx.deflateEnd()
|
||||
ext.compCtxState = ContextState.Invalid
|
||||
|
||||
if ext.decompCtxState != ContextState.Invalid:
|
||||
doAssert(ext.decompCtx.inflateEnd() == Z_OK)
|
||||
ext.decompCtxState = ContextState.Invalid
|
||||
|
||||
proc makeOffer(
|
||||
clientNoContextTakeOver: bool,
|
||||
clientMaxWindowBits: int): string =
|
||||
|
||||
var param = extID
|
||||
if clientMaxWindowBits in {9..15}:
|
||||
param.add "; client_max_window_bits=" & $clientMaxWindowBits
|
||||
else:
|
||||
param.add "; client_max_window_bits"
|
||||
|
||||
if clientNoContextTakeOver:
|
||||
param.add "; client_no_context_takeover"
|
||||
|
||||
param
|
||||
|
||||
proc deflateFactory*(
|
||||
threshold = ExtDeflateThreshold,
|
||||
decompressLimit = ExtDeflateDecompressLimit,
|
||||
level = Z_DEFAULT_LEVEL,
|
||||
strategy = Z_DEFAULT_STRATEGY,
|
||||
memLevel = Z_DEFAULT_MEM_LEVEL,
|
||||
clientNoContextTakeOver = false,
|
||||
clientMaxWindowBits = 15): ExtFactory =
|
||||
|
||||
proc factory(isServer: bool,
|
||||
args: seq[ExtParam]): Result[Ext, string] {.
|
||||
gcsafe, raises: [Defect].} =
|
||||
|
||||
# capture user configuration via closure
|
||||
var opts = DeflateOpts(
|
||||
isServer: isServer,
|
||||
threshold: threshold,
|
||||
decompressLimit: decompressLimit,
|
||||
level: level,
|
||||
strategy: strategy,
|
||||
memLevel: memLevel
|
||||
)
|
||||
let resp = createParams(args, opts)
|
||||
if resp.isErr:
|
||||
return err(resp.error)
|
||||
|
||||
var ext: DeflateExt
|
||||
ext.new(destroyExt)
|
||||
ext.name = extID
|
||||
ext.paramStr = resp.get()
|
||||
ext.opts = opts
|
||||
ext.compressedMsg = false
|
||||
ext.compCtxState = ContextState.Invalid
|
||||
ext.decompCtxState= ContextState.Invalid
|
||||
|
||||
ok(ext)
|
||||
|
||||
ExtFactory(
|
||||
name: extID,
|
||||
factory: factory,
|
||||
clientOffer: makeOffer(
|
||||
clientNoContextTakeOver,
|
||||
clientMaxWindowBits
|
||||
)
|
||||
)
|
Loading…
Reference in New Issue