From 5d0bcf63750a085fe4ca6a273d04ee7f321fec6d Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Tue, 25 May 2021 08:02:32 -0600 Subject: [PATCH] Client server (#29) * better client/server separation (WIP) * add extensions interface * index out of bounds --- examples/client.nim | 2 +- examples/server.nim | 3 +- examples/tlsserver.nim | 3 +- tests/testtlswebsockets.nim | 22 +++-- tests/testwebsockets.nim | 166 +++++++++++++++++++++-------------- ws/extension.nim | 26 ++++++ ws/utils.nim | 3 + ws/ws.nim | 167 ++++++++++++++++++++++-------------- 8 files changed, 254 insertions(+), 138 deletions(-) create mode 100644 ws/extension.nim diff --git a/examples/client.nim b/examples/client.nim index ea61b9917e..274c75064b 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -3,7 +3,7 @@ import pkg/[ chronicles, stew/byteutils] -import ../ws/ws +import ../ws/ws, ../ws/errors proc main() {.async.} = let ws = await WebSocket.connect( diff --git a/examples/server.nim b/examples/server.nim index f5c33da1d5..fa477fe9f6 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -12,7 +12,8 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if request.uri.path == "/ws": debug "Initiating web socket connection." try: - let ws = await WebSocket.createServer(request, "") + let server = WSServer.new() + let ws = await server.handleRequest(request) if ws.readyState != Open: error "Failed to open websocket connection." return diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim index 35f1a25faf..c2ba9e3d3f 100644 --- a/examples/tlsserver.nim +++ b/examples/tlsserver.nim @@ -18,7 +18,8 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if request.uri.path == "/wss": debug "Initiating web socket connection." try: - var ws = await WebSocket.createServer(request, "myfancyprotocol") + let server = WSServer.new(protos = ["myfancyprotocol"]) + var ws = await server.handleRequest(request) if ws.readyState != Open: error "Failed to open websocket connection." return diff --git a/tests/testtlswebsockets.nim b/tests/testtlswebsockets.nim index d1c49b0bd1..786cc9c6c3 100644 --- a/tests/testtlswebsockets.nim +++ b/tests/testtlswebsockets.nim @@ -11,7 +11,7 @@ import ../ws/[ws, stream, errors], import ./keys -proc waitForClose(ws: WebSocket) {.async.} = +proc waitForClose(ws: WSSession) {.async.} = try: while ws.readystate != ReadyState.Closed: discard await ws.recv() @@ -39,8 +39,10 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" + let server = WSServer.new(protos = ["proto"]) + expect WSProtoMismatchError: - discard await WebSocket.createServer(request, "proto") + discard await server.handleRequest(request) let res = SecureHttpServerRef.new( address, cb, @@ -67,8 +69,10 @@ suite "Test websocket TLS handshake": let request = r.get() check request.uri.path == "/wss" + let server = WSServer.new(protos = ["proto"]) + expect WSVersionError: - discard await WebSocket.createServer(request, "proto") + discard await server.handleRequest(request) let res = SecureHttpServerRef.new( address, cb, @@ -135,9 +139,13 @@ suite "Test websocket TLS transmission": let request = r.get() check request.uri.path == "/wss" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + let servRes = await ws.recv() check string.fromBytes(servRes) == testString + await waitForClose(ws) return dumbResponse() @@ -169,9 +177,13 @@ suite "Test websocket TLS transmission": let request = r.get() check request.uri.path == "/wss" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + await ws.send(testString) await ws.close() + return dumbResponse() let res = SecureHttpServerRef.new( diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index 2f0ddf8831..2d1dd2d4f2 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -19,7 +19,7 @@ proc rndBin*(size: int): seq[byte] = for _ in .. size: add(result, byte(rand(0 .. 255))) -proc waitForClose(ws: WebSocket) {.async.} = +proc waitForClose(ws: WSSession) {.async.} = try: while ws.readystate != ReadyState.Closed: discard await ws.recv() @@ -38,8 +38,9 @@ suite "Test handshake": let request = r.get() check request.uri.path == "/ws" + let server = WSServer.new(protos = ["proto"]) expect WSProtoMismatchError: - discard await WebSocket.createServer(request, "proto") + discard await server.handleRequest(request) let res = HttpServerRef.new(address, cb) server = res.get() @@ -58,8 +59,9 @@ suite "Test handshake": return dumbResponse() let request = r.get() check request.uri.path == "/ws" + let server = WSServer.new(protos = ["ws"]) expect WSVersionError: - discard await WebSocket.createServer(request, "proto") + discard await server.handleRequest(request) let res = HttpServerRef.new(address, cb) server = res.get() @@ -110,8 +112,9 @@ suite "Test handshake": let request = r.get() check request.uri.path == "/ws" + let server = WSServer.new(protos = ["proto"]) expect WSProtoMismatchError: - var ws = await WebSocket.createServer(request, "proto") + var ws = await server.handleRequest(request) check ws.readyState == ReadyState.Closed return await request.respond(Http200, "Connection established") @@ -135,6 +138,7 @@ suite "Test handshake": suite "Test transmission": teardown: + await server.stop() await server.closeWait() test "Send text message message with payload of length 65535": @@ -145,7 +149,8 @@ suite "Test transmission": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let servRes = await ws.recv() check string.fromBytes(servRes) == testString @@ -167,10 +172,14 @@ suite "Test transmission": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isErr(): return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let servRes = await ws.recv() + check string.fromBytes(servRes) == testString await waitForClose(ws) @@ -183,6 +192,7 @@ suite "Test transmission": Port(8888), path = "/ws", protocols = @["proto"]) + await wsClient.send(testString) await wsClient.close() @@ -199,7 +209,9 @@ suite "Test transmission": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + await ws.send(testString) await ws.close() @@ -239,19 +251,19 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", + let server = WSServer.new( + protos = ["proto"], onPing = proc(data: openArray[byte]) = ping = true - ) + ) + + let ws = await server.handleRequest(request) let respData = await ws.recv() check string.fromBytes(respData) == testString await waitForClose(ws) - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -310,12 +322,12 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", + let server = WSServer.new( + protos = ["proto"], onPong = proc(data: openArray[byte]) = pong = true ) + let ws = await server.handleRequest(request) await ws.ping() await ws.close() @@ -346,12 +358,13 @@ suite "Test ping-pong": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", + let server = WSServer.new( + protos = ["proto"], onPing = proc(data: openArray[byte]) = ping = true ) + + let ws = await server.handleRequest(request) await waitForClose(ws) check: ping @@ -393,7 +406,8 @@ suite "Test framing": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let frame1 = await ws.readFrame() check not isNil(frame1) var data1 = newSeq[byte](frame1.remainder().int) @@ -435,7 +449,8 @@ suite "Test framing": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) await ws.send(testString) await ws.close() @@ -476,7 +491,8 @@ suite "Test Closing": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) await ws.close() let res = HttpServerRef.new(address, cb) @@ -515,11 +531,12 @@ suite "Test Closing": return (Status.Fulfilled, "") - let ws = await WebSocket.createServer( - request, - "proto", - onClose = closeServer) + let server = WSServer.new( + protos = ["proto"], + onClose = closeServer + ) + let ws = await server.handleRequest(request) await ws.close() let res = HttpServerRef.new(address, cb) @@ -551,7 +568,8 @@ suite "Test Closing": let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) await waitForClose(ws) let res = HttpServerRef.new(address, cb) @@ -580,10 +598,12 @@ suite "Test Closing": except Exception as exc: raise newException(Defect, exc.msg) - let ws = await WebSocket.createServer( - request, - "proto", - onClose = closeServer) + let server = WSServer.new( + protos = ["proto"], + onClose = closeServer + ) + + let ws = await server.handleRequest(request) await waitForClose(ws) let res = HttpServerRef.new(address, cb) @@ -615,14 +635,12 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) await ws.close(code = Status.ReservedCode) - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -658,16 +676,16 @@ suite "Test Closing": except Exception as exc: raise newException(Defect, exc.msg) - let ws = await WebSocket.createServer( - request, - "proto", - onClose = closeServer) + let server = WSServer.new( + protos = ["proto"], + onClose = closeServer + ) + let ws = await server.handleRequest(request) await waitForClose(ws) return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -689,7 +707,10 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + # Close with payload of length 2 await ws.close(reason = "HH") @@ -711,7 +732,10 @@ suite "Test Closing": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + await waitForClose(ws) let res = HttpServerRef.new( @@ -744,9 +768,8 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto") + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) expect WSPayloadTooLarge: discard await ws.recv() @@ -775,8 +798,11 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let servRes = await ws.recv() + check string.fromBytes(servRes) == emptyStr await waitForClose(ws) @@ -801,8 +827,11 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let servRes = await ws.recv() + check string.fromBytes(servRes) == emptyStr await waitForClose(ws) @@ -829,13 +858,13 @@ suite "Test Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", - onPing = proc(data: openArray[byte]) = - ping = data == testData - ) + let server = WSServer.new( + protos = ["proto"], + onPing = proc(data: openArray[byte]) = + ping = data == testData) + + let ws = await server.handleRequest(request) await waitForClose(ws) let res = HttpServerRef.new( @@ -876,7 +905,9 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) let servRes = await ws.recv() check: @@ -885,8 +916,7 @@ suite "Test Binary message with Payload": await waitForClose(ws) - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -906,7 +936,10 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer(request, "proto") + + let server = WSServer.new(protos = ["proto"]) + let ws = await server.handleRequest(request) + let servRes = await ws.recv() check: @@ -939,12 +972,13 @@ suite "Test Binary message with Payload": return dumbResponse() let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", + + let server = WSServer.new( + protos = ["proto"], onPing = proc(data: openArray[byte]) = ping = true ) + let ws = await server.handleRequest(request) let res = await ws.recv() check: @@ -976,14 +1010,16 @@ suite "Test Binary message with Payload": proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isErr(): return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - let ws = await WebSocket.createServer( - request, - "proto", + + let server = WSServer.new( + protos = ["proto"], onPing = proc(data: openArray[byte]) = ping = true ) + let ws = await server.handleRequest(request) let res = await ws.recv() check: diff --git a/ws/extension.nim b/ws/extension.nim new file mode 100644 index 0000000000..b4c875d3ff --- /dev/null +++ b/ws/extension.nim @@ -0,0 +1,26 @@ +## Nim-Libp2p +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +import pkg/[chronos, chronicles] +import ./frame + +type + Extension* = ref object of RootObj + name*: string + +proc `name=`*(self: Extension, name: string) = + raiseAssert "Can't change extensions name!" + +method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = + raiseAssert "Not implemented!" + +method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} = + raiseAssert "Not implemented!" diff --git a/ws/utils.nim b/ws/utils.nim index 2862b02da2..66302a3178 100644 --- a/ws/utils.nim +++ b/ws/utils.nim @@ -4,6 +4,9 @@ export bearssl ## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng const randMax = 18_446_744_073_709_551_615'u64 +type + Rng* = ref BrHmacDrbgContext + proc newRng*(): ref BrHmacDrbgContext = # You should only create one instance of the RNG per application / library # Ref is used so that it can be shared between components diff --git a/ws/ws.nim b/ws/ws.nim index d3fdfc4be5..f36a18b072 100644 --- a/ws/ws.nim +++ b/ws/ws.nim @@ -12,6 +12,7 @@ import std/[tables, strutils, + sequtils, uri, parseutils] @@ -28,7 +29,7 @@ import pkg/[chronos, stew/base10, nimcrypto/sha] -import ./utils, ./stream, ./frame, ./errors +import ./utils, ./stream, ./frame, ./errors, ./extension const SHA1DigestSize* = 20 @@ -77,21 +78,27 @@ type CloseCb* = proc(code: Status, reason: string): CloseResult {.gcsafe, raises: [Defect].} - WebSocket* = ref object - stream*: AsyncStream + WebSocket* = ref object of RootObj + extensions: seq[Extension] # extension active for this session version*: uint key*: string - protocol*: string + proto*: string readyState*: ReadyState masked*: bool # send masked packets binary*: bool # is payload binary? rng*: ref BrHmacDrbgContext frameSize: int - frame: Frame onPing: ControlCb onPong: ControlCb onClose: CloseCb + WSServer* = ref object of WebSocket + protocols: seq[string] + + WSSession* = ref object of WebSocket + stream*: AsyncStream + frame*: Frame + template remainder*(frame: Frame): uint64 = frame.length - frame.consumed @@ -114,11 +121,13 @@ proc prepareCloseBody(code: Status, reason: string): seq[byte] = result = @(ord(code).uint16.toBytesBE()) & result proc handshake*( - ws: WebSocket, + ws: WSServer, request: HttpRequestRef, - version: uint = WSDefaultVersion) {.async.} = + stream: AsyncStream, + version: uint = WSDefaultVersion): Future[WSSession] {.async.} = ## Handles the websocket handshake. ## + let reqHeaders = request.headers @@ -133,15 +142,21 @@ proc handshake*( reqHeaders.getString("Sec-WebSocket-Version")) ws.key = reqHeaders.getString("Sec-WebSocket-Key").strip() + var protos = @[""] if reqHeaders.contains("Sec-WebSocket-Protocol"): - let wantProtocol = reqHeaders.getString("Sec-WebSocket-Protocol").strip() - if ws.protocol != wantProtocol: - raise newException(WSProtoMismatchError, - "Protocol mismatch (expected: " & ws.protocol & ", got: " & - wantProtocol & ")") + let wantProtos = reqHeaders.getList("Sec-WebSocket-Protocol") + protos = wantProtos.filterIt( + it in ws.protocols + ) - let cKey = ws.key & WSGuid - let acceptKey = Base64Pad.encode( + if protos.len <= 0: + raise newException(WSProtoMismatchError, + "Protocol mismatch (expected: " & ws.protocols.join(", ") & ", got: " & + wantProtos.join(", ") & ")") + + let + cKey = ws.key & WSGuid + acceptKey = Base64Pad.encode( sha1.digest(cKey.toOpenArray(0, cKey.high)).data) var headerData = [ @@ -150,50 +165,30 @@ proc handshake*( ("Sec-WebSocket-Accept", acceptKey)] var headers = HttpTable.init(headerData) - if ws.protocol != "": - headers.add("Sec-WebSocket-Protocol", ws.protocol) + if protos.len > 0: + headers.add("Sec-WebSocket-Protocol", protos[0]) # send back the first matching proto try: discard await request.respond(httputils.Http101, "", headers) + except CancelledError as exc: + raise exc except CatchableError as exc: raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) - ws.readyState = ReadyState.Open - -proc createServer*( - _: typedesc[WebSocket], - request: HttpRequestRef, - protocol: string = "", - frameSize = WSDefaultFrameSize, - onPing: ControlCb = nil, - onPong: ControlCb = nil, - onClose: CloseCb = nil): Future[WebSocket] {.async.} = - ## Creates a new socket from a request. - ## - - if not request.headers.contains("Sec-WebSocket-Version"): - raise newException(WSHandshakeError, "Missing version header") - - let wsStream = AsyncStream( - reader: request.connection.reader, - writer: request.connection.writer) - - var ws = WebSocket( - stream: wsStream, - protocol: protocol, + return WSSession( + readyState: ReadyState.Open, + stream: stream, + proto: protos[0], masked: false, - rng: newRng(), - frameSize: frameSize, - onPing: onPing, - onPong: onPong, - onClose: onClose) - - await ws.handshake(request) - return ws + rng: ws.rng, + frameSize: ws.frameSize, + onPing: ws.onPing, + onPong: ws.onPong, + onClose: ws.onClose) proc send*( - ws: WebSocket, + ws: WSSession, data: seq[byte] = @[], opcode: Opcode) {.async.} = ## Send a frame @@ -252,20 +247,23 @@ proc send*( if i >= data.len: break -proc send*(ws: WebSocket, data: string): Future[void] = +proc send*(ws: WSSession, data: string): Future[void] = send(ws, toBytes(data), Opcode.Text) -proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = +proc handleClose*(ws: WSSession, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = + ## Handle close sequence + ## logScope: fin = frame.fin masked = frame.mask opcode = frame.opcode - serverState = ws.readyState + readyState = ws.readyState debug "Handling close sequence" if ws.readyState notin {ReadyState.Open}: + debug "Connection isn't open, abortig close sequence!" return var @@ -310,8 +308,8 @@ proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async ws.readyState = ReadyState.Closed await ws.stream.closeWait() -proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = - ## handle control frames +proc handleControl*(ws: WSSession, frame: Frame) {.async.} = + ## Handle control frames ## if not frame.fin: @@ -362,7 +360,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = ws.readyState = ReadyState.Closed await ws.stream.closeWait() -proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = +proc readFrame*(ws: WSSession): Future[Frame] {.async.} = ## Gets a frame from the WebSocket. ## See https://tools.ietf.org/html/rfc6455#section-5.2 ## @@ -387,11 +385,11 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = await ws.stream.closeWait() raise exc -proc ping*(ws: WebSocket, data: seq[byte] = @[]): Future[void] = +proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] = ws.send(data, opcode = Opcode.Ping) proc recv*( - ws: WebSocket, + ws: WSSession, data: pointer, size: int): Future[int] {.async.} = ## Attempts to read up to `size` bytes @@ -470,7 +468,7 @@ proc recv*( debug "Exception reading frames", exc = exc.msg proc recv*( - ws: WebSocket, + ws: WSSession, size = WSMaxMessageSize): Future[seq[byte]] {.async.} = ## Attempt to read a full message up to max `size` ## bytes in `frameSize` chunks. @@ -516,7 +514,7 @@ proc recv*( return res proc close*( - ws: WebSocket, + ws: WSSession, code: Status = Status.Fulfilled, reason: string = "") {.async.} = ## Close the Socket, sends close packet. @@ -607,7 +605,8 @@ proc connect*( frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, onPong: ControlCb = nil, - onClose: CloseCb = nil): Future[WebSocket] {.async.} = + onClose: CloseCb = nil, + rng: Rng = nil): Future[WSSession] {.async.} = ## create a new websockets client ## @@ -619,7 +618,8 @@ proc connect*( of "wss": uri.scheme = "https" else: - raise newException(WSWrongUriSchemeError, "uri scheme has to be 'ws' or 'wss'") + raise newException(WSWrongUriSchemeError, + "uri scheme has to be 'ws' or 'wss'") var headerData = [ ("Connection", "Upgrade"), @@ -637,11 +637,11 @@ proc connect*( let stream = await initiateHandshake(uri, address, headers, flags) # Client data should be masked. - return WebSocket( + return WSSession( stream: stream, readyState: ReadyState.Open, masked: true, - rng: newRng(), + rng: if isNil(rng): newRng() else: rng, frameSize: frameSize, onPing: onPing, onPong: onPong, @@ -657,7 +657,7 @@ proc connect*( frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, onPong: ControlCb = nil, - onClose: CloseCb = nil): Future[WebSocket] {.async.} = + onClose: CloseCb = nil): Future[WSSession] {.async.} = ## Create a new websockets client ## using a string path ## @@ -689,7 +689,8 @@ proc tlsConnect*( frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, onPong: ControlCb = nil, - onClose: CloseCb = nil): Future[WebSocket] {.async.} = + onClose: CloseCb = nil, + rng: Rng = nil): Future[WSSession] {.async.} = var uri = "wss://" & host & ":" & $port if path.startsWith("/"): @@ -705,4 +706,40 @@ proc tlsConnect*( frameSize, onPing, onPong, - onClose) + onClose, + rng) + +proc handleRequest*( + ws: WSServer, + request: HttpRequestRef): Future[WSSession] + {.raises: [Defect, WSHandshakeError].} = + ## Creates a new socket from a request. + ## + + if not request.headers.contains("Sec-WebSocket-Version"): + raise newException(WSHandshakeError, "Missing version header") + + let wsStream = AsyncStream( + reader: request.connection.reader, + writer: request.connection.writer) + + return ws.handshake(request, wsStream) + +proc new*( + _: typedesc[WSServer], + protos: openArray[string] = [""], + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil, + extensions: openArray[Extension] = [], + rng: Rng = nil): WSServer = + + return WSServer( + protocols: @protos, + masked: false, + rng: if isNil(rng): newRng() else: rng, + frameSize: frameSize, + onPing: onPing, + onPong: onPong, + onClose: onClose)