mirror of
https://github.com/status-im/nim-websock.git
synced 2025-02-19 12:58:21 +00:00
fixes related to extensions
- set RSV bits in frame.encode - move ExtParams from extutils.nim to types.nim - remodel extension factory type - accept/reject extensions offer in server - offer/accept extensions in client
This commit is contained in:
parent
5af418f850
commit
a96a123bfe
@ -9,13 +9,10 @@
|
|||||||
|
|
||||||
import
|
import
|
||||||
std/strutils,
|
std/strutils,
|
||||||
pkg/httputils
|
pkg/httputils,
|
||||||
|
../types
|
||||||
|
|
||||||
type
|
type
|
||||||
ExtParam* = object
|
|
||||||
name* : string
|
|
||||||
value*: string
|
|
||||||
|
|
||||||
AppExt* = object
|
AppExt* = object
|
||||||
name* : string
|
name* : string
|
||||||
params*: seq[ExtParam]
|
params*: seq[ExtParam]
|
||||||
@ -144,7 +141,7 @@ proc parseExt*[T: BChar](data: openarray[T], output: var seq[AppExt]): bool =
|
|||||||
return false
|
return false
|
||||||
|
|
||||||
output.setLen(output.len + 1)
|
output.setLen(output.len + 1)
|
||||||
output[^1].name = system.move(ext.name)
|
output[^1].name = toLowerAscii(ext.name)
|
||||||
output[^1].params = system.move(ext.params)
|
output[^1].params = system.move(ext.params)
|
||||||
|
|
||||||
if lex.tok == tkEof:
|
if lex.tok == tkEof:
|
||||||
|
11
ws/frame.nim
11
ws/frame.nim
@ -99,7 +99,13 @@ proc encode*(
|
|||||||
var ret: seq[byte]
|
var ret: seq[byte]
|
||||||
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags.
|
||||||
if f.fin:
|
if f.fin:
|
||||||
b0 = b0 or 128'u8
|
b0 = b0 or 0x80'u8
|
||||||
|
if f.rsv1:
|
||||||
|
b0 = b0 or 0x40'u8
|
||||||
|
if f.rsv2:
|
||||||
|
b0 = b0 or 0x20'u8
|
||||||
|
if f.rsv3:
|
||||||
|
b0 = b0 or 0x10'u8
|
||||||
|
|
||||||
ret.add(b0)
|
ret.add(b0)
|
||||||
|
|
||||||
@ -218,6 +224,9 @@ proc decode*(
|
|||||||
for i in countdown(extensions.high, extensions.low):
|
for i in countdown(extensions.high, extensions.low):
|
||||||
frame = await extensions[i].decode(frame)
|
frame = await extensions[i].decode(frame)
|
||||||
|
|
||||||
|
# we check rsv bits after extensions,
|
||||||
|
# because they have special meaning for extensions.
|
||||||
|
# rsv bits will be cleared by extensions if they are set by peer.
|
||||||
# If any of the rsv are set close the socket.
|
# If any of the rsv are set close the socket.
|
||||||
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
if frame.rsv1 or frame.rsv2 or frame.rsv3:
|
||||||
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
|
||||||
|
22
ws/types.nim
22
ws/types.nim
@ -9,8 +9,7 @@
|
|||||||
|
|
||||||
{.push raises: [Defect].}
|
{.push raises: [Defect].}
|
||||||
|
|
||||||
import std/tables
|
import pkg/[chronos, chronos/streams/tlsstream, stew/results]
|
||||||
import pkg/[chronos, chronos/streams/tlsstream]
|
|
||||||
import ./utils
|
import ./utils
|
||||||
|
|
||||||
const
|
const
|
||||||
@ -96,13 +95,21 @@ type
|
|||||||
|
|
||||||
Ext* = ref object of RootObj
|
Ext* = ref object of RootObj
|
||||||
name*: string
|
name*: string
|
||||||
options*: Table[string, string]
|
|
||||||
session*: WSSession
|
session*: WSSession
|
||||||
|
|
||||||
ExtFactory* = proc(
|
ExtParam* = object
|
||||||
name: string,
|
name* : string
|
||||||
session: WSSession,
|
value*: string
|
||||||
options: Table[string, string]): Ext {.raises: [Defect].}
|
|
||||||
|
ExtFactoryProc* = proc(
|
||||||
|
isServer: bool,
|
||||||
|
args: seq[ExtParam]): Result[Ext, string] {.
|
||||||
|
gcsafe, raises: [Defect].}
|
||||||
|
|
||||||
|
ExtFactory* = object
|
||||||
|
name*: string
|
||||||
|
factory*: ExtFactoryProc
|
||||||
|
clientOffer*: string
|
||||||
|
|
||||||
WebSocketError* = object of CatchableError
|
WebSocketError* = object of CatchableError
|
||||||
WSMalformedHeaderError* = object of WebSocketError
|
WSMalformedHeaderError* = object of WebSocketError
|
||||||
@ -124,6 +131,7 @@ type
|
|||||||
WSPayloadLengthError* = object of WebSocketError
|
WSPayloadLengthError* = object of WebSocketError
|
||||||
WSInvalidOpcodeError* = object of WebSocketError
|
WSInvalidOpcodeError* = object of WebSocketError
|
||||||
WSInvalidUTF8* = object of WebSocketError
|
WSInvalidUTF8* = object of WebSocketError
|
||||||
|
WSExtError* = object of WebSocketError
|
||||||
|
|
||||||
const
|
const
|
||||||
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
||||||
|
107
ws/ws.nim
107
ws/ws.nim
@ -27,7 +27,7 @@ import pkg/[chronos,
|
|||||||
stew/base10,
|
stew/base10,
|
||||||
nimcrypto/sha]
|
nimcrypto/sha]
|
||||||
|
|
||||||
import ./utils, ./frame, ./session, /types, ./http
|
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
||||||
|
|
||||||
export utils, session, frame, types, http
|
export utils, session, frame, types, http
|
||||||
|
|
||||||
@ -45,11 +45,69 @@ func toException(e: string): ref WebSocketError =
|
|||||||
func toException(e: cstring): ref WebSocketError =
|
func toException(e: cstring): ref WebSocketError =
|
||||||
(ref WebSocketError)(msg: $e)
|
(ref WebSocketError)(msg: $e)
|
||||||
|
|
||||||
|
func contains(extensions: openArray[Ext], extName: string): bool =
|
||||||
|
for ext in extensions:
|
||||||
|
if ext.name == extName:
|
||||||
|
return true
|
||||||
|
|
||||||
|
proc getFactory(factories: openArray[ExtFactory], extName: string): ExtFactoryProc =
|
||||||
|
for n in factories:
|
||||||
|
if n.name == extName:
|
||||||
|
return n.factory
|
||||||
|
|
||||||
|
proc selectExt(isServer: bool,
|
||||||
|
extensions: var seq[Ext],
|
||||||
|
factories: openArray[ExtFactory],
|
||||||
|
exts: openArray[string]): string {.raises: [Defect, WSExtError].} =
|
||||||
|
|
||||||
|
var extList: seq[AppExt]
|
||||||
|
var response = ""
|
||||||
|
for ext in exts:
|
||||||
|
# each of "Sec-WebSocket-Extensions" can have multiple
|
||||||
|
# extensions or fallback extension
|
||||||
|
if not parseExt(ext, extList):
|
||||||
|
raise newException(WSExtError, "extension syntax error: " & ext)
|
||||||
|
|
||||||
|
for i, ext in extList:
|
||||||
|
if extensions.contains(ext.name):
|
||||||
|
# don't accept this fallback if prev ext
|
||||||
|
# configuration already accepted
|
||||||
|
trace "extension fallback not accepted", ext=ext.name
|
||||||
|
continue
|
||||||
|
|
||||||
|
# now look for right factory
|
||||||
|
let factory = factories.getFactory(ext.name)
|
||||||
|
if factory.isNil:
|
||||||
|
# no factory? it's ok, just skip it
|
||||||
|
trace "no extension factory", ext=ext.name
|
||||||
|
continue
|
||||||
|
|
||||||
|
let extRes = factory(isServer, ext.params)
|
||||||
|
if extRes.isErr:
|
||||||
|
# cannot create extension because of
|
||||||
|
# wrong/incompatible params? skip or fallback
|
||||||
|
trace "skip extension", ext=ext.name, msg=extRes.error
|
||||||
|
continue
|
||||||
|
|
||||||
|
let ext = extRes.get()
|
||||||
|
doAssert(not ext.isNil)
|
||||||
|
if i > 0:
|
||||||
|
# add separator if more than one exts
|
||||||
|
response.add ", "
|
||||||
|
response.add ext.toHttpOptions
|
||||||
|
|
||||||
|
# finally, accept the extension
|
||||||
|
trace "extension accepted", ext=ext.name
|
||||||
|
extensions.add ext
|
||||||
|
|
||||||
|
# HTTP response for "Sec-WebSocket-Extensions"
|
||||||
|
response
|
||||||
|
|
||||||
proc connect*(
|
proc connect*(
|
||||||
_: type WebSocket,
|
_: type WebSocket,
|
||||||
uri: Uri,
|
uri: Uri,
|
||||||
protocols: seq[string] = @[],
|
protocols: seq[string] = @[],
|
||||||
extensions: seq[Ext] = @[],
|
factories: seq[ExtFactory] = @[],
|
||||||
flags: set[TLSFlags] = {},
|
flags: set[TLSFlags] = {},
|
||||||
version = WSDefaultVersion,
|
version = WSDefaultVersion,
|
||||||
frameSize = WSDefaultFrameSize,
|
frameSize = WSDefaultFrameSize,
|
||||||
@ -86,6 +144,15 @@ proc connect*(
|
|||||||
if protocols.len > 0:
|
if protocols.len > 0:
|
||||||
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
|
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
|
||||||
|
|
||||||
|
var extOffer = ""
|
||||||
|
for i, f in factories:
|
||||||
|
if i > 0:
|
||||||
|
extOffer.add ", "
|
||||||
|
extOffer.add f.clientOffer
|
||||||
|
|
||||||
|
if extOffer.len > 0:
|
||||||
|
headers.add("Sec-WebSocket-Extensions", extOffer)
|
||||||
|
|
||||||
let response = try:
|
let response = try:
|
||||||
await client.request(uri, headers = headers)
|
await client.request(uri, headers = headers)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
@ -105,24 +172,33 @@ proc connect*(
|
|||||||
raise newException(WSFailedUpgradeError,
|
raise newException(WSFailedUpgradeError,
|
||||||
&"Invalid protocol returned {proto}!")
|
&"Invalid protocol returned {proto}!")
|
||||||
|
|
||||||
|
var extensions: seq[Ext]
|
||||||
|
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
||||||
|
discard selectExt(false, extensions, factories, exts)
|
||||||
|
|
||||||
# Client data should be masked.
|
# Client data should be masked.
|
||||||
return WSSession(
|
let session = WSSession(
|
||||||
stream: client.stream,
|
stream: client.stream,
|
||||||
readyState: ReadyState.Open,
|
readyState: ReadyState.Open,
|
||||||
masked: true,
|
masked: true,
|
||||||
extensions: @extensions,
|
extensions: system.move(extensions),
|
||||||
rng: rng,
|
rng: rng,
|
||||||
frameSize: frameSize,
|
frameSize: frameSize,
|
||||||
onPing: onPing,
|
onPing: onPing,
|
||||||
onPong: onPong,
|
onPong: onPong,
|
||||||
onClose: onClose)
|
onClose: onClose)
|
||||||
|
|
||||||
|
for ext in session.extensions:
|
||||||
|
ext.session = session
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
proc connect*(
|
proc connect*(
|
||||||
_: type WebSocket,
|
_: type WebSocket,
|
||||||
address: TransportAddress,
|
address: TransportAddress,
|
||||||
path: string,
|
path: string,
|
||||||
protocols: seq[string] = @[],
|
protocols: seq[string] = @[],
|
||||||
extensions: seq[Ext] = @[],
|
factories: seq[ExtFactory] = @[],
|
||||||
secure = false,
|
secure = false,
|
||||||
flags: set[TLSFlags] = {},
|
flags: set[TLSFlags] = {},
|
||||||
version = WSDefaultVersion,
|
version = WSDefaultVersion,
|
||||||
@ -149,7 +225,7 @@ proc connect*(
|
|||||||
return await WebSocket.connect(
|
return await WebSocket.connect(
|
||||||
uri = parseUri(uri),
|
uri = parseUri(uri),
|
||||||
protocols = protocols,
|
protocols = protocols,
|
||||||
extensions = extensions,
|
factories = factories,
|
||||||
flags = flags,
|
flags = flags,
|
||||||
version = version,
|
version = version,
|
||||||
frameSize = frameSize,
|
frameSize = frameSize,
|
||||||
@ -163,7 +239,7 @@ proc connect*(
|
|||||||
port: Port,
|
port: Port,
|
||||||
path: string,
|
path: string,
|
||||||
protocols: seq[string] = @[],
|
protocols: seq[string] = @[],
|
||||||
extensions: seq[Ext] = @[],
|
factories: seq[ExtFactory] = @[],
|
||||||
secure = false,
|
secure = false,
|
||||||
flags: set[TLSFlags] = {},
|
flags: set[TLSFlags] = {},
|
||||||
version = WSDefaultVersion,
|
version = WSDefaultVersion,
|
||||||
@ -177,7 +253,7 @@ proc connect*(
|
|||||||
address = initTAddress(host, port),
|
address = initTAddress(host, port),
|
||||||
path = path,
|
path = path,
|
||||||
protocols = protocols,
|
protocols = protocols,
|
||||||
extensions = extensions,
|
factories = factories,
|
||||||
secure = secure,
|
secure = secure,
|
||||||
flags = flags,
|
flags = flags,
|
||||||
version = version,
|
version = version,
|
||||||
@ -242,6 +318,13 @@ proc handleRequest*(
|
|||||||
else:
|
else:
|
||||||
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
||||||
|
|
||||||
|
# it is possible to have multiple "Sec-WebSocket-Extensions"
|
||||||
|
let exts = request.headers.getList("Sec-WebSocket-Extensions")
|
||||||
|
let extResp = selectExt(true, ws.extensions, ws.factories, exts)
|
||||||
|
if extResp.len > 0:
|
||||||
|
# send back any accepted extensions
|
||||||
|
headers.add("Sec-WebSocket-Extensions", extResp)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await request.sendResponse(Http101, headers = headers)
|
await request.sendResponse(Http101, headers = headers)
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
@ -250,10 +333,11 @@ proc handleRequest*(
|
|||||||
raise newException(WSHandshakeError,
|
raise newException(WSHandshakeError,
|
||||||
"Failed to sent handshake response. Error: " & exc.msg)
|
"Failed to sent handshake response. Error: " & exc.msg)
|
||||||
|
|
||||||
return WSSession(
|
let session = WSSession(
|
||||||
readyState: ReadyState.Open,
|
readyState: ReadyState.Open,
|
||||||
stream: request.stream,
|
stream: request.stream,
|
||||||
proto: protocol,
|
proto: protocol,
|
||||||
|
extensions: system.move(ws.extensions),
|
||||||
masked: false,
|
masked: false,
|
||||||
rng: ws.rng,
|
rng: ws.rng,
|
||||||
frameSize: ws.frameSize,
|
frameSize: ws.frameSize,
|
||||||
@ -261,6 +345,11 @@ proc handleRequest*(
|
|||||||
onPong: ws.onPong,
|
onPong: ws.onPong,
|
||||||
onClose: ws.onClose)
|
onClose: ws.onClose)
|
||||||
|
|
||||||
|
for ext in session.extensions:
|
||||||
|
ext.session = session
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
proc new*(
|
proc new*(
|
||||||
_: typedesc[WSServer],
|
_: typedesc[WSServer],
|
||||||
protos: openArray[string] = [""],
|
protos: openArray[string] = [""],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user