Add server Hook for authentication (#133)

why:
  JWT authentication needs that
This commit is contained in:
Jordan Hrycaj 2022-04-05 16:19:52 +01:00 committed by GitHub
parent 9e0a9496c5
commit d4ae2328d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
import import
chronicles, httputils, chronos, websock/websock, chronicles, httputils, chronos, websock/[websock, types],
websock/extensions/compression/deflate, websock/extensions/compression/deflate,
stew/byteutils, json_serialization/std/net, stew/byteutils, json_serialization/std/net,
".."/[errors, server] ".."/[errors, server]
@ -11,16 +11,18 @@ logScope:
type type
RpcWebSocketServer* = ref object of RpcServer RpcWebSocketServer* = ref object of RpcServer
authHook: seq[Hook]
server: StreamServer server: StreamServer
wsserver: WSServer wsserver: WSServer
proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
trace "Handling request:", uri = request.uri.path trace "Handling request:", uri = request.uri.path
trace "Initiating web socket connection." trace "Initiating web socket connection."
try: try:
let server = rpc.wsserver let server = rpc.wsserver
let ws = await server.handleRequest(request) let ws = await server.handleRequest(request, hooks = rpc.authHook)
if ws.readyState != Open: if ws.readyState != ReadyState.Open:
error "Failed to open websocket connection" error "Failed to open websocket connection"
return return
@ -58,24 +60,26 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
except WebSocketError as exc: except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg error "WebSocket error:", exception = exc.msg
proc initWebsocket(rpc: RpcWebSocketServer, compression: bool) = proc initWebsocket(rpc: RpcWebSocketServer,
compression: bool, authHandler: seq[Hook]) =
if compression: if compression:
let deflateFactory = deflateFactory() let deflateFactory = deflateFactory()
rpc.wsserver = WSServer.new(factories = [deflateFactory]) rpc.wsserver = WSServer.new(factories = [deflateFactory])
else: else:
rpc.wsserver = WSServer.new() rpc.wsserver = WSServer.new()
rpc.authHook = authHandler
proc newRpcWebSocketServer*( proc newRpcWebSocketServer*(
address: TransportAddress, address: TransportAddress,
compression: bool = false, compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr},
ServerFlags.ReuseAddr}): RpcWebSocketServer = authHandler: seq[Hook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer) var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] = proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request) handleRequest(server, request)
server.initWebsocket(compression) server.initWebsocket(compression, authHandler)
server.server = HttpServer.create( server.server = HttpServer.create(
address, address,
processCallback, processCallback,
@ -88,13 +92,14 @@ proc newRpcWebSocketServer*(
host: string, host: string,
port: Port, port: Port,
compression: bool = false, compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
ServerFlags.ReuseAddr}): RpcWebSocketServer = authHandler: seq[Hook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer( newRpcWebSocketServer(
initTAddress(host, port), initTAddress(host, port),
compression, compression,
flags flags,
authHandler
) )
proc newRpcWebSocketServer*( proc newRpcWebSocketServer*(
@ -106,13 +111,14 @@ proc newRpcWebSocketServer*(
ServerFlags.ReuseAddr}, ServerFlags.ReuseAddr},
tlsFlags: set[TLSFlags] = {}, tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12, tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer = tlsMaxVersion = TLSVersion.TLS12,
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer) var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] = proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request) handleRequest(server, request)
server.initWebsocket(compression) server.initWebsocket(compression, authHandler)
server.server = TlsHttpServer.create( server.server = TlsHttpServer.create(
address, address,
tlsPrivateKey, tlsPrivateKey,
@ -136,7 +142,8 @@ proc newRpcWebSocketServer*(
ServerFlags.ReuseAddr}, ServerFlags.ReuseAddr},
tlsFlags: set[TLSFlags] = {}, tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12, tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer = tlsMaxVersion = TLSVersion.TLS12,
authHandler: seq[Hook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer( newRpcWebSocketServer(
initTAddress(host, port), initTAddress(host, port),
@ -146,7 +153,8 @@ proc newRpcWebSocketServer*(
flags, flags,
tlsFlags, tlsFlags,
tlsMinVersion, tlsMinVersion,
tlsMaxVersion tlsMaxVersion,
authHandler
) )
proc start*(server: RpcWebSocketServer) = proc start*(server: RpcWebSocketServer) =