From 0fee4be2ccbbfa3f158b741c3a34b04130c8b424 Mon Sep 17 00:00:00 2001 From: jangko Date: Sun, 17 Jul 2022 12:41:18 +0700 Subject: [PATCH] implement hook to handle CORS and JWT auth - fixes #138 - fixes #126 - fixes #38 --- json_rpc/clients/websocketclient.nim | 2 + json_rpc/rpcproxy.nim | 7 ++- json_rpc/servers/httpserver.nim | 54 +++++++++++------ json_rpc/servers/websocketserver.nim | 69 +++++++++------------- tests/all.nim | 1 + tests/testhook.nim | 86 ++++++++++++++++++++++++++++ 6 files changed, 156 insertions(+), 63 deletions(-) create mode 100644 tests/testhook.nim diff --git a/json_rpc/clients/websocketclient.nim b/json_rpc/clients/websocketclient.nim index 9cbecc0..0609a79 100644 --- a/json_rpc/clients/websocketclient.nim +++ b/json_rpc/clients/websocketclient.nim @@ -128,6 +128,7 @@ else: proc connect*( client: RpcWebSocketClient, uri: string, compression = false, + hooks: seq[Hook] = @[], flags: set[TLSFlags] = {NoVerifyHost, NoVerifyServerName}) {.async.} = var ext: seq[ExtFactory] = if compression: @[deflateFactory()] else: @[] @@ -135,6 +136,7 @@ else: let ws = await WebSocket.connect( uri=uri, factories=ext, + hooks=hooks, flags=flags ) client.transport = ws diff --git a/json_rpc/rpcproxy.nim b/json_rpc/rpcproxy.nim index f22b68e..f4a03bd 100644 --- a/json_rpc/rpcproxy.nim +++ b/json_rpc/rpcproxy.nim @@ -40,7 +40,7 @@ proc getWebSocketClientConfig*( uri: string, compression: bool = false, flags: set[TLSFlags] = { - NoVerifyHost, NoVerifyServerName}): ClientConfig = + TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}): ClientConfig = ClientConfig(kind: WebSocket, wsUri: uri, compression: compression, flags: flags) proc proxyCall(client: RpcClient, name: string): RpcProc = @@ -82,7 +82,10 @@ proc connectToProxy(proxy: RpcProxy): Future[void] = of Http: return proxy.httpClient.connect(proxy.httpUri) of WebSocket: - return proxy.webSocketClient.connect(proxy.wsUri, proxy.compression, proxy.flags) + return proxy.webSocketClient.connect( + uri = proxy.wsUri, + compression = proxy.compression, + flags = proxy.flags) proc start*(proxy: RpcProxy) {.async.} = proxy.rpcHttpServer.start() diff --git a/json_rpc/servers/httpserver.nim b/json_rpc/servers/httpserver.nim index 8cbf295..1e7b2dc 100644 --- a/json_rpc/servers/httpserver.nim +++ b/json_rpc/servers/httpserver.nim @@ -5,7 +5,7 @@ import chronos/apps/http/[httpserver, shttpserver], ".."/[errors, server] -export server +export server, shttpserver logScope: topics = "JSONRPC-HTTP-SERVER" @@ -14,13 +14,31 @@ type ReqStatus = enum Success, Error, ErrorFailure + # HttpAuthHook: handle CORS, JWT auth, etc. in HTTP header + # before actual request processed + # return value: + # - nil: auth success, continue execution + # - HttpResponse: could not authenticate, stop execution + # and return the response + HttpAuthHook* = proc(request: HttpRequestRef): Future[HttpResponseRef] + {.gcsafe, raises: [Defect, CatchableError].} + RpcHttpServer* = ref object of RpcServer httpServers: seq[HttpServerRef] + authHooks: seq[HttpAuthHook] -proc processClientRpc(rpcServer: RpcServer): HttpProcessCallback = +proc processClientRpc(rpcServer: RpcHttpServer): HttpProcessCallback = return proc (req: RequestFence): Future[HttpResponseRef] {.async.} = if req.isOk(): let request = req.get() + + # if hook result is not nil, + # it means we should return immediately + for hook in rpcServer.authHooks: + let res = await hook(request) + if not res.isNil: + return res + let body = await request.getBody() let future = rpcServer.route(string.fromBytes(body)) @@ -177,36 +195,36 @@ proc addSecureHttpServer*(server: RpcHttpServer, for a in resolvedAddresses(address, port): server.addSecureHttpServer(a, tlsPrivateKey, tlsCertificate) -proc new*(T: type RpcHttpServer): T = - T(router: RpcRouter.init(), httpServers: @[]) +proc new*(T: type RpcHttpServer, authHooks: seq[HttpAuthHook] = @[]): T = + T(router: RpcRouter.init(), httpServers: @[], authHooks: authHooks) -proc new*(T: type RpcHttpServer, router: RpcRouter): T = - T(router: router, httpServers: @[]) +proc new*(T: type RpcHttpServer, router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): T = + T(router: router, httpServers: @[], authHooks: authHooks) -proc newRpcHttpServer*(): RpcHttpServer = - RpcHttpServer.new() +proc newRpcHttpServer*(authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = + RpcHttpServer.new(authHooks) -proc newRpcHttpServer*(router: RpcRouter): RpcHttpServer = - RpcHttpServer.new(router) +proc newRpcHttpServer*(router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = + RpcHttpServer.new(router, authHooks) -proc newRpcHttpServer*(addresses: openArray[TransportAddress]): RpcHttpServer = +proc newRpcHttpServer*(addresses: openArray[TransportAddress], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = ## Create new server and assign it to addresses ``addresses``. - result = newRpcHttpServer() + result = newRpcHttpServer(authHooks) result.addHttpServers(addresses) -proc newRpcHttpServer*(addresses: openArray[string]): RpcHttpServer = +proc newRpcHttpServer*(addresses: openArray[string], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = ## Create new server and assign it to addresses ``addresses``. - result = newRpcHttpServer() + result = newRpcHttpServer(authHooks) result.addHttpServers(addresses) -proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter): RpcHttpServer = +proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = ## Create new server and assign it to addresses ``addresses``. - result = newRpcHttpServer(router) + result = newRpcHttpServer(router, authHooks) result.addHttpServers(addresses) -proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter): RpcHttpServer = +proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer = ## Create new server and assign it to addresses ``addresses``. - result = newRpcHttpServer(router) + result = newRpcHttpServer(router, authHooks) result.addHttpServers(addresses) proc start*(server: RpcHttpServer) = diff --git a/json_rpc/servers/websocketserver.nim b/json_rpc/servers/websocketserver.nim index eb53c0b..bc57e09 100644 --- a/json_rpc/servers/websocketserver.nim +++ b/json_rpc/servers/websocketserver.nim @@ -10,51 +10,34 @@ logScope: topics = "JSONRPC-WS-SERVER" type - RpcWebSocketServerAuth* = ##\ - ## Authenticator function. On error, the resulting `HttpCode` is sent back\ - ## to the client and the `string` argument will be used in an exception,\ - ## following. - proc(req: HttpTable): Result[void,(HttpCode,string)] - {.gcsafe, raises: [Defect].} + # WsAuthHook: handle CORS, JWT auth, etc. in HTTP header + # before actual request processed + # return value: + # - true: auth success, continue execution + # - false: could not authenticate, stop execution + # and return the response + WsAuthHook* = proc(request: HttpRequest): Future[bool] + {.gcsafe, raises: [Defect, CatchableError].} RpcWebSocketServer* = ref object of RpcServer - authHook: Option[RpcWebSocketServerAuth] ## Authorization call back handler server: StreamServer wsserver: WSServer - - HookEx = ref object of Hook - handler: RpcWebSocketServerAuth ## from `RpcWebSocketServer` - request: HttpRequest ## current request needed for error response - -proc authWithHtCodeResponse(ctx: Hook, headers: HttpTable): - Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} = - ## Wrapper around authorization handler which is stored in the - ## extended `Hook` object. - let - cty = ctx.HookEx - rc = cty.handler(headers) - if rc.isErr: - await cty.request.stream.writer.sendError(rc.error[0]) - return err(rc.error[1]) - return ok() + authHooks: seq[WsAuthHook] proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = trace "Handling request:", uri = request.uri.path trace "Initiating web socket connection." - # Authorization handler constructor (if enabled) - var hooks: seq[Hook] - if rpc.authHook.isSome: - let hookEx = HookEx( - append: nil, - request: request, - handler: rpc.authHook.get, - verify: authWithHtCodeResponse) - hooks = @[hookEx.Hook] + # if hook result is false, + # it means we should return immediately + for hook in rpc.authHooks: + let res = await hook(request) + if not res: + return try: let server = rpc.wsserver - let ws = await server.handleRequest(request, hooks = hooks) + let ws = await server.handleRequest(request) if ws.readyState != ReadyState.Open: error "Failed to open websocket connection" return @@ -94,25 +77,25 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = error "WebSocket error:", exception = exc.msg proc initWebsocket(rpc: RpcWebSocketServer, compression: bool, - authHandler: Option[RpcWebSocketServerAuth]) = + authHooks: seq[WsAuthHook]) = if compression: let deflateFactory = deflateFactory() rpc.wsserver = WSServer.new(factories = [deflateFactory]) else: rpc.wsserver = WSServer.new() - rpc.authHook = authHandler + rpc.authHooks = authHooks proc newRpcWebSocketServer*( address: TransportAddress, compression: bool = false, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr}, - authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = + authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = handleRequest(server, request) - server.initWebsocket(compression, authHandler) + server.initWebsocket(compression, authHooks) server.server = HttpServer.create( address, processCallback, @@ -126,13 +109,13 @@ proc newRpcWebSocketServer*( port: Port, compression: bool = false, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}, - authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = + authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port), compression, flags, - authHandler + authHooks ) proc newRpcWebSocketServer*( @@ -145,13 +128,13 @@ proc newRpcWebSocketServer*( tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12, - authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = + authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer = var server = new(RpcWebSocketServer) proc processCallback(request: HttpRequest): Future[void] = handleRequest(server, request) - server.initWebsocket(compression, authHandler) + server.initWebsocket(compression, authHooks) server.server = TlsHttpServer.create( address, tlsPrivateKey, @@ -176,7 +159,7 @@ proc newRpcWebSocketServer*( tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12, - authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = + authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer = newRpcWebSocketServer( initTAddress(host, port), @@ -187,7 +170,7 @@ proc newRpcWebSocketServer*( tlsFlags, tlsMinVersion, tlsMaxVersion, - authHandler + authHooks ) proc start*(server: RpcWebSocketServer) = diff --git a/tests/all.nim b/tests/all.nim index a5c83e3..38e23f2 100644 --- a/tests/all.nim +++ b/tests/all.nim @@ -9,3 +9,4 @@ import when not useNews: # The proxy implementation is based on websock import testproxy + import testhook diff --git a/tests/testhook.nim b/tests/testhook.nim new file mode 100644 index 0000000..326f11c --- /dev/null +++ b/tests/testhook.nim @@ -0,0 +1,86 @@ +import + unittest, json, chronicles, + websock/websock, + ../json_rpc/[rpcclient, rpcserver, clients/config] + +const + serverHost = "localhost" + serverPort = 8547 + serverAddress = serverHost & ":" & $serverPort + +proc setupServer*(srv: RpcServer) = + srv.rpc("testHook") do(input: string): + return %("Hello " & input) + +proc authHeaders(): seq[(string, string)] = + @[("Auth-Token", "Good Token")] + +suite "HTTP server hook test": + proc mockAuth(req: HttpRequestRef): Future[HttpResponseRef] {.async.} = + if req.headers.getString("Auth-Token") == "Good Token": + return HttpResponseRef(nil) + + return await req.respond(Http401, "Unauthorized access") + + let srv = newRpcHttpServer([serverAddress], @[HttpAuthHook(mockAuth)]) + srv.setupServer() + srv.start() + + test "no auth token": + let client = newRpcHttpClient() + waitFor client.connect(serverHost, Port(serverPort), false) + expect ErrorResponse: + let r = waitFor client.call("testHook", %[%"abc"]) + + test "good auth token": + let client = newRpcHttpClient(getHeaders = authHeaders) + waitFor client.connect(serverHost, Port(serverPort), false) + let r = waitFor client.call("testHook", %[%"abc"]) + check r.getStr == "Hello abc" + + waitFor srv.closeWait() + +proc wsAuthHeaders(ctx: Hook, + headers: var HttpTable): Result[void, string] + {.gcsafe, raises: [Defect].} = + headers.add("Auth-Token", "Good Token") + return ok() + +suite "Websocket server hook test": + let hook = Hook(append: wsAuthHeaders) + + proc mockAuth(req: websock.HttpRequest): Future[bool] {.async.} = + if not req.headers.contains("Auth-Token"): + await req.sendResponse(code = Http403, data = "Missing Auth-Token") + return false + + let token = req.headers.getString("Auth-Token") + if token != "Good Token": + await req.sendResponse(code = Http401, data = "Unauthorized access") + return false + + return true + + let srv = newRpcWebSocketServer( + "127.0.0.1", + Port(8545), + authHooks = @[WsAuthHook(mockAuth)] + ) + srv.setupServer() + srv.start() + let client = newRpcWebSocketClient() + + test "no auth token": + try: + waitFor client.connect("ws://127.0.0.1:8545/") + check false + except CatchableError as e: + check e.msg == "Server did not reply with a websocket upgrade: Header code: 403 Header reason: Forbidden Address: 127.0.0.1:8545" + + test "good auth token": + waitFor client.connect("ws://127.0.0.1:8545/", hooks = @[hook]) + let r = waitFor client.call("testHook", %[%"abc"]) + check r.getStr == "Hello abc" + + srv.stop() + waitFor srv.closeWait()