fixes related to extensions
- set RSV bits in frame.encode - move ExtParams from extutils.nim to types.nim - remodel extension factory type - accept/reject extensions offer in server - offer/accept extensions in client
This commit is contained in:
parent
5af418f850
commit
a96a123bfe
|
@ -9,13 +9,10 @@
|
|||
|
||||
import
|
||||
std/strutils,
|
||||
pkg/httputils
|
||||
pkg/httputils,
|
||||
../types
|
||||
|
||||
type
|
||||
ExtParam* = object
|
||||
name* : string
|
||||
value*: string
|
||||
|
||||
AppExt* = object
|
||||
name* : string
|
||||
params*: seq[ExtParam]
|
||||
|
@ -140,11 +137,11 @@ proc parseExt*[T: BChar](data: openarray[T], output: var seq[AppExt]): bool =
|
|||
ext.params[^1].name = system.move(param.name)
|
||||
ext.params[^1].value = system.move(param.value)
|
||||
|
||||
if lex.tok notin {tkSemCol, tkComma, tkEof}:
|
||||
if lex.tok notin {tkSemCol, tkComma, tkEof}:
|
||||
return false
|
||||
|
||||
|
||||
output.setLen(output.len + 1)
|
||||
output[^1].name = system.move(ext.name)
|
||||
output[^1].name = toLowerAscii(ext.name)
|
||||
output[^1].params = system.move(ext.params)
|
||||
|
||||
if lex.tok == tkEof:
|
||||
|
|
11
ws/frame.nim
11
ws/frame.nim
|
@ -99,7 +99,13 @@ proc encode*(
|
|||
var ret: seq[byte]
|
||||
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
||||
if f.fin:
|
||||
b0 = b0 or 128'u8
|
||||
b0 = b0 or 0x80'u8
|
||||
if f.rsv1:
|
||||
b0 = b0 or 0x40'u8
|
||||
if f.rsv2:
|
||||
b0 = b0 or 0x20'u8
|
||||
if f.rsv3:
|
||||
b0 = b0 or 0x10'u8
|
||||
|
||||
ret.add(b0)
|
||||
|
||||
|
@ -218,6 +224,9 @@ proc decode*(
|
|||
for i in countdown(extensions.high, extensions.low):
|
||||
frame = await extensions[i].decode(frame)
|
||||
|
||||
# we check rsv bits after extensions,
|
||||
# because they have special meaning for extensions.
|
||||
# rsv bits will be cleared by extensions if they are set by peer.
|
||||
# If any of the rsv are set close the socket.
|
||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||
|
|
22
ws/types.nim
22
ws/types.nim
|
@ -9,8 +9,7 @@
|
|||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import std/tables
|
||||
import pkg/[chronos, chronos/streams/tlsstream]
|
||||
import pkg/[chronos, chronos/streams/tlsstream, stew/results]
|
||||
import ./utils
|
||||
|
||||
const
|
||||
|
@ -96,13 +95,21 @@ type
|
|||
|
||||
Ext* = ref object of RootObj
|
||||
name*: string
|
||||
options*: Table[string, string]
|
||||
session*: WSSession
|
||||
|
||||
ExtFactory* = proc(
|
||||
name: string,
|
||||
session: WSSession,
|
||||
options: Table[string, string]): Ext {.raises: [Defect].}
|
||||
ExtParam* = object
|
||||
name* : string
|
||||
value*: string
|
||||
|
||||
ExtFactoryProc* = proc(
|
||||
isServer: bool,
|
||||
args: seq[ExtParam]): Result[Ext, string] {.
|
||||
gcsafe, raises: [Defect].}
|
||||
|
||||
ExtFactory* = object
|
||||
name*: string
|
||||
factory*: ExtFactoryProc
|
||||
clientOffer*: string
|
||||
|
||||
WebSocketError* = object of CatchableError
|
||||
WSMalformedHeaderError* = object of WebSocketError
|
||||
|
@ -124,6 +131,7 @@ type
|
|||
WSPayloadLengthError* = object of WebSocketError
|
||||
WSInvalidOpcodeError* = object of WebSocketError
|
||||
WSInvalidUTF8* = object of WebSocketError
|
||||
WSExtError* = object of WebSocketError
|
||||
|
||||
const
|
||||
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
||||
|
|
107
ws/ws.nim
107
ws/ws.nim
|
@ -27,7 +27,7 @@ import pkg/[chronos,
|
|||
stew/base10,
|
||||
nimcrypto/sha]
|
||||
|
||||
import ./utils, ./frame, ./session, /types, ./http
|
||||
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
||||
|
||||
export utils, session, frame, types, http
|
||||
|
||||
|
@ -45,11 +45,69 @@ func toException(e: string): ref WebSocketError =
|
|||
func toException(e: cstring): ref WebSocketError =
|
||||
(ref WebSocketError)(msg: $e)
|
||||
|
||||
func contains(extensions: openArray[Ext], extName: string): bool =
|
||||
for ext in extensions:
|
||||
if ext.name == extName:
|
||||
return true
|
||||
|
||||
proc getFactory(factories: openArray[ExtFactory], extName: string): ExtFactoryProc =
|
||||
for n in factories:
|
||||
if n.name == extName:
|
||||
return n.factory
|
||||
|
||||
proc selectExt(isServer: bool,
|
||||
extensions: var seq[Ext],
|
||||
factories: openArray[ExtFactory],
|
||||
exts: openArray[string]): string {.raises: [Defect, WSExtError].} =
|
||||
|
||||
var extList: seq[AppExt]
|
||||
var response = ""
|
||||
for ext in exts:
|
||||
# each of "Sec-WebSocket-Extensions" can have multiple
|
||||
# extensions or fallback extension
|
||||
if not parseExt(ext, extList):
|
||||
raise newException(WSExtError, "extension syntax error: " & ext)
|
||||
|
||||
for i, ext in extList:
|
||||
if extensions.contains(ext.name):
|
||||
# don't accept this fallback if prev ext
|
||||
# configuration already accepted
|
||||
trace "extension fallback not accepted", ext=ext.name
|
||||
continue
|
||||
|
||||
# now look for right factory
|
||||
let factory = factories.getFactory(ext.name)
|
||||
if factory.isNil:
|
||||
# no factory? it's ok, just skip it
|
||||
trace "no extension factory", ext=ext.name
|
||||
continue
|
||||
|
||||
let extRes = factory(isServer, ext.params)
|
||||
if extRes.isErr:
|
||||
# cannot create extension because of
|
||||
# wrong/incompatible params? skip or fallback
|
||||
trace "skip extension", ext=ext.name, msg=extRes.error
|
||||
continue
|
||||
|
||||
let ext = extRes.get()
|
||||
doAssert(not ext.isNil)
|
||||
if i > 0:
|
||||
# add separator if more than one exts
|
||||
response.add ", "
|
||||
response.add ext.toHttpOptions
|
||||
|
||||
# finally, accept the extension
|
||||
trace "extension accepted", ext=ext.name
|
||||
extensions.add ext
|
||||
|
||||
# HTTP response for "Sec-WebSocket-Extensions"
|
||||
response
|
||||
|
||||
proc connect*(
|
||||
_: type WebSocket,
|
||||
uri: Uri,
|
||||
protocols: seq[string] = @[],
|
||||
extensions: seq[Ext] = @[],
|
||||
factories: seq[ExtFactory] = @[],
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
|
@ -86,6 +144,15 @@ proc connect*(
|
|||
if protocols.len > 0:
|
||||
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
|
||||
|
||||
var extOffer = ""
|
||||
for i, f in factories:
|
||||
if i > 0:
|
||||
extOffer.add ", "
|
||||
extOffer.add f.clientOffer
|
||||
|
||||
if extOffer.len > 0:
|
||||
headers.add("Sec-WebSocket-Extensions", extOffer)
|
||||
|
||||
let response = try:
|
||||
await client.request(uri, headers = headers)
|
||||
except CatchableError as exc:
|
||||
|
@ -105,24 +172,33 @@ proc connect*(
|
|||
raise newException(WSFailedUpgradeError,
|
||||
&"Invalid protocol returned {proto}!")
|
||||
|
||||
var extensions: seq[Ext]
|
||||
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
||||
discard selectExt(false, extensions, factories, exts)
|
||||
|
||||
# Client data should be masked.
|
||||
return WSSession(
|
||||
let session = WSSession(
|
||||
stream: client.stream,
|
||||
readyState: ReadyState.Open,
|
||||
masked: true,
|
||||
extensions: @extensions,
|
||||
extensions: system.move(extensions),
|
||||
rng: rng,
|
||||
frameSize: frameSize,
|
||||
onPing: onPing,
|
||||
onPong: onPong,
|
||||
onClose: onClose)
|
||||
|
||||
for ext in session.extensions:
|
||||
ext.session = session
|
||||
|
||||
return session
|
||||
|
||||
proc connect*(
|
||||
_: type WebSocket,
|
||||
address: TransportAddress,
|
||||
path: string,
|
||||
protocols: seq[string] = @[],
|
||||
extensions: seq[Ext] = @[],
|
||||
factories: seq[ExtFactory] = @[],
|
||||
secure = false,
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
|
@ -149,7 +225,7 @@ proc connect*(
|
|||
return await WebSocket.connect(
|
||||
uri = parseUri(uri),
|
||||
protocols = protocols,
|
||||
extensions = extensions,
|
||||
factories = factories,
|
||||
flags = flags,
|
||||
version = version,
|
||||
frameSize = frameSize,
|
||||
|
@ -163,7 +239,7 @@ proc connect*(
|
|||
port: Port,
|
||||
path: string,
|
||||
protocols: seq[string] = @[],
|
||||
extensions: seq[Ext] = @[],
|
||||
factories: seq[ExtFactory] = @[],
|
||||
secure = false,
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
|
@ -177,7 +253,7 @@ proc connect*(
|
|||
address = initTAddress(host, port),
|
||||
path = path,
|
||||
protocols = protocols,
|
||||
extensions = extensions,
|
||||
factories = factories,
|
||||
secure = secure,
|
||||
flags = flags,
|
||||
version = version,
|
||||
|
@ -242,6 +318,13 @@ proc handleRequest*(
|
|||
else:
|
||||
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
||||
|
||||
# it is possible to have multiple "Sec-WebSocket-Extensions"
|
||||
let exts = request.headers.getList("Sec-WebSocket-Extensions")
|
||||
let extResp = selectExt(true, ws.extensions, ws.factories, exts)
|
||||
if extResp.len > 0:
|
||||
# send back any accepted extensions
|
||||
headers.add("Sec-WebSocket-Extensions", extResp)
|
||||
|
||||
try:
|
||||
await request.sendResponse(Http101, headers = headers)
|
||||
except CancelledError as exc:
|
||||
|
@ -250,10 +333,11 @@ proc handleRequest*(
|
|||
raise newException(WSHandshakeError,
|
||||
"Failed to sent handshake response. Error: " & exc.msg)
|
||||
|
||||
return WSSession(
|
||||
let session = WSSession(
|
||||
readyState: ReadyState.Open,
|
||||
stream: request.stream,
|
||||
proto: protocol,
|
||||
extensions: system.move(ws.extensions),
|
||||
masked: false,
|
||||
rng: ws.rng,
|
||||
frameSize: ws.frameSize,
|
||||
|
@ -261,6 +345,11 @@ proc handleRequest*(
|
|||
onPong: ws.onPong,
|
||||
onClose: ws.onClose)
|
||||
|
||||
for ext in session.extensions:
|
||||
ext.session = session
|
||||
|
||||
return session
|
||||
|
||||
proc new*(
|
||||
_: typedesc[WSServer],
|
||||
protos: openArray[string] = [""],
|
||||
|
|
Loading…
Reference in New Issue