implement permessage-deflate compression extension

depends on zlib as it's backend compressor
pass both client and server tests in autobahn test suite
This commit is contained in:
jangko 2021-06-20 10:19:08 +07:00
parent fef04a1595
commit 14d8e51f53
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
5 changed files with 410 additions and 219 deletions

View File

@ -10,7 +10,7 @@
import import
std/[strutils], std/[strutils],
pkg/[chronos, chronicles, stew/byteutils], pkg/[chronos, chronicles, stew/byteutils],
../ws/[ws, types, frame] ../ws/[ws, types, frame, extensions/compression/deflate]
const const
clientFlags = {NoVerifyHost, NoVerifyServerName} clientFlags = {NoVerifyHost, NoVerifyServerName}
@ -28,13 +28,14 @@ else:
secure = false secure = false
serverPort = 9001 serverPort = 9001
proc connectServer(path: string): Future[WSSession] {.async.} = proc connectServer(path: string, factories: seq[ExtFactory] = @[]): Future[WSSession] {.async.} =
let ws = await WebSocket.connect( let ws = await WebSocket.connect(
host = "127.0.0.1", host = "127.0.0.1",
port = Port(serverPort), port = Port(serverPort),
path = path, path = path,
secure=secure, secure=secure,
flags=clientFlags flags=clientFlags,
factories = factories
) )
return ws return ws
@ -71,11 +72,12 @@ proc main() {.async.} =
let caseCount = await getCaseCount() let caseCount = await getCaseCount()
trace "case count", count=caseCount trace "case count", count=caseCount
var deflateFactory = @[deflateFactory()]
for i in 1..caseCount: for i in 1..caseCount:
trace "runcase", no=i trace "runcase", no=i
let path = "/runCase?case=$1&agent=$2" % [$i, agent] let path = "/runCase?case=$1&agent=$2" % [$i, agent]
try: try:
let ws = await connectServer(path) let ws = await connectServer(path, deflateFactory)
while ws.readystate != ReadyState.Closed: while ws.readystate != ReadyState.Closed:
# echo back # echo back

View File

@ -12,7 +12,7 @@ import pkg/[chronos,
chronicles, chronicles,
httputils] httputils]
import ../ws/ws import ../ws/[ws, extensions/compression/deflate]
import ../tests/keys import ../tests/keys
proc handle(request: HttpRequest) {.async.} = proc handle(request: HttpRequest) {.async.} =
@ -23,7 +23,8 @@ proc handle(request: HttpRequest) {.async.} =
trace "Initiating web socket connection." trace "Initiating web socket connection."
try: try:
let server = WSServer.new() let deflateFactory = deflateFactory()
let server = WSServer.new(factories = [deflateFactory])
let ws = await server.handleRequest(request) let ws = await server.handleRequest(request)
if ws.readyState != Open: if ws.readyState != Open:
error "Failed to open websocket connection" error "Failed to open websocket connection"
@ -32,8 +33,13 @@ proc handle(request: HttpRequest) {.async.} =
trace "Websocket handshake completed" trace "Websocket handshake completed"
while ws.readyState != ReadyState.Closed: while ws.readyState != ReadyState.Closed:
let recvData = await ws.recv() let recvData = await ws.recv()
trace "Client Response: ", size = recvData.len, binary = ws.binary trace "Client Response: ", size = recvData.len, binary = ws.binary
if ws.readyState == ReadyState.Closed:
# if session already terminated by peer,
# no need to send response
break
await ws.send(recvData, await ws.send(recvData,
if ws.binary: Opcode.Binary else: Opcode.Text) if ws.binary: Opcode.Binary else: Opcode.Text)

View File

@ -22,6 +22,7 @@ requires "stew >= 0.1.0"
requires "asynctest >= 0.2.0 & < 0.3.0" requires "asynctest >= 0.2.0 & < 0.3.0"
requires "nimcrypto" requires "nimcrypto"
requires "bearssl" requires "bearssl"
requires "https://github.com/status-im/nim-zlib"
task test, "run tests": task test, "run tests":
exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim" exec "nim --hints:off c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim"

View File

@ -1,212 +0,0 @@
## nim-ws
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import
std/[strutils],
pkg/[stew/results, chronos],
../../types, ../../frame, ./miniz/miniz_api
type
DeflateOpts = object
isServer: bool
serverNoContextTakeOver: bool
clientNoContextTakeOver: bool
serverMaxWindowBits: int
clientMaxWindowBits: int
DeflateExt = ref object of Ext
paramStr: string
opts: DeflateOpts
# 'messageCompressed' is a two way flag:
# 1. the original message is compressible or not
# 2. the frame we received need decompression or not
messageCompressed: bool
const
extID = "permessage-deflate"
proc concatParam(resp: var string, param: string) =
resp.add ';'
resp.add param
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
if arg.value.len == 0:
return ok("")
if arg.value.len > 2:
return err("window bits expect 2 bytes, got " & $arg.value.len)
for n in arg.value:
if n notin Digits:
return err("window bits value contains illegal char: " & $n)
var winbit = 0
for i in 0..<arg.value.len:
winbit = winbit * 10 + arg.value[i].int - '0'.int
if winbit < 8 or winbit > 15:
return err("window bits should between 8-15, got " & $winbit)
res = winbit
return ok("=" & arg.value)
proc validateParams(args: seq[ExtParam],
opts: var DeflateOpts): Result[string, string] =
# besides validating extensions params, this proc
# also constructing extension param for response
var resp = ""
for arg in args:
case arg.name
of "server_no_context_takeover":
if arg.value.len > 0:
return err("'server_no_context_takeover' should have no param")
if opts.isServer:
concatParam(resp, arg.name)
opts.serverNoContextTakeOver = true
of "client_no_context_takeover":
if arg.value.len > 0:
return err("'client_no_context_takeover' should have no param")
if opts.isServer:
concatParam(resp, arg.name)
opts.clientNoContextTakeOver = true
of "server_max_window_bits":
if opts.isServer:
concatParam(resp, arg.name)
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
if res.isErr:
return res
resp.add res.get()
of "client_max_window_bits":
if opts.isServer:
concatParam(resp, arg.name)
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
if res.isErr:
return res
resp.add res.get()
else:
return err("unrecognized param: " & arg.name)
ok(resp)
method decode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
if not ext.messageCompressed:
return data
# TODO: append trailing bytes
var mz = MzStream(
next_in: cast[ptr cuchar](data[0].unsafeAddr),
avail_in: data.len.cuint
)
let windowBits = if ext.opts.serverMaxWindowBits == 0:
MZ_DEFAULT_WINDOW_BITS
else:
MzWindowBits(ext.opts.serverMaxWindowBits)
doAssert(mz.inflateInit2(windowBits) == MZ_OK)
var res: seq[byte]
var buf: array[0xFFFF, byte]
while true:
mz.next_out = cast[ptr cuchar](buf[0].addr)
mz.avail_out = buf.len.cuint
let r = mz.inflate(MZ_SYNC_FLUSH)
let outSize = buf.len - mz.avail_out.int
res.add toOpenArray(buf, 0, outSize-1)
if r == MZ_STREAM_END:
break
elif r == MZ_OK:
continue
else:
doAssert(false, "decompression error")
doAssert(mz.inflateEnd() == MZ_OK)
return res
method encode*(ext: DeflateExt, data: seq[byte]): Future[seq[byte]] {.async.} =
var mz = MzStream(
next_in: cast[ptr cuchar](data[0].unsafeAddr),
avail_in: data.len.cuint
)
let windowBits = if ext.opts.serverMaxWindowBits == 0:
MZ_DEFAULT_WINDOW_BITS
else:
MzWindowBits(ext.opts.serverMaxWindowBits)
doAssert(mz.deflateInit2(
level = MZ_DEFAULT_LEVEL,
meth = MZ_DEFLATED,
windowBits,
1,
strategy = MZ_DEFAULT_STRATEGY) == MZ_OK
)
let maxSize = mz.deflateBound(data.len.culong).int
var res: seq[byte]
var buf: array[0xFFFF, byte]
while true:
mz.next_out = cast[ptr cuchar](buf[0].addr)
mz.avail_out = buf.len.cuint
let r = mz.deflate(MZ_FINISH)
let outSize = buf.len - mz.avail_out.int
res.add toOpenArray(buf, 0, outSize-1)
if r == MZ_STREAM_END:
break
elif r == MZ_OK:
continue
else:
doAssert(false, "compression error")
# TODO: cut trailing bytes
doAssert(mz.deflateEnd() == MZ_OK)
ext.messageCompressed = res.len < data.len
if ext.messageCompressed:
return res
else:
return data
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# only data frame can be compressed
# and we want to know if this message is compressed or not
# if the frame opcode is text or binary, it should also the first frame
ext.messageCompressed = frame.rsv1
# clear rsv1 bit because we already done with it
frame.rsv1 = false
return frame
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# only data frame can be compressed
# and we only set rsv1 bit to true if the message is compressible
# if the frame opcode is text or binary, it should also the first frame
frame.rsv1 = ext.messageCompressed
return frame
method toHttpOptions(ext: DeflateExt): string =
# using paramStr here is a bit clunky
extID & "; " & ext.paramStr
proc deflateExtFactory(isServer: bool, args: seq[ExtParam]): Result[Ext, string] {.raises: [Defect].} =
var opts = DeflateOpts(isServer: isServer)
let resp = validateParams(args, opts)
if resp.isErr:
return err(resp.error)
let ext = DeflateExt(
name: extID,
paramStr: resp.get(),
opts: opts
)
ok(ext)
const
deflateFactory* = (extID, deflateExtFactory)

View File

@ -0,0 +1,394 @@
## nim-ws
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import
std/[strutils],
pkg/[stew/results,
chronos,
chronicles,
zlib],
../../types,
../../frame
type
DeflateOpts = object
isServer: bool
decompressLimit: int # max allowed decompression size
threshold: int # size in bytes below which messages
# should not be compressed.
level: ZLevel # compression level
strategy: ZStrategy # compression strategy
memLevel: ZMemLevel # hint for miniz memory consumption
serverNoContextTakeOver: bool
clientNoContextTakeOver: bool
serverMaxWindowBits: int
clientMaxWindowBits: int
ContextState {.pure.} = enum
Invalid
Initialized
Reset
DeflateExt = ref object of Ext
paramStr : string
opts : DeflateOpts
compressedMsg : bool
compCtx : ZStream
compCtxState : ContextState
decompCtx : ZStream
decompCtxState: ContextState
const
extID = "permessage-deflate"
TrailingBytes = [0x00.byte, 0x00.byte, 0xff.byte, 0xff.byte]
ExtDeflateThreshold* = 1024
ExtDeflateDecompressLimit* = 10 shl 20 # 10mb
proc concatParam(resp: var string, param: string) =
resp.add "; "
resp.add param
proc validateWindowBits(arg: ExtParam, res: var int): Result[string, string] =
if arg.value.len == 0:
return ok("")
if arg.value.len > 2:
return err("window bits expect 2 bytes, got " & $arg.value.len)
for n in arg.value:
if n notin Digits:
return err("window bits value contains illegal char: " & $n)
var winbit = 0
for i in 0..<arg.value.len:
winbit = winbit * 10 + arg.value[i].int - '0'.int
if winbit < 8 or winbit > 15:
return err("window bits should between 8-15, got " & $winbit)
res = winbit
return ok("=" & arg.value)
proc createParams(args: seq[ExtParam],
opts: var DeflateOpts): Result[string, string] =
# besides validating extensions params, this proc
# also constructing extension params for response
var resp = ""
for arg in args:
case arg.name
of "server_no_context_takeover":
if arg.value.len > 0:
return err("'server_no_context_takeover' should have no param")
opts.serverNoContextTakeOver = true
if opts.isServer:
concatParam(resp, arg.name)
of "client_no_context_takeover":
if arg.value.len > 0:
return err("'client_no_context_takeover' should have no param")
opts.clientNoContextTakeOver = true
if opts.isServer:
concatParam(resp, arg.name)
of "server_max_window_bits":
let res = validateWindowBits(arg, opts.serverMaxWindowBits)
if res.isErr:
return res
if opts.isServer:
concatParam(resp, arg.name)
if opts.serverMaxWindowBits == 8:
# zlib does not support windowBits == 8
resp.add "=9"
else:
resp.add res.get()
of "client_max_window_bits":
let res = validateWindowBits(arg, opts.clientMaxWindowBits)
if res.isErr:
return res
if not opts.isServer:
concatParam(resp, arg.name)
if opts.clientMaxWindowBits == 8:
# zlib does not support windowBits == 8
resp.add "=9"
else:
resp.add res.get()
else:
return err("unrecognized param: " & arg.name)
ok(resp)
proc getWindowBits(opts: DeflateOpts, isServer: bool): ZWindowBits =
if isServer:
if opts.serverMaxWindowBits == 0:
Z_RAW_DEFLATE
else:
ZWindowBits(-opts.serverMaxWindowBits)
else:
if opts.clientMaxWindowBits == 0:
Z_RAW_DEFLATE
else:
ZWindowBits(-opts.clientMaxWindowBits)
proc getContextTakeover(opts: DeflateOpts, isServer: bool): bool =
if isServer:
opts.serverNoContextTakeOver
else:
opts.clientNoContextTakeOver
proc decompressInit(ext: DeflateExt) =
# decompression using `client_` prefixed config
let windowBits = getWindowBits(ext.opts, not ext.opts.isServer)
doAssert(ext.decompCtx.inflateInit2(windowBits) == Z_OK)
ext.decompCtxState = ContextState.Initialized
proc compressInit(ext: DeflateExt) =
# compression using `server_` prefixed config
let windowBits = getWindowBits(ext.opts, ext.opts.isServer)
doAssert(ext.compCtx.deflateInit2(
level = ext.opts.level,
meth = Z_DEFLATED,
windowBits,
memLevel = ext.opts.memLevel,
strategy = ext.opts.strategy) == Z_OK
)
ext.compCtxState = ContextState.Initialized
proc compress(zs: var ZStream, data: openArray[byte]): seq[byte] =
var buf: array[0xFFFF, byte]
# these casting is needed to prevent compilation
# error with CLANG
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
zs.avail_in = data.len.cuint
while true:
zs.next_out = cast[ptr cuchar](buf[0].addr)
zs.avail_out = buf.len.cuint
let r = zs.deflate(Z_SYNC_FLUSH)
let outSize = buf.len - zs.avail_out.int
result.add toOpenArray(buf, 0, outSize-1)
if r == Z_STREAM_END:
break
elif r == Z_OK:
# need more input or more output available
if zs.avail_in > 0 or zs.avail_out == 0:
continue
else:
break
else:
raise newException(WSExtError, "compression error " & $r)
proc decompress(zs: var ZStream, limit: int, data: openArray[byte]): seq[byte] =
var buf: array[0xFFFF, byte]
# these casting is needed to prevent compilation
# error with CLANG
zs.next_in = cast[ptr cuchar](data[0].unsafeAddr)
zs.avail_in = data.len.cuint
while true:
zs.next_out = cast[ptr cuchar](buf[0].addr)
zs.avail_out = buf.len.cuint
let r = zs.inflate(Z_NO_FLUSH)
let outSize = buf.len - zs.avail_out.int
result.add toOpenArray(buf, 0, outSize-1)
if result.len > limit:
raise newException(WSExtError, "decompression exceeds allowed limit")
if r == Z_STREAM_END:
break
elif r == Z_OK:
# need more input or more output available
if zs.avail_in > 0 or zs.avail_out == 0:
continue
else:
break
else:
raise newException(WSExtError, "decompression error " & $r)
return result
method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
# only data frames can be decompressed
return frame
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# we want to know if this message is compressed or not
# if the frame opcode is text or binary, it should also the first frame
ext.compressedMsg = frame.rsv1
# clear rsv1 bit because we already done with it
frame.rsv1 = false
if not ext.compressedMsg:
# don't bother with uncompressed message
return frame
if ext.decompCtxState == ContextState.Invalid:
ext.decompressInit()
# even though the frame.data.len == 0, the stream needs
# to be closed with trailing bytes if it's a final frame
var data: seq[byte]
var buf: array[0xFFFF, byte]
while data.len < frame.length.int:
let len = min(frame.length.int - data.len, buf.len)
let read = await frame.read(ext.session.stream.reader, addr buf[0], len)
data.add toOpenArray(buf, 0, read - 1)
if data.len > ext.session.frameSize:
raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size")
if frame.fin:
data.add TrailingBytes
frame.data = decompress(ext.decompCtx, ext.opts.decompressLimit, data)
trace "DeflateExt decompress", input=frame.length, output=frame.data.len
frame.length = frame.data.len.uint64
frame.offset = 0
frame.consumed = 0
frame.mask = false # clear mask flag, decompressed content is not masked
if frame.fin:
# decompression using `client_` prefixed config
let noContextTakeover = getContextTakeover(ext.opts, not ext.opts.isServer)
if noContextTakeover:
doAssert(ext.decompCtx.inflateReset() == Z_OK)
ext.decompCtxState = ContextState.Reset
return frame
method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} =
if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}:
# only data frames can be compressed
return frame
if frame.opcode in {Opcode.Text, Opcode.Binary}:
# we only set rsv1 bit to true if the message is compressible
# and only set the first frame's rsv1
# if the frame opcode is text or binary, it should also the first frame
ext.compressedMsg = frame.data.len >= ext.opts.threshold
frame.rsv1 = ext.compressedMsg
if not ext.compressedMsg:
# don't bother with incompressible message
return frame
if ext.compCtxState == ContextState.Invalid:
ext.compressInit()
frame.length = frame.data.len.uint64
frame.data = compress(ext.compCtx, frame.data)
trace "DeflateExt compress", input=frame.length, output=frame.data.len
if frame.fin:
# remove trailing bytes
when not defined(release):
var trailer: array[4, byte]
trailer[0] = frame.data[^4]
trailer[1] = frame.data[^3]
trailer[2] = frame.data[^2]
trailer[3] = frame.data[^1]
doAssert trailer == TrailingBytes
frame.data.setLen(frame.data.len - 4)
frame.length = frame.data.len.uint64
frame.offset = 0
frame.consumed = 0
if frame.fin:
# compression using `server_` prefixed config
let noContextTakeover = getContextTakeover(ext.opts, ext.opts.isServer)
if noContextTakeover:
doAssert(ext.compCtx.deflateReset() == Z_OK)
ext.compCtxState = ContextState.Reset
return frame
method toHttpOptions(ext: DeflateExt): string =
# using paramStr here is a bit clunky
extID & ext.paramStr
proc destroyExt(ext: DeflateExt) =
if ext.compCtxState != ContextState.Invalid:
# zlib.deflateEnd somehow return DATA_ERROR
# when compression succeed some cases.
# we forget to do something?
discard ext.compCtx.deflateEnd()
ext.compCtxState = ContextState.Invalid
if ext.decompCtxState != ContextState.Invalid:
doAssert(ext.decompCtx.inflateEnd() == Z_OK)
ext.decompCtxState = ContextState.Invalid
proc makeOffer(
clientNoContextTakeOver: bool,
clientMaxWindowBits: int): string =
var param = extID
if clientMaxWindowBits in {9..15}:
param.add "; client_max_window_bits=" & $clientMaxWindowBits
else:
param.add "; client_max_window_bits"
if clientNoContextTakeOver:
param.add "; client_no_context_takeover"
param
proc deflateFactory*(
threshold = ExtDeflateThreshold,
decompressLimit = ExtDeflateDecompressLimit,
level = Z_DEFAULT_LEVEL,
strategy = Z_DEFAULT_STRATEGY,
memLevel = Z_DEFAULT_MEM_LEVEL,
clientNoContextTakeOver = false,
clientMaxWindowBits = 15): ExtFactory =
proc factory(isServer: bool,
args: seq[ExtParam]): Result[Ext, string] {.
gcsafe, raises: [Defect].} =
# capture user configuration via closure
var opts = DeflateOpts(
isServer: isServer,
threshold: threshold,
decompressLimit: decompressLimit,
level: level,
strategy: strategy,
memLevel: memLevel
)
let resp = createParams(args, opts)
if resp.isErr:
return err(resp.error)
var ext: DeflateExt
ext.new(destroyExt)
ext.name = extID
ext.paramStr = resp.get()
ext.opts = opts
ext.compressedMsg = false
ext.compCtxState = ContextState.Invalid
ext.decompCtxState= ContextState.Invalid
ok(ext)
ExtFactory(
name: extID,
factory: factory,
clientOffer: makeOffer(
clientNoContextTakeOver,
clientMaxWindowBits
)
)