Client server (#29)

* better client/server separation (WIP)

* add extensions interface

* index out of bounds
This commit is contained in:
Dmitriy Ryajov 2021-05-25 08:02:32 -06:00 committed by GitHub
parent 0a4121c29d
commit 5d0bcf6375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 254 additions and 138 deletions

View File

@ -3,7 +3,7 @@ import pkg/[
chronicles, chronicles,
stew/byteutils] stew/byteutils]
import ../ws/ws import ../ws/ws, ../ws/errors
proc main() {.async.} = proc main() {.async.} =
let ws = await WebSocket.connect( let ws = await WebSocket.connect(

View File

@ -12,7 +12,8 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if request.uri.path == "/ws": if request.uri.path == "/ws":
debug "Initiating web socket connection." debug "Initiating web socket connection."
try: try:
let ws = await WebSocket.createServer(request, "") let server = WSServer.new()
let ws = await server.handleRequest(request)
if ws.readyState != Open: if ws.readyState != Open:
error "Failed to open websocket connection." error "Failed to open websocket connection."
return return

View File

@ -18,7 +18,8 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if request.uri.path == "/wss": if request.uri.path == "/wss":
debug "Initiating web socket connection." debug "Initiating web socket connection."
try: try:
var ws = await WebSocket.createServer(request, "myfancyprotocol") let server = WSServer.new(protos = ["myfancyprotocol"])
var ws = await server.handleRequest(request)
if ws.readyState != Open: if ws.readyState != Open:
error "Failed to open websocket connection." error "Failed to open websocket connection."
return return

View File

@ -11,7 +11,7 @@ import ../ws/[ws, stream, errors],
import ./keys import ./keys
proc waitForClose(ws: WebSocket) {.async.} = proc waitForClose(ws: WSSession) {.async.} =
try: try:
while ws.readystate != ReadyState.Closed: while ws.readystate != ReadyState.Closed:
discard await ws.recv() discard await ws.recv()
@ -39,8 +39,10 @@ suite "Test websocket TLS handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
let server = WSServer.new(protos = ["proto"])
expect WSProtoMismatchError: expect WSProtoMismatchError:
discard await WebSocket.createServer(request, "proto") discard await server.handleRequest(request)
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
address, cb, address, cb,
@ -67,8 +69,10 @@ suite "Test websocket TLS handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
let server = WSServer.new(protos = ["proto"])
expect WSVersionError: expect WSVersionError:
discard await WebSocket.createServer(request, "proto") discard await server.handleRequest(request)
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
address, cb, address, cb,
@ -135,9 +139,13 @@ suite "Test websocket TLS transmission":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == testString check string.fromBytes(servRes) == testString
await waitForClose(ws) await waitForClose(ws)
return dumbResponse() return dumbResponse()
@ -169,9 +177,13 @@ suite "Test websocket TLS transmission":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await ws.send(testString) await ws.send(testString)
await ws.close() await ws.close()
return dumbResponse() return dumbResponse()
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(

View File

@ -19,7 +19,7 @@ proc rndBin*(size: int): seq[byte] =
for _ in .. size: for _ in .. size:
add(result, byte(rand(0 .. 255))) add(result, byte(rand(0 .. 255)))
proc waitForClose(ws: WebSocket) {.async.} = proc waitForClose(ws: WSSession) {.async.} =
try: try:
while ws.readystate != ReadyState.Closed: while ws.readystate != ReadyState.Closed:
discard await ws.recv() discard await ws.recv()
@ -38,8 +38,9 @@ suite "Test handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let server = WSServer.new(protos = ["proto"])
expect WSProtoMismatchError: expect WSProtoMismatchError:
discard await WebSocket.createServer(request, "proto") discard await server.handleRequest(request)
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -58,8 +59,9 @@ suite "Test handshake":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let server = WSServer.new(protos = ["ws"])
expect WSVersionError: expect WSVersionError:
discard await WebSocket.createServer(request, "proto") discard await server.handleRequest(request)
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -110,8 +112,9 @@ suite "Test handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let server = WSServer.new(protos = ["proto"])
expect WSProtoMismatchError: expect WSProtoMismatchError:
var ws = await WebSocket.createServer(request, "proto") var ws = await server.handleRequest(request)
check ws.readyState == ReadyState.Closed check ws.readyState == ReadyState.Closed
return await request.respond(Http200, "Connection established") return await request.respond(Http200, "Connection established")
@ -135,6 +138,7 @@ suite "Test handshake":
suite "Test transmission": suite "Test transmission":
teardown: teardown:
await server.stop()
await server.closeWait() await server.closeWait()
test "Send text message message with payload of length 65535": test "Send text message message with payload of length 65535":
@ -145,7 +149,8 @@ suite "Test transmission":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == testString check string.fromBytes(servRes) == testString
@ -167,10 +172,14 @@ suite "Test transmission":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == testString check string.fromBytes(servRes) == testString
await waitForClose(ws) await waitForClose(ws)
@ -183,6 +192,7 @@ suite "Test transmission":
Port(8888), Port(8888),
path = "/ws", path = "/ws",
protocols = @["proto"]) protocols = @["proto"])
await wsClient.send(testString) await wsClient.send(testString)
await wsClient.close() await wsClient.close()
@ -199,7 +209,9 @@ suite "Test transmission":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await ws.send(testString) await ws.send(testString)
await ws.close() await ws.close()
@ -239,19 +251,19 @@ suite "Test ping-pong":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto",
onPing = proc(data: openArray[byte]) = onPing = proc(data: openArray[byte]) =
ping = true ping = true
) )
let ws = await server.handleRequest(request)
let respData = await ws.recv() let respData = await ws.recv()
check string.fromBytes(respData) == testString check string.fromBytes(respData) == testString
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new( let res = HttpServerRef.new(address, cb)
address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -310,12 +322,12 @@ suite "Test ping-pong":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto",
onPong = proc(data: openArray[byte]) = onPong = proc(data: openArray[byte]) =
pong = true pong = true
) )
let ws = await server.handleRequest(request)
await ws.ping() await ws.ping()
await ws.close() await ws.close()
@ -346,12 +358,13 @@ suite "Test ping-pong":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto",
onPing = proc(data: openArray[byte]) = onPing = proc(data: openArray[byte]) =
ping = true ping = true
) )
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
check: check:
ping ping
@ -393,7 +406,8 @@ suite "Test framing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let frame1 = await ws.readFrame() let frame1 = await ws.readFrame()
check not isNil(frame1) check not isNil(frame1)
var data1 = newSeq[byte](frame1.remainder().int) var data1 = newSeq[byte](frame1.remainder().int)
@ -435,7 +449,8 @@ suite "Test framing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await ws.send(testString) await ws.send(testString)
await ws.close() await ws.close()
@ -476,7 +491,8 @@ suite "Test Closing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await ws.close() await ws.close()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
@ -515,11 +531,12 @@ suite "Test Closing":
return (Status.Fulfilled, "") return (Status.Fulfilled, "")
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto", onClose = closeServer
onClose = closeServer) )
let ws = await server.handleRequest(request)
await ws.close() await ws.close()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
@ -551,7 +568,8 @@ suite "Test Closing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto") let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
@ -580,10 +598,12 @@ suite "Test Closing":
except Exception as exc: except Exception as exc:
raise newException(Defect, exc.msg) raise newException(Defect, exc.msg)
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto", onClose = closeServer
onClose = closeServer) )
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
@ -615,14 +635,12 @@ suite "Test Closing":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer( let server = WSServer.new(protos = ["proto"])
request, let ws = await server.handleRequest(request)
"proto")
await ws.close(code = Status.ReservedCode) await ws.close(code = Status.ReservedCode)
let res = HttpServerRef.new( let res = HttpServerRef.new(address, cb)
address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -658,16 +676,16 @@ suite "Test Closing":
except Exception as exc: except Exception as exc:
raise newException(Defect, exc.msg) raise newException(Defect, exc.msg)
let ws = await WebSocket.createServer( let server = WSServer.new(
request, protos = ["proto"],
"proto", onClose = closeServer
onClose = closeServer) )
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
return dumbResponse() return dumbResponse()
let res = HttpServerRef.new( let res = HttpServerRef.new(address, cb)
address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -689,7 +707,10 @@ suite "Test Closing":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
# Close with payload of length 2 # Close with payload of length 2
await ws.close(reason = "HH") await ws.close(reason = "HH")
@ -711,7 +732,10 @@ suite "Test Closing":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new( let res = HttpServerRef.new(
@ -744,9 +768,8 @@ suite "Test Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer( let server = WSServer.new(protos = ["proto"])
request, let ws = await server.handleRequest(request)
"proto")
expect WSPayloadTooLarge: expect WSPayloadTooLarge:
discard await ws.recv() discard await ws.recv()
@ -775,8 +798,11 @@ suite "Test Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr check string.fromBytes(servRes) == emptyStr
await waitForClose(ws) await waitForClose(ws)
@ -801,8 +827,11 @@ suite "Test Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr check string.fromBytes(servRes) == emptyStr
await waitForClose(ws) await waitForClose(ws)
@ -829,13 +858,13 @@ suite "Test Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(
request,
"proto",
onPing = proc(data: openArray[byte]) =
ping = data == testData
)
let server = WSServer.new(
protos = ["proto"],
onPing = proc(data: openArray[byte]) =
ping = data == testData)
let ws = await server.handleRequest(request)
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new( let res = HttpServerRef.new(
@ -876,7 +905,9 @@ suite "Test Binary message with Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check: check:
@ -885,8 +916,7 @@ suite "Test Binary message with Payload":
await waitForClose(ws) await waitForClose(ws)
let res = HttpServerRef.new( let res = HttpServerRef.new(address, cb)
address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -906,7 +936,10 @@ suite "Test Binary message with Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(request, "proto")
let server = WSServer.new(protos = ["proto"])
let ws = await server.handleRequest(request)
let servRes = await ws.recv() let servRes = await ws.recv()
check: check:
@ -939,12 +972,13 @@ suite "Test Binary message with Payload":
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(
request, let server = WSServer.new(
"proto", protos = ["proto"],
onPing = proc(data: openArray[byte]) = onPing = proc(data: openArray[byte]) =
ping = true ping = true
) )
let ws = await server.handleRequest(request)
let res = await ws.recv() let res = await ws.recv()
check: check:
@ -976,14 +1010,16 @@ suite "Test Binary message with Payload":
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return dumbResponse() return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await WebSocket.createServer(
request, let server = WSServer.new(
"proto", protos = ["proto"],
onPing = proc(data: openArray[byte]) = onPing = proc(data: openArray[byte]) =
ping = true ping = true
) )
let ws = await server.handleRequest(request)
let res = await ws.recv() let res = await ws.recv()
check: check:

26
ws/extension.nim Normal file
View File

@ -0,0 +1,26 @@
## Nim-Libp2p
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import pkg/[chronos, chronicles]
import ./frame
type
Extension* = ref object of RootObj
name*: string
proc `name=`*(self: Extension, name: string) =
raiseAssert "Can't change extensions name!"
method decode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"
method encode*(self: Extension, frame: Frame): Future[Frame] {.base, async.} =
raiseAssert "Not implemented!"

View File

@ -4,6 +4,9 @@ export bearssl
## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng ## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng
const randMax = 18_446_744_073_709_551_615'u64 const randMax = 18_446_744_073_709_551_615'u64
type
Rng* = ref BrHmacDrbgContext
proc newRng*(): ref BrHmacDrbgContext = proc newRng*(): ref BrHmacDrbgContext =
# You should only create one instance of the RNG per application / library # You should only create one instance of the RNG per application / library
# Ref is used so that it can be shared between components # Ref is used so that it can be shared between components

167
ws/ws.nim
View File

@ -12,6 +12,7 @@
import std/[tables, import std/[tables,
strutils, strutils,
sequtils,
uri, uri,
parseutils] parseutils]
@ -28,7 +29,7 @@ import pkg/[chronos,
stew/base10, stew/base10,
nimcrypto/sha] nimcrypto/sha]
import ./utils, ./stream, ./frame, ./errors import ./utils, ./stream, ./frame, ./errors, ./extension
const const
SHA1DigestSize* = 20 SHA1DigestSize* = 20
@ -77,21 +78,27 @@ type
CloseCb* = proc(code: Status, reason: string): CloseCb* = proc(code: Status, reason: string):
CloseResult {.gcsafe, raises: [Defect].} CloseResult {.gcsafe, raises: [Defect].}
WebSocket* = ref object WebSocket* = ref object of RootObj
stream*: AsyncStream extensions: seq[Extension] # extension active for this session
version*: uint version*: uint
key*: string key*: string
protocol*: string proto*: string
readyState*: ReadyState readyState*: ReadyState
masked*: bool # send masked packets masked*: bool # send masked packets
binary*: bool # is payload binary? binary*: bool # is payload binary?
rng*: ref BrHmacDrbgContext rng*: ref BrHmacDrbgContext
frameSize: int frameSize: int
frame: Frame
onPing: ControlCb onPing: ControlCb
onPong: ControlCb onPong: ControlCb
onClose: CloseCb onClose: CloseCb
WSServer* = ref object of WebSocket
protocols: seq[string]
WSSession* = ref object of WebSocket
stream*: AsyncStream
frame*: Frame
template remainder*(frame: Frame): uint64 = template remainder*(frame: Frame): uint64 =
frame.length - frame.consumed frame.length - frame.consumed
@ -114,11 +121,13 @@ proc prepareCloseBody(code: Status, reason: string): seq[byte] =
result = @(ord(code).uint16.toBytesBE()) & result result = @(ord(code).uint16.toBytesBE()) & result
proc handshake*( proc handshake*(
ws: WebSocket, ws: WSServer,
request: HttpRequestRef, request: HttpRequestRef,
version: uint = WSDefaultVersion) {.async.} = stream: AsyncStream,
version: uint = WSDefaultVersion): Future[WSSession] {.async.} =
## Handles the websocket handshake. ## Handles the websocket handshake.
## ##
let let
reqHeaders = request.headers reqHeaders = request.headers
@ -133,15 +142,21 @@ proc handshake*(
reqHeaders.getString("Sec-WebSocket-Version")) reqHeaders.getString("Sec-WebSocket-Version"))
ws.key = reqHeaders.getString("Sec-WebSocket-Key").strip() ws.key = reqHeaders.getString("Sec-WebSocket-Key").strip()
var protos = @[""]
if reqHeaders.contains("Sec-WebSocket-Protocol"): if reqHeaders.contains("Sec-WebSocket-Protocol"):
let wantProtocol = reqHeaders.getString("Sec-WebSocket-Protocol").strip() let wantProtos = reqHeaders.getList("Sec-WebSocket-Protocol")
if ws.protocol != wantProtocol: protos = wantProtos.filterIt(
raise newException(WSProtoMismatchError, it in ws.protocols
"Protocol mismatch (expected: " & ws.protocol & ", got: " & )
wantProtocol & ")")
let cKey = ws.key & WSGuid if protos.len <= 0:
let acceptKey = Base64Pad.encode( raise newException(WSProtoMismatchError,
"Protocol mismatch (expected: " & ws.protocols.join(", ") & ", got: " &
wantProtos.join(", ") & ")")
let
cKey = ws.key & WSGuid
acceptKey = Base64Pad.encode(
sha1.digest(cKey.toOpenArray(0, cKey.high)).data) sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
var headerData = [ var headerData = [
@ -150,50 +165,30 @@ proc handshake*(
("Sec-WebSocket-Accept", acceptKey)] ("Sec-WebSocket-Accept", acceptKey)]
var headers = HttpTable.init(headerData) var headers = HttpTable.init(headerData)
if ws.protocol != "": if protos.len > 0:
headers.add("Sec-WebSocket-Protocol", ws.protocol) headers.add("Sec-WebSocket-Protocol", protos[0]) # send back the first matching proto
try: try:
discard await request.respond(httputils.Http101, "", headers) discard await request.respond(httputils.Http101, "", headers)
except CancelledError as exc:
raise exc
except CatchableError as exc: except CatchableError as exc:
raise newException(WSHandshakeError, raise newException(WSHandshakeError,
"Failed to sent handshake response. Error: " & exc.msg) "Failed to sent handshake response. Error: " & exc.msg)
ws.readyState = ReadyState.Open return WSSession(
readyState: ReadyState.Open,
proc createServer*( stream: stream,
_: typedesc[WebSocket], proto: protos[0],
request: HttpRequestRef,
protocol: string = "",
frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil,
onPong: ControlCb = nil,
onClose: CloseCb = nil): Future[WebSocket] {.async.} =
## Creates a new socket from a request.
##
if not request.headers.contains("Sec-WebSocket-Version"):
raise newException(WSHandshakeError, "Missing version header")
let wsStream = AsyncStream(
reader: request.connection.reader,
writer: request.connection.writer)
var ws = WebSocket(
stream: wsStream,
protocol: protocol,
masked: false, masked: false,
rng: newRng(), rng: ws.rng,
frameSize: frameSize, frameSize: ws.frameSize,
onPing: onPing, onPing: ws.onPing,
onPong: onPong, onPong: ws.onPong,
onClose: onClose) onClose: ws.onClose)
await ws.handshake(request)
return ws
proc send*( proc send*(
ws: WebSocket, ws: WSSession,
data: seq[byte] = @[], data: seq[byte] = @[],
opcode: Opcode) {.async.} = opcode: Opcode) {.async.} =
## Send a frame ## Send a frame
@ -252,20 +247,23 @@ proc send*(
if i >= data.len: if i >= data.len:
break break
proc send*(ws: WebSocket, data: string): Future[void] = proc send*(ws: WSSession, data: string): Future[void] =
send(ws, toBytes(data), Opcode.Text) send(ws, toBytes(data), Opcode.Text)
proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} = proc handleClose*(ws: WSSession, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
## Handle close sequence
##
logScope: logScope:
fin = frame.fin fin = frame.fin
masked = frame.mask masked = frame.mask
opcode = frame.opcode opcode = frame.opcode
serverState = ws.readyState readyState = ws.readyState
debug "Handling close sequence" debug "Handling close sequence"
if ws.readyState notin {ReadyState.Open}: if ws.readyState notin {ReadyState.Open}:
debug "Connection isn't open, abortig close sequence!"
return return
var var
@ -310,8 +308,8 @@ proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async
ws.readyState = ReadyState.Closed ws.readyState = ReadyState.Closed
await ws.stream.closeWait() await ws.stream.closeWait()
proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
## handle control frames ## Handle control frames
## ##
if not frame.fin: if not frame.fin:
@ -362,7 +360,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
ws.readyState = ReadyState.Closed ws.readyState = ReadyState.Closed
await ws.stream.closeWait() await ws.stream.closeWait()
proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
## Gets a frame from the WebSocket. ## Gets a frame from the WebSocket.
## See https://tools.ietf.org/html/rfc6455#section-5.2 ## See https://tools.ietf.org/html/rfc6455#section-5.2
## ##
@ -387,11 +385,11 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
await ws.stream.closeWait() await ws.stream.closeWait()
raise exc raise exc
proc ping*(ws: WebSocket, data: seq[byte] = @[]): Future[void] = proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] =
ws.send(data, opcode = Opcode.Ping) ws.send(data, opcode = Opcode.Ping)
proc recv*( proc recv*(
ws: WebSocket, ws: WSSession,
data: pointer, data: pointer,
size: int): Future[int] {.async.} = size: int): Future[int] {.async.} =
## Attempts to read up to `size` bytes ## Attempts to read up to `size` bytes
@ -470,7 +468,7 @@ proc recv*(
debug "Exception reading frames", exc = exc.msg debug "Exception reading frames", exc = exc.msg
proc recv*( proc recv*(
ws: WebSocket, ws: WSSession,
size = WSMaxMessageSize): Future[seq[byte]] {.async.} = size = WSMaxMessageSize): Future[seq[byte]] {.async.} =
## Attempt to read a full message up to max `size` ## Attempt to read a full message up to max `size`
## bytes in `frameSize` chunks. ## bytes in `frameSize` chunks.
@ -516,7 +514,7 @@ proc recv*(
return res return res
proc close*( proc close*(
ws: WebSocket, ws: WSSession,
code: Status = Status.Fulfilled, code: Status = Status.Fulfilled,
reason: string = "") {.async.} = reason: string = "") {.async.} =
## Close the Socket, sends close packet. ## Close the Socket, sends close packet.
@ -607,7 +605,8 @@ proc connect*(
frameSize = WSDefaultFrameSize, frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil, onPing: ControlCb = nil,
onPong: ControlCb = nil, onPong: ControlCb = nil,
onClose: CloseCb = nil): Future[WebSocket] {.async.} = onClose: CloseCb = nil,
rng: Rng = nil): Future[WSSession] {.async.} =
## create a new websockets client ## create a new websockets client
## ##
@ -619,7 +618,8 @@ proc connect*(
of "wss": of "wss":
uri.scheme = "https" uri.scheme = "https"
else: else:
raise newException(WSWrongUriSchemeError, "uri scheme has to be 'ws' or 'wss'") raise newException(WSWrongUriSchemeError,
"uri scheme has to be 'ws' or 'wss'")
var headerData = [ var headerData = [
("Connection", "Upgrade"), ("Connection", "Upgrade"),
@ -637,11 +637,11 @@ proc connect*(
let stream = await initiateHandshake(uri, address, headers, flags) let stream = await initiateHandshake(uri, address, headers, flags)
# Client data should be masked. # Client data should be masked.
return WebSocket( return WSSession(
stream: stream, stream: stream,
readyState: ReadyState.Open, readyState: ReadyState.Open,
masked: true, masked: true,
rng: newRng(), rng: if isNil(rng): newRng() else: rng,
frameSize: frameSize, frameSize: frameSize,
onPing: onPing, onPing: onPing,
onPong: onPong, onPong: onPong,
@ -657,7 +657,7 @@ proc connect*(
frameSize = WSDefaultFrameSize, frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil, onPing: ControlCb = nil,
onPong: ControlCb = nil, onPong: ControlCb = nil,
onClose: CloseCb = nil): Future[WebSocket] {.async.} = onClose: CloseCb = nil): Future[WSSession] {.async.} =
## Create a new websockets client ## Create a new websockets client
## using a string path ## using a string path
## ##
@ -689,7 +689,8 @@ proc tlsConnect*(
frameSize = WSDefaultFrameSize, frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil, onPing: ControlCb = nil,
onPong: ControlCb = nil, onPong: ControlCb = nil,
onClose: CloseCb = nil): Future[WebSocket] {.async.} = onClose: CloseCb = nil,
rng: Rng = nil): Future[WSSession] {.async.} =
var uri = "wss://" & host & ":" & $port var uri = "wss://" & host & ":" & $port
if path.startsWith("/"): if path.startsWith("/"):
@ -705,4 +706,40 @@ proc tlsConnect*(
frameSize, frameSize,
onPing, onPing,
onPong, onPong,
onClose) onClose,
rng)
proc handleRequest*(
ws: WSServer,
request: HttpRequestRef): Future[WSSession]
{.raises: [Defect, WSHandshakeError].} =
## Creates a new socket from a request.
##
if not request.headers.contains("Sec-WebSocket-Version"):
raise newException(WSHandshakeError, "Missing version header")
let wsStream = AsyncStream(
reader: request.connection.reader,
writer: request.connection.writer)
return ws.handshake(request, wsStream)
proc new*(
_: typedesc[WSServer],
protos: openArray[string] = [""],
frameSize = WSDefaultFrameSize,
onPing: ControlCb = nil,
onPong: ControlCb = nil,
onClose: CloseCb = nil,
extensions: openArray[Extension] = [],
rng: Rng = nil): WSServer =
return WSServer(
protocols: @protos,
masked: false,
rng: if isNil(rng): newRng() else: rng,
frameSize: frameSize,
onPing: onPing,
onPong: onPong,
onClose: onClose)