diff --git a/ws/extensions/extutils.nim b/ws/extensions/extutils.nim index 3c334ec3da..21add3a1f9 100644 --- a/ws/extensions/extutils.nim +++ b/ws/extensions/extutils.nim @@ -9,13 +9,10 @@ import std/strutils, - pkg/httputils + pkg/httputils, + ../types type - ExtParam* = object - name* : string - value*: string - AppExt* = object name* : string params*: seq[ExtParam] @@ -140,11 +137,11 @@ proc parseExt*[T: BChar](data: openarray[T], output: var seq[AppExt]): bool = ext.params[^1].name = system.move(param.name) ext.params[^1].value = system.move(param.value) - if lex.tok notin {tkSemCol, tkComma, tkEof}: + if lex.tok notin {tkSemCol, tkComma, tkEof}: return false - + output.setLen(output.len + 1) - output[^1].name = system.move(ext.name) + output[^1].name = toLowerAscii(ext.name) output[^1].params = system.move(ext.params) if lex.tok == tkEof: diff --git a/ws/frame.nim b/ws/frame.nim index 1aeff78cb2..399a977c64 100644 --- a/ws/frame.nim +++ b/ws/frame.nim @@ -99,7 +99,13 @@ proc encode*( var ret: seq[byte] var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. if f.fin: - b0 = b0 or 128'u8 + b0 = b0 or 0x80'u8 + if f.rsv1: + b0 = b0 or 0x40'u8 + if f.rsv2: + b0 = b0 or 0x20'u8 + if f.rsv3: + b0 = b0 or 0x10'u8 ret.add(b0) @@ -218,6 +224,9 @@ proc decode*( for i in countdown(extensions.high, extensions.low): frame = await extensions[i].decode(frame) + # we check rsv bits after extensions, + # because they have special meaning for extensions. + # rsv bits will be cleared by extensions if they are set by peer. # If any of the rsv are set close the socket. if frame.rsv1 or frame.rsv2 or frame.rsv3: raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") diff --git a/ws/types.nim b/ws/types.nim index 01ab81ab8e..faf654e16c 100644 --- a/ws/types.nim +++ b/ws/types.nim @@ -9,8 +9,7 @@ {.push raises: [Defect].} -import std/tables -import pkg/[chronos, chronos/streams/tlsstream] +import pkg/[chronos, chronos/streams/tlsstream, stew/results] import ./utils const @@ -96,13 +95,21 @@ type Ext* = ref object of RootObj name*: string - options*: Table[string, string] session*: WSSession - ExtFactory* = proc( - name: string, - session: WSSession, - options: Table[string, string]): Ext {.raises: [Defect].} + ExtParam* = object + name* : string + value*: string + + ExtFactoryProc* = proc( + isServer: bool, + args: seq[ExtParam]): Result[Ext, string] {. + gcsafe, raises: [Defect].} + + ExtFactory* = object + name*: string + factory*: ExtFactoryProc + clientOffer*: string WebSocketError* = object of CatchableError WSMalformedHeaderError* = object of WebSocketError @@ -124,6 +131,7 @@ type WSPayloadLengthError* = object of WebSocketError WSInvalidOpcodeError* = object of WebSocketError WSInvalidUTF8* = object of WebSocketError + WSExtError* = object of WebSocketError const StatusNotUsed* = (StatusCodes(0)..StatusCodes(999)) diff --git a/ws/ws.nim b/ws/ws.nim index a43ce4c924..2410445f75 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -27,7 +27,7 @@ import pkg/[chronos, stew/base10, nimcrypto/sha] -import ./utils, ./frame, ./session, /types, ./http +import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils export utils, session, frame, types, http @@ -45,11 +45,69 @@ func toException(e: string): ref WebSocketError = func toException(e: cstring): ref WebSocketError = (ref WebSocketError)(msg: $e) +func contains(extensions: openArray[Ext], extName: string): bool = + for ext in extensions: + if ext.name == extName: + return true + +proc getFactory(factories: openArray[ExtFactory], extName: string): ExtFactoryProc = + for n in factories: + if n.name == extName: + return n.factory + +proc selectExt(isServer: bool, + extensions: var seq[Ext], + factories: openArray[ExtFactory], + exts: openArray[string]): string {.raises: [Defect, WSExtError].} = + + var extList: seq[AppExt] + var response = "" + for ext in exts: + # each of "Sec-WebSocket-Extensions" can have multiple + # extensions or fallback extension + if not parseExt(ext, extList): + raise newException(WSExtError, "extension syntax error: " & ext) + + for i, ext in extList: + if extensions.contains(ext.name): + # don't accept this fallback if prev ext + # configuration already accepted + trace "extension fallback not accepted", ext=ext.name + continue + + # now look for right factory + let factory = factories.getFactory(ext.name) + if factory.isNil: + # no factory? it's ok, just skip it + trace "no extension factory", ext=ext.name + continue + + let extRes = factory(isServer, ext.params) + if extRes.isErr: + # cannot create extension because of + # wrong/incompatible params? skip or fallback + trace "skip extension", ext=ext.name, msg=extRes.error + continue + + let ext = extRes.get() + doAssert(not ext.isNil) + if i > 0: + # add separator if more than one exts + response.add ", " + response.add ext.toHttpOptions + + # finally, accept the extension + trace "extension accepted", ext=ext.name + extensions.add ext + + # HTTP response for "Sec-WebSocket-Extensions" + response + proc connect*( _: type WebSocket, uri: Uri, protocols: seq[string] = @[], - extensions: seq[Ext] = @[], + factories: seq[ExtFactory] = @[], flags: set[TLSFlags] = {}, version = WSDefaultVersion, frameSize = WSDefaultFrameSize, @@ -86,6 +144,15 @@ proc connect*( if protocols.len > 0: headers.add("Sec-WebSocket-Protocol", protocols.join(", ")) + var extOffer = "" + for i, f in factories: + if i > 0: + extOffer.add ", " + extOffer.add f.clientOffer + + if extOffer.len > 0: + headers.add("Sec-WebSocket-Extensions", extOffer) + let response = try: await client.request(uri, headers = headers) except CatchableError as exc: @@ -105,24 +172,33 @@ proc connect*( raise newException(WSFailedUpgradeError, &"Invalid protocol returned {proto}!") + var extensions: seq[Ext] + let exts = response.headers.getList("Sec-WebSocket-Extensions") + discard selectExt(false, extensions, factories, exts) + # Client data should be masked. - return WSSession( + let session = WSSession( stream: client.stream, readyState: ReadyState.Open, masked: true, - extensions: @extensions, + extensions: system.move(extensions), rng: rng, frameSize: frameSize, onPing: onPing, onPong: onPong, onClose: onClose) + for ext in session.extensions: + ext.session = session + + return session + proc connect*( _: type WebSocket, address: TransportAddress, path: string, protocols: seq[string] = @[], - extensions: seq[Ext] = @[], + factories: seq[ExtFactory] = @[], secure = false, flags: set[TLSFlags] = {}, version = WSDefaultVersion, @@ -149,7 +225,7 @@ proc connect*( return await WebSocket.connect( uri = parseUri(uri), protocols = protocols, - extensions = extensions, + factories = factories, flags = flags, version = version, frameSize = frameSize, @@ -163,7 +239,7 @@ proc connect*( port: Port, path: string, protocols: seq[string] = @[], - extensions: seq[Ext] = @[], + factories: seq[ExtFactory] = @[], secure = false, flags: set[TLSFlags] = {}, version = WSDefaultVersion, @@ -177,7 +253,7 @@ proc connect*( address = initTAddress(host, port), path = path, protocols = protocols, - extensions = extensions, + factories = factories, secure = secure, flags = flags, version = version, @@ -242,6 +318,13 @@ proc handleRequest*( else: trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos + # it is possible to have multiple "Sec-WebSocket-Extensions" + let exts = request.headers.getList("Sec-WebSocket-Extensions") + let extResp = selectExt(true, ws.extensions, ws.factories, exts) + if extResp.len > 0: + # send back any accepted extensions + headers.add("Sec-WebSocket-Extensions", extResp) + try: await request.sendResponse(Http101, headers = headers) except CancelledError as exc: @@ -250,10 +333,11 @@ proc handleRequest*( raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) - return WSSession( + let session = WSSession( readyState: ReadyState.Open, stream: request.stream, proto: protocol, + extensions: system.move(ws.extensions), masked: false, rng: ws.rng, frameSize: ws.frameSize, @@ -261,6 +345,11 @@ proc handleRequest*( onPong: ws.onPong, onClose: ws.onClose) + for ext in session.extensions: + ext.session = session + + return session + proc new*( _: typedesc[WSServer], protos: openArray[string] = [""],