fixes related to extensions

- set RSV bits in frame.encode
- move ExtParams from extutils.nim to types.nim
- remodel extension factory type
- accept/reject extensions offer in server
- offer/accept extensions in client
This commit is contained in:
jangko 2021-06-15 21:27:56 +07:00
parent 5af418f850
commit a96a123bfe
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
4 changed files with 128 additions and 25 deletions

View File

@ -9,13 +9,10 @@
import
std/strutils,
pkg/httputils
pkg/httputils,
../types
type
ExtParam* = object
name* : string
value*: string
AppExt* = object
name* : string
params*: seq[ExtParam]
@ -140,11 +137,11 @@ proc parseExt*[T: BChar](data: openarray[T], output: var seq[AppExt]): bool =
ext.params[^1].name = system.move(param.name)
ext.params[^1].value = system.move(param.value)
if lex.tok notin {tkSemCol, tkComma, tkEof}:
if lex.tok notin {tkSemCol, tkComma, tkEof}:
return false
output.setLen(output.len + 1)
output[^1].name = system.move(ext.name)
output[^1].name = toLowerAscii(ext.name)
output[^1].params = system.move(ext.params)
if lex.tok == tkEof:

View File

@ -99,7 +99,13 @@ proc encode*(
var ret: seq[byte]
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
if f.fin:
b0 = b0 or 128'u8
b0 = b0 or 0x80'u8
if f.rsv1:
b0 = b0 or 0x40'u8
if f.rsv2:
b0 = b0 or 0x20'u8
if f.rsv3:
b0 = b0 or 0x10'u8
ret.add(b0)
@ -218,6 +224,9 @@ proc decode*(
for i in countdown(extensions.high, extensions.low):
frame = await extensions[i].decode(frame)
# we check rsv bits after extensions,
# because they have special meaning for extensions.
# rsv bits will be cleared by extensions if they are set by peer.
# If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")

View File

@ -9,8 +9,7 @@
{.push raises: [Defect].}
import std/tables
import pkg/[chronos, chronos/streams/tlsstream]
import pkg/[chronos, chronos/streams/tlsstream, stew/results]
import ./utils
const
@ -96,13 +95,21 @@ type
Ext* = ref object of RootObj
name*: string
options*: Table[string, string]
session*: WSSession
ExtFactory* = proc(
name: string,
session: WSSession,
options: Table[string, string]): Ext {.raises: [Defect].}
ExtParam* = object
name* : string
value*: string
ExtFactoryProc* = proc(
isServer: bool,
args: seq[ExtParam]): Result[Ext, string] {.
gcsafe, raises: [Defect].}
ExtFactory* = object
name*: string
factory*: ExtFactoryProc
clientOffer*: string
WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError
@ -124,6 +131,7 @@ type
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError
WSInvalidUTF8* = object of WebSocketError
WSExtError* = object of WebSocketError
const
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))

107
ws/ws.nim
View File

@ -27,7 +27,7 @@ import pkg/[chronos,
stew/base10,
nimcrypto/sha]
import ./utils, ./frame, ./session, /types, ./http
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
export utils, session, frame, types, http
@ -45,11 +45,69 @@ func toException(e: string): ref WebSocketError =
func toException(e: cstring): ref WebSocketError =
(ref WebSocketError)(msg: $e)
func contains(extensions: openArray[Ext], extName: string): bool =
for ext in extensions:
if ext.name == extName:
return true
proc getFactory(factories: openArray[ExtFactory], extName: string): ExtFactoryProc =
for n in factories:
if n.name == extName:
return n.factory
proc selectExt(isServer: bool,
extensions: var seq[Ext],
factories: openArray[ExtFactory],
exts: openArray[string]): string {.raises: [Defect, WSExtError].} =
var extList: seq[AppExt]
var response = ""
for ext in exts:
# each of "Sec-WebSocket-Extensions" can have multiple
# extensions or fallback extension
if not parseExt(ext, extList):
raise newException(WSExtError, "extension syntax error: " & ext)
for i, ext in extList:
if extensions.contains(ext.name):
# don't accept this fallback if prev ext
# configuration already accepted
trace "extension fallback not accepted", ext=ext.name
continue
# now look for right factory
let factory = factories.getFactory(ext.name)
if factory.isNil:
# no factory? it's ok, just skip it
trace "no extension factory", ext=ext.name
continue
let extRes = factory(isServer, ext.params)
if extRes.isErr:
# cannot create extension because of
# wrong/incompatible params? skip or fallback
trace "skip extension", ext=ext.name, msg=extRes.error
continue
let ext = extRes.get()
doAssert(not ext.isNil)
if i > 0:
# add separator if more than one exts
response.add ", "
response.add ext.toHttpOptions
# finally, accept the extension
trace "extension accepted", ext=ext.name
extensions.add ext
# HTTP response for "Sec-WebSocket-Extensions"
response
proc connect*(
_: type WebSocket,
uri: Uri,
protocols: seq[string] = @[],
extensions: seq[Ext] = @[],
factories: seq[ExtFactory] = @[],
flags: set[TLSFlags] = {},
version = WSDefaultVersion,
frameSize = WSDefaultFrameSize,
@ -86,6 +144,15 @@ proc connect*(
if protocols.len > 0:
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
var extOffer = ""
for i, f in factories:
if i > 0:
extOffer.add ", "
extOffer.add f.clientOffer
if extOffer.len > 0:
headers.add("Sec-WebSocket-Extensions", extOffer)
let response = try:
await client.request(uri, headers = headers)
except CatchableError as exc:
@ -105,24 +172,33 @@ proc connect*(
raise newException(WSFailedUpgradeError,
&"Invalid protocol returned {proto}!")
var extensions: seq[Ext]
let exts = response.headers.getList("Sec-WebSocket-Extensions")
discard selectExt(false, extensions, factories, exts)
# Client data should be masked.
return WSSession(
let session = WSSession(
stream: client.stream,
readyState: ReadyState.Open,
masked: true,
extensions: @extensions,
extensions: system.move(extensions),
rng: rng,
frameSize: frameSize,
onPing: onPing,
onPong: onPong,
onClose: onClose)
for ext in session.extensions:
ext.session = session
return session
proc connect*(
_: type WebSocket,
address: TransportAddress,
path: string,
protocols: seq[string] = @[],
extensions: seq[Ext] = @[],
factories: seq[ExtFactory] = @[],
secure = false,
flags: set[TLSFlags] = {},
version = WSDefaultVersion,
@ -149,7 +225,7 @@ proc connect*(
return await WebSocket.connect(
uri = parseUri(uri),
protocols = protocols,
extensions = extensions,
factories = factories,
flags = flags,
version = version,
frameSize = frameSize,
@ -163,7 +239,7 @@ proc connect*(
port: Port,
path: string,
protocols: seq[string] = @[],
extensions: seq[Ext] = @[],
factories: seq[ExtFactory] = @[],
secure = false,
flags: set[TLSFlags] = {},
version = WSDefaultVersion,
@ -177,7 +253,7 @@ proc connect*(
address = initTAddress(host, port),
path = path,
protocols = protocols,
extensions = extensions,
factories = factories,
secure = secure,
flags = flags,
version = version,
@ -242,6 +318,13 @@ proc handleRequest*(
else:
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
# it is possible to have multiple "Sec-WebSocket-Extensions"
let exts = request.headers.getList("Sec-WebSocket-Extensions")
let extResp = selectExt(true, ws.extensions, ws.factories, exts)
if extResp.len > 0:
# send back any accepted extensions
headers.add("Sec-WebSocket-Extensions", extResp)
try:
await request.sendResponse(Http101, headers = headers)
except CancelledError as exc:
@ -250,10 +333,11 @@ proc handleRequest*(
raise newException(WSHandshakeError,
"Failed to sent handshake response. Error: " & exc.msg)
return WSSession(
let session = WSSession(
readyState: ReadyState.Open,
stream: request.stream,
proto: protocol,
extensions: system.move(ws.extensions),
masked: false,
rng: ws.rng,
frameSize: ws.frameSize,
@ -261,6 +345,11 @@ proc handleRequest*(
onPong: ws.onPong,
onClose: ws.onClose)
for ext in session.extensions:
ext.session = session
return session
proc new*(
_: typedesc[WSServer],
protos: openArray[string] = [""],