diff --git a/json_rpc/servers/websocketserver.nim b/json_rpc/servers/websocketserver.nim index 6dfa477..eb53c0b 100644 --- a/json_rpc/servers/websocketserver.nim +++ b/json_rpc/servers/websocketserver.nim @@ -10,18 +10,51 @@ logScope: topics = "JSONRPC-WS-SERVER" type + RpcWebSocketServerAuth* = ##\ + ## Authenticator function. On error, the resulting `HttpCode` is sent back\ + ## to the client and the `string` argument will be used in an exception,\ + ## following. + proc(req: HttpTable): Result[void,(HttpCode,string)] + {.gcsafe, raises: [Defect].} + RpcWebSocketServer* = ref object of RpcServer - authHook: seq[Hook] + authHook: Option[RpcWebSocketServerAuth] ## Authorization call back handler server: StreamServer wsserver: WSServer + HookEx = ref object of Hook + handler: RpcWebSocketServerAuth ## from `RpcWebSocketServer` + request: HttpRequest ## current request needed for error response + +proc authWithHtCodeResponse(ctx: Hook, headers: HttpTable): + Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} = + ## Wrapper around authorization handler which is stored in the + ## extended `Hook` object. + let + cty = ctx.HookEx + rc = cty.handler(headers) + if rc.isErr: + await cty.request.stream.writer.sendError(rc.error[0]) + return err(rc.error[1]) + return ok() + proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = trace "Handling request:", uri = request.uri.path trace "Initiating web socket connection." + # Authorization handler constructor (if enabled) + var hooks: seq[Hook] + if rpc.authHook.isSome: + let hookEx = HookEx( + append: nil, + request: request, + handler: rpc.authHook.get, + verify: authWithHtCodeResponse) + hooks = @[hookEx.Hook] + try: let server = rpc.wsserver - let ws = await server.handleRequest(request, hooks = rpc.authHook) + let ws = await server.handleRequest(request, hooks = hooks) if ws.readyState != ReadyState.Open: error "Failed to open websocket connection" return @@ -60,8 +93,8 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = except WebSocketError as exc: error "WebSocket error:", exception = exc.msg -proc initWebsocket(rpc: RpcWebSocketServer, - compression: bool, authHandler: seq[Hook]) = +proc initWebsocket(rpc: RpcWebSocketServer, compression: bool, + authHandler: Option[RpcWebSocketServerAuth]) = if compression: let deflateFactory = deflateFactory() rpc.wsserver = WSServer.new(factories = [deflateFactory]) @@ -73,7 +106,7 @@ proc newRpcWebSocketServer*( address: TransportAddress, compression: bool = false, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr}, - authHandler: seq[Hook] = @[]): RpcWebSocketServer = + authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = @@ -93,7 +126,7 @@ proc newRpcWebSocketServer*( port: Port, compression: bool = false, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}, - authHandler: seq[Hook] = @[]): RpcWebSocketServer = + authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port), @@ -112,7 +145,7 @@ proc newRpcWebSocketServer*( tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12, - authHandler: seq[Hook] = @[]): RpcWebSocketServer = + authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = @@ -143,7 +176,7 @@ proc newRpcWebSocketServer*( tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12, - authHandler: seq[Hook] = @[]): RpcWebSocketServer = + authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port),