372 lines
10 KiB
Nim
372 lines
10 KiB
Nim
## nim-websock
|
|
## Copyright (c) 2021 Status Research & Development GmbH
|
|
## Licensed under either of
|
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
|
## at your option.
|
|
## This file may not be copied, modified, or distributed except according to
|
|
## those terms.
|
|
|
|
{.push raises: [Defect].}
|
|
|
|
import std/[tables,
|
|
strutils,
|
|
strformat,
|
|
sequtils,
|
|
uri]
|
|
|
|
import pkg/[chronos,
|
|
chronos/apps/http/httptable,
|
|
chronos/streams/asyncstream,
|
|
chronos/streams/tlsstream,
|
|
chronicles,
|
|
httputils,
|
|
stew/byteutils,
|
|
stew/base64,
|
|
stew/base10,
|
|
nimcrypto/sha]
|
|
|
|
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
|
|
|
export utils, session, frame, types, http, httptable
|
|
|
|
logScope:
|
|
topics = "websock ws-server"
|
|
|
|
type
|
|
WSServer* = ref object of WebSocket
|
|
protocols: seq[string]
|
|
factories: seq[ExtFactory]
|
|
|
|
func toException(e: string): ref WebSocketError =
|
|
(ref WebSocketError)(msg: e)
|
|
|
|
func toException(e: cstring): ref WebSocketError =
|
|
(ref WebSocketError)(msg: $e)
|
|
|
|
func contains(extensions: openArray[Ext], extName: string): bool =
|
|
for ext in extensions:
|
|
if ext.name == extName:
|
|
return true
|
|
|
|
proc getFactory(factories: openArray[ExtFactory], extName: string): ExtFactoryProc =
|
|
for n in factories:
|
|
if n.name == extName:
|
|
return n.factory
|
|
|
|
proc selectExt(isServer: bool,
|
|
extensions: var seq[Ext],
|
|
factories: openArray[ExtFactory],
|
|
exts: openArray[string]): string {.raises: [Defect, WSExtError].} =
|
|
|
|
var extList: seq[AppExt]
|
|
var response = ""
|
|
for ext in exts:
|
|
# each of "Sec-WebSocket-Extensions" can have multiple
|
|
# extensions or fallback extension
|
|
if not parseExt(ext, extList):
|
|
raise newException(WSExtError, "extension syntax error: " & ext)
|
|
|
|
for i, ext in extList:
|
|
if extensions.contains(ext.name):
|
|
# don't accept this fallback if prev ext
|
|
# configuration already accepted
|
|
trace "extension fallback not accepted", ext=ext.name
|
|
continue
|
|
|
|
# now look for right factory
|
|
let factory = factories.getFactory(ext.name)
|
|
if factory.isNil:
|
|
# no factory? it's ok, just skip it
|
|
trace "no extension factory", ext=ext.name
|
|
continue
|
|
|
|
let extRes = factory(isServer, ext.params)
|
|
if extRes.isErr:
|
|
# cannot create extension because of
|
|
# wrong/incompatible params? skip or fallback
|
|
trace "skip extension", ext=ext.name, msg=extRes.error
|
|
continue
|
|
|
|
let ext = extRes.get()
|
|
doAssert(not ext.isNil)
|
|
if i > 0:
|
|
# add separator if more than one exts
|
|
response.add ", "
|
|
response.add ext.toHttpOptions
|
|
|
|
# finally, accept the extension
|
|
trace "extension accepted", ext=ext.name
|
|
extensions.add ext
|
|
|
|
# HTTP response for "Sec-WebSocket-Extensions"
|
|
response
|
|
|
|
proc connect*(
|
|
_: type WebSocket,
|
|
host: string | TransportAddress,
|
|
path: string,
|
|
hostName: string = "", # override used when the hostname has been externally resolved
|
|
protocols: seq[string] = @[],
|
|
factories: seq[ExtFactory] = @[],
|
|
hooks: seq[Hook] = @[],
|
|
secure = false,
|
|
flags: set[TLSFlags] = {},
|
|
version = WSDefaultVersion,
|
|
frameSize = WSDefaultFrameSize,
|
|
onPing: ControlCb = nil,
|
|
onPong: ControlCb = nil,
|
|
onClose: CloseCb = nil,
|
|
rng: Rng = nil): Future[WSSession] {.async.} =
|
|
|
|
let
|
|
rng = if isNil(rng): newRng() else: rng
|
|
key = Base64Pad.encode(genWebSecKey(rng))
|
|
hostname = if hostName.len > 0: hostName else: $host
|
|
|
|
let client = if secure:
|
|
await TlsHttpClient.connect(host, tlsFlags = flags, hostName = hostname)
|
|
else:
|
|
await HttpClient.connect(host)
|
|
|
|
let headerData = [
|
|
("Connection", "Upgrade"),
|
|
("Upgrade", "websocket"),
|
|
("Cache-Control", "no-cache"),
|
|
("Sec-WebSocket-Version", $version),
|
|
("Sec-WebSocket-Key", key),
|
|
("Host", hostname)]
|
|
|
|
var headers = HttpTable.init(headerData)
|
|
if protocols.len > 0:
|
|
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
|
|
|
|
var extOffer = ""
|
|
for i, f in factories:
|
|
if i > 0:
|
|
extOffer.add ", "
|
|
extOffer.add f.clientOffer
|
|
|
|
if extOffer.len > 0:
|
|
headers.add("Sec-WebSocket-Extensions", extOffer)
|
|
|
|
for hp in hooks:
|
|
if hp.append == nil: continue
|
|
let res = hp.append(hp, headers)
|
|
if res.isErr:
|
|
raise newException(WSHookError,
|
|
"Header plugin execution failed: " & res.error)
|
|
|
|
let response = try:
|
|
await client.request(path, headers = headers)
|
|
except CatchableError as exc:
|
|
trace "Websocket failed during handshake", exc = exc.msg
|
|
await client.close()
|
|
raise exc
|
|
|
|
if response.code != Http101.toInt():
|
|
raise newException(WSFailedUpgradeError,
|
|
&"Server did not reply with a websocket upgrade: " &
|
|
&"Header code: {response.code} Header reason: {response.reason} " &
|
|
&"Address: {client.address}")
|
|
|
|
let proto = response.headers.getString("Sec-WebSocket-Protocol")
|
|
if proto.len > 0 and protocols.len > 0:
|
|
if proto notin protocols:
|
|
raise newException(WSFailedUpgradeError,
|
|
&"Invalid protocol returned {proto}!")
|
|
|
|
for hp in hooks:
|
|
if hp.verify == nil: continue
|
|
let res = await hp.verify(hp, response.headers)
|
|
if res.isErr:
|
|
raise newException(WSHookError,
|
|
"Header verification failed: " & res.error)
|
|
|
|
var extensions: seq[Ext]
|
|
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
|
discard selectExt(false, extensions, factories, exts)
|
|
|
|
# Client data should be masked.
|
|
let session = WSSession(
|
|
stream: client.stream,
|
|
readyState: ReadyState.Open,
|
|
masked: true,
|
|
extensions: system.move(extensions),
|
|
rng: rng,
|
|
frameSize: frameSize,
|
|
onPing: onPing,
|
|
onPong: onPong,
|
|
onClose: onClose)
|
|
|
|
for ext in session.extensions:
|
|
ext.session = session
|
|
|
|
return session
|
|
|
|
proc connect*(
|
|
_: type WebSocket,
|
|
uri: Uri,
|
|
protocols: seq[string] = @[],
|
|
factories: seq[ExtFactory] = @[],
|
|
hooks: seq[Hook] = @[],
|
|
flags: set[TLSFlags] = {},
|
|
version = WSDefaultVersion,
|
|
frameSize = WSDefaultFrameSize,
|
|
onPing: ControlCb = nil,
|
|
onPong: ControlCb = nil,
|
|
onClose: CloseCb = nil,
|
|
rng: Rng = nil): Future[WSSession]
|
|
{.raises: [Defect, WSWrongUriSchemeError].} =
|
|
## Create a new websockets client
|
|
## using a Uri
|
|
##
|
|
|
|
let secure = case uri.scheme:
|
|
of "wss": true
|
|
of "ws": false
|
|
else:
|
|
raise newException(WSWrongUriSchemeError,
|
|
"uri scheme has to be 'ws' or 'wss'")
|
|
|
|
var uri = uri
|
|
if uri.port.len <= 0:
|
|
uri.port = if secure: "443" else: "80"
|
|
|
|
return WebSocket.connect(
|
|
host = uri.hostname & ":" & uri.port,
|
|
path = uri.path,
|
|
protocols = protocols,
|
|
factories = factories,
|
|
hooks = hooks,
|
|
secure = secure,
|
|
flags = flags,
|
|
version = version,
|
|
frameSize = frameSize,
|
|
onPing = onPing,
|
|
onPong = onPong,
|
|
onClose = onClose,
|
|
rng = rng)
|
|
|
|
proc handleRequest*(
|
|
ws: WSServer,
|
|
request: HttpRequest,
|
|
version: uint = WSDefaultVersion,
|
|
hooks: seq[Hook] = @[]): Future[WSSession]
|
|
{.
|
|
async,
|
|
raises: [
|
|
Defect,
|
|
WSHandshakeError,
|
|
WSProtoMismatchError]
|
|
.} =
|
|
## Creates a new socket from a request.
|
|
##
|
|
|
|
if not request.headers.contains("Sec-WebSocket-Version"):
|
|
raise newException(WSHandshakeError, "Missing version header")
|
|
|
|
ws.version = Base10.decode(
|
|
uint,
|
|
request.headers.getString("Sec-WebSocket-Version"))
|
|
.tryGet() # this method throws
|
|
|
|
if ws.version != version:
|
|
await request.stream.writer.sendError(Http426)
|
|
trace "Websocket version not supported", version = ws.version
|
|
|
|
raise newException(WSVersionError,
|
|
&"Websocket version not supported, Version: {version}")
|
|
|
|
ws.key = request.headers.getString("Sec-WebSocket-Key").strip()
|
|
let wantProtos = if request.headers.contains("Sec-WebSocket-Protocol"):
|
|
request.headers.getList("Sec-WebSocket-Protocol")
|
|
else:
|
|
@[""]
|
|
|
|
let protos = wantProtos.filterIt(
|
|
it in ws.protocols
|
|
)
|
|
|
|
for hp in hooks:
|
|
if hp.verify == nil: continue
|
|
let res = await hp.verify(hp, request.headers)
|
|
if res.isErr:
|
|
raise newException(WSHookError,
|
|
"Header verification failed: " & res.error)
|
|
|
|
let
|
|
cKey = ws.key & WSGuid
|
|
acceptKey = Base64Pad.encode(
|
|
sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
|
|
|
|
var headers = HttpTable.init([
|
|
("Connection", "Upgrade"),
|
|
("Upgrade", "websocket"),
|
|
("Sec-WebSocket-Accept", acceptKey)])
|
|
|
|
let protocol = if protos.len > 0: protos[0] else: ""
|
|
if protocol.len > 0:
|
|
headers.add("Sec-WebSocket-Protocol", protocol) # send back the first matching proto
|
|
else:
|
|
trace "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
|
|
|
# it is possible to have multiple "Sec-WebSocket-Extensions"
|
|
let exts = request.headers.getList("Sec-WebSocket-Extensions")
|
|
let extResp = selectExt(true, ws.extensions, ws.factories, exts)
|
|
if extResp.len > 0:
|
|
# send back any accepted extensions
|
|
headers.add("Sec-WebSocket-Extensions", extResp)
|
|
|
|
for hp in hooks:
|
|
if hp.append == nil: continue
|
|
let res = hp.append(hp, headers)
|
|
if res.isErr:
|
|
raise newException(WSHookError,
|
|
"Header plugin execution failed: " & res.error)
|
|
|
|
try:
|
|
await request.sendResponse(Http101, headers = headers)
|
|
except CancelledError as exc:
|
|
raise exc
|
|
except CatchableError as exc:
|
|
raise newException(WSHandshakeError,
|
|
"Failed to sent handshake response. Error: " & exc.msg)
|
|
|
|
let session = WSSession(
|
|
readyState: ReadyState.Open,
|
|
stream: request.stream,
|
|
proto: protocol,
|
|
extensions: system.move(ws.extensions),
|
|
masked: false,
|
|
rng: ws.rng,
|
|
frameSize: ws.frameSize,
|
|
onPing: ws.onPing,
|
|
onPong: ws.onPong,
|
|
onClose: ws.onClose)
|
|
|
|
for ext in session.extensions:
|
|
ext.session = session
|
|
|
|
return session
|
|
|
|
proc new*(
|
|
_: typedesc[WSServer],
|
|
protos: openArray[string] = [""],
|
|
factories: openArray[ExtFactory] = [],
|
|
frameSize = WSDefaultFrameSize,
|
|
onPing: ControlCb = nil,
|
|
onPong: ControlCb = nil,
|
|
onClose: CloseCb = nil,
|
|
rng: Rng = nil): WSServer =
|
|
|
|
return WSServer(
|
|
protocols: @protos,
|
|
masked: false,
|
|
rng: if isNil(rng): newRng() else: rng,
|
|
frameSize: frameSize,
|
|
factories: @factories,
|
|
onPing: onPing,
|
|
onPong: onPong,
|
|
onClose: onClose)
|