Add server Hook for authentication (#133)

why:
  JWT authentication needs that
This commit is contained in:
Jordan Hrycaj 2022-04-05 16:19:52 +01:00 committed by GitHub
parent 9e0a9496c5
commit d4ae2328d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
import
chronicles, httputils, chronos, websock/websock,
chronicles, httputils, chronos, websock/[websock, types],
websock/extensions/compression/deflate,
stew/byteutils, json_serialization/std/net,
".."/[errors, server]
@ -11,16 +11,18 @@ logScope:
type
RpcWebSocketServer* = ref object of RpcServer
authHook: seq[Hook]
server: StreamServer
wsserver: WSServer
proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
trace "Handling request:", uri = request.uri.path
trace "Initiating web socket connection."
try:
let server = rpc.wsserver
let ws = await server.handleRequest(request)
if ws.readyState != Open:
let ws = await server.handleRequest(request, hooks = rpc.authHook)
if ws.readyState != ReadyState.Open:
error "Failed to open websocket connection"
return
@ -58,24 +60,26 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg
proc initWebsocket(rpc: RpcWebSocketServer, compression: bool) =
proc initWebsocket(rpc: RpcWebSocketServer,
compression: bool, authHandler: seq[Hook]) =
if compression:
let deflateFactory = deflateFactory()
rpc.wsserver = WSServer.new(factories = [deflateFactory])
else:
rpc.wsserver = WSServer.new()
rpc.authHook = authHandler
proc newRpcWebSocketServer*(
address: TransportAddress,
compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,
ServerFlags.ReuseAddr}): RpcWebSocketServer =
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr},
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request)
server.initWebsocket(compression)
server.initWebsocket(compression, authHandler)
server.server = HttpServer.create(
address,
processCallback,
@ -88,13 +92,14 @@ proc newRpcWebSocketServer*(
host: string,
port: Port,
compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,
ServerFlags.ReuseAddr}): RpcWebSocketServer =
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer(
initTAddress(host, port),
compression,
flags
flags,
authHandler
)
proc newRpcWebSocketServer*(
@ -106,13 +111,14 @@ proc newRpcWebSocketServer*(
ServerFlags.ReuseAddr},
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer =
tlsMaxVersion = TLSVersion.TLS12,
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request)
server.initWebsocket(compression)
server.initWebsocket(compression, authHandler)
server.server = TlsHttpServer.create(
address,
tlsPrivateKey,
@ -136,7 +142,8 @@ proc newRpcWebSocketServer*(
ServerFlags.ReuseAddr},
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer =
tlsMaxVersion = TLSVersion.TLS12,
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer(
initTAddress(host, port),
@ -146,7 +153,8 @@ proc newRpcWebSocketServer*(
flags,
tlsFlags,
tlsMinVersion,
tlsMaxVersion
tlsMaxVersion,
authHandler
)
proc start*(server: RpcWebSocketServer) =