From d4ae2328d4247c59cefd8d5e0fbc3f178a0eb4ef Mon Sep 17 00:00:00 2001 From: Jordan Hrycaj Date: Tue, 5 Apr 2022 16:19:52 +0100 Subject: [PATCH] Add server Hook for authentication (#133) why: JWT authentication needs that --- json_rpc/servers/websocketserver.nim | 36 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/json_rpc/servers/websocketserver.nim b/json_rpc/servers/websocketserver.nim index a4b9ef4..6dfa477 100644 --- a/json_rpc/servers/websocketserver.nim +++ b/json_rpc/servers/websocketserver.nim @@ -1,5 +1,5 @@ import - chronicles, httputils, chronos, websock/websock, + chronicles, httputils, chronos, websock/[websock, types], websock/extensions/compression/deflate, stew/byteutils, json_serialization/std/net, ".."/[errors, server] @@ -11,16 +11,18 @@ logScope: type RpcWebSocketServer* = ref object of RpcServer + authHook: seq[Hook] server: StreamServer wsserver: WSServer proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = trace "Handling request:", uri = request.uri.path trace "Initiating web socket connection." + try: let server = rpc.wsserver - let ws = await server.handleRequest(request) - if ws.readyState != Open: + let ws = await server.handleRequest(request, hooks = rpc.authHook) + if ws.readyState != ReadyState.Open: error "Failed to open websocket connection" return @@ -58,24 +60,26 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = except WebSocketError as exc: error "WebSocket error:", exception = exc.msg -proc initWebsocket(rpc: RpcWebSocketServer, compression: bool) = +proc initWebsocket(rpc: RpcWebSocketServer, + compression: bool, authHandler: seq[Hook]) = if compression: let deflateFactory = deflateFactory() rpc.wsserver = WSServer.new(factories = [deflateFactory]) else: rpc.wsserver = WSServer.new() + rpc.authHook = authHandler proc newRpcWebSocketServer*( address: TransportAddress, compression: bool = false, - flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, - ServerFlags.ReuseAddr}): RpcWebSocketServer = + flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr}, + authHandler: seq[Hook] = @[]): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = handleRequest(server, request) - server.initWebsocket(compression) + server.initWebsocket(compression, authHandler) server.server = HttpServer.create( address, processCallback, @@ -88,13 +92,14 @@ proc newRpcWebSocketServer*( host: string, port: Port, compression: bool = false, - flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, - ServerFlags.ReuseAddr}): RpcWebSocketServer = + flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}, + authHandler: seq[Hook] = @[]): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port), compression, - flags + flags, + authHandler ) proc newRpcWebSocketServer*( @@ -106,13 +111,14 @@ proc newRpcWebSocketServer*( ServerFlags.ReuseAddr}, tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer = + tlsMaxVersion = TLSVersion.TLS12, + authHandler: seq[Hook] = @[]): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = handleRequest(server, request) - server.initWebsocket(compression) + server.initWebsocket(compression, authHandler) server.server = TlsHttpServer.create( address, tlsPrivateKey, @@ -136,7 +142,8 @@ proc newRpcWebSocketServer*( ServerFlags.ReuseAddr}, tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12): RpcWebSocketServer = + tlsMaxVersion = TLSVersion.TLS12, + authHandler: seq[Hook] = @[]): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port), @@ -146,7 +153,8 @@ proc newRpcWebSocketServer*( flags, tlsFlags, tlsMinVersion, - tlsMaxVersion + tlsMaxVersion, + authHandler ) proc start*(server: RpcWebSocketServer) =