diff --git a/ws/frame.nim b/ws/frame.nim index 10e46d1..5055a69 100644 --- a/ws/frame.nim +++ b/ws/frame.nim @@ -55,7 +55,7 @@ template remainder*(frame: Frame): uint64 = proc encode*( frame: Frame, offset = 0, - extensions: seq[Extension] = @[]): Future[seq[byte]] {.async.} = + extensions: seq[Ext] = @[]): Future[seq[byte]] {.async.} = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 @@ -117,7 +117,7 @@ proc decode*( _: typedesc[Frame], reader: AsyncStreamReader, masked: bool, - extensions: seq[Extension] = @[]): Future[Frame] {.async.} = + extensions: seq[Ext] = @[]): Future[Frame] {.async.} = ## Read and Decode incoming header ## diff --git a/ws/session.nim b/ws/session.nim index 7f36680..c9830a4 100644 --- a/ws/session.nim +++ b/ws/session.nim @@ -61,7 +61,7 @@ proc send*( mask: ws.masked, data: data, # allow sending data with close messages maskKey: maskKey) - .encode())) + .encode(extensions = ws.extensions))) return @@ -212,7 +212,8 @@ proc readFrame*(ws: WSSession): Future[Frame] {.async.} = ## while ws.readyState != ReadyState.Closed: - let frame = await Frame.decode(ws.stream.reader, ws.masked) + let frame = await Frame.decode( + ws.stream.reader, ws.masked, ws.extensions) debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask # return the current frame if it's not one of the control frames diff --git a/ws/types.nim b/ws/types.nim index 8dc08aa..26729f8 100644 --- a/ws/types.nim +++ b/ws/types.nim @@ -9,6 +9,7 @@ {.push raises: [Defect].} +import std/tables import pkg/[chronos, chronos/streams/tlsstream] import ./utils @@ -89,11 +90,15 @@ type CloseCb* = proc(code: Status, reason: string): CloseResult {.gcsafe, raises: [Defect].} - Extension* = ref object of RootObj + Ext* = ref object of RootObj name*: string + options*: Table[string, string] + + ExtFactory* = proc(name: string, options: Table[string, string]): + Ext {.raises: [Defect].} WebSocket* = ref object of RootObj - extensions: seq[Extension] # extension active for this session + extensions*: seq[Ext] version*: uint key*: string readyState*: ReadyState @@ -127,11 +132,14 @@ type WSInvalidOpcodeError* = object of WebSocketError WSInvalidUTF8* = object of WebSocketError -proc `name=`*(self: Extension, name: string) = +proc `name=`*(self: Ext, name: string) = raiseAssert "Can't change extensions name!" -method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = +method decode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = raiseAssert "Not implemented!" -method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = +method encode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = + raiseAssert "Not implemented!" + +method toHttpOptions*(self: Ext): string = raiseAssert "Not implemented!" diff --git a/ws/ws.nim b/ws/ws.nim index 280ced2..c2017ec 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -35,6 +35,7 @@ export utils, session, frame, types, http type WSServer* = ref object of WebSocket protocols: seq[string] + factories: seq[ExtFactory] func toException(e: string): ref WebSocketError = (ref WebSocketError)(msg: e) @@ -46,6 +47,7 @@ proc connect*( _: type WebSocket, uri: Uri, protocols: seq[string] = @[], + extensions: seq[Ext] = @[], flags: set[TLSFlags] = {}, version = WSDefaultVersion, frameSize = WSDefaultFrameSize, @@ -105,6 +107,7 @@ proc connect*( stream: client.stream, readyState: ReadyState.Open, masked: true, + extensions: @extensions, rng: rng, frameSize: frameSize, onPing: onPing, @@ -116,6 +119,7 @@ proc connect*( address: TransportAddress, path: string, protocols: seq[string] = @[], + extensions: seq[Ext] = @[], secure = false, flags: set[TLSFlags] = {}, version = WSDefaultVersion, @@ -142,6 +146,7 @@ proc connect*( return await WebSocket.connect( uri = parseUri(uri), protocols = protocols, + extensions = extensions, flags = flags, version = version, frameSize = frameSize, @@ -155,6 +160,7 @@ proc connect*( port: Port, path: string, protocols: seq[string] = @[], + extensions: seq[Ext] = @[], secure = false, flags: set[TLSFlags] = {}, version = WSDefaultVersion, @@ -168,6 +174,7 @@ proc connect*( address = initTAddress(host, port), path = path, protocols = protocols, + extensions = extensions, flags = flags, version = version, frameSize = frameSize, @@ -253,11 +260,11 @@ proc handleRequest*( proc new*( _: typedesc[WSServer], protos: openArray[string] = [""], + factories: openArray[ExtFactory] = [], frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, onPong: ControlCb = nil, onClose: CloseCb = nil, - extensions: openArray[Extension] = [], rng: Rng = nil): WSServer = return WSServer( @@ -265,6 +272,7 @@ proc new*( masked: false, rng: if isNil(rng): newRng() else: rng, frameSize: frameSize, + factories: @factories, onPing: onPing, onPong: onPong, onClose: onClose)