implement hook to handle CORS and JWT auth

- fixes #138
- fixes #126
- fixes #38
This commit is contained in:
jangko 2022-07-17 12:41:18 +07:00
parent 12e921c2ea
commit 0fee4be2cc
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
6 changed files with 156 additions and 63 deletions

View File

@ -128,6 +128,7 @@ else:
proc connect*( proc connect*(
client: RpcWebSocketClient, uri: string, client: RpcWebSocketClient, uri: string,
compression = false, compression = false,
hooks: seq[Hook] = @[],
flags: set[TLSFlags] = {NoVerifyHost, NoVerifyServerName}) {.async.} = flags: set[TLSFlags] = {NoVerifyHost, NoVerifyServerName}) {.async.} =
var ext: seq[ExtFactory] = if compression: @[deflateFactory()] var ext: seq[ExtFactory] = if compression: @[deflateFactory()]
else: @[] else: @[]
@ -135,6 +136,7 @@ else:
let ws = await WebSocket.connect( let ws = await WebSocket.connect(
uri=uri, uri=uri,
factories=ext, factories=ext,
hooks=hooks,
flags=flags flags=flags
) )
client.transport = ws client.transport = ws

View File

@ -40,7 +40,7 @@ proc getWebSocketClientConfig*(
uri: string, uri: string,
compression: bool = false, compression: bool = false,
flags: set[TLSFlags] = { flags: set[TLSFlags] = {
NoVerifyHost, NoVerifyServerName}): ClientConfig = TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}): ClientConfig =
ClientConfig(kind: WebSocket, wsUri: uri, compression: compression, flags: flags) ClientConfig(kind: WebSocket, wsUri: uri, compression: compression, flags: flags)
proc proxyCall(client: RpcClient, name: string): RpcProc = proc proxyCall(client: RpcClient, name: string): RpcProc =
@ -82,7 +82,10 @@ proc connectToProxy(proxy: RpcProxy): Future[void] =
of Http: of Http:
return proxy.httpClient.connect(proxy.httpUri) return proxy.httpClient.connect(proxy.httpUri)
of WebSocket: of WebSocket:
return proxy.webSocketClient.connect(proxy.wsUri, proxy.compression, proxy.flags) return proxy.webSocketClient.connect(
uri = proxy.wsUri,
compression = proxy.compression,
flags = proxy.flags)
proc start*(proxy: RpcProxy) {.async.} = proc start*(proxy: RpcProxy) {.async.} =
proxy.rpcHttpServer.start() proxy.rpcHttpServer.start()

View File

@ -5,7 +5,7 @@ import
chronos/apps/http/[httpserver, shttpserver], chronos/apps/http/[httpserver, shttpserver],
".."/[errors, server] ".."/[errors, server]
export server export server, shttpserver
logScope: logScope:
topics = "JSONRPC-HTTP-SERVER" topics = "JSONRPC-HTTP-SERVER"
@ -14,13 +14,31 @@ type
ReqStatus = enum ReqStatus = enum
Success, Error, ErrorFailure Success, Error, ErrorFailure
# HttpAuthHook: handle CORS, JWT auth, etc. in HTTP header
# before actual request processed
# return value:
# - nil: auth success, continue execution
# - HttpResponse: could not authenticate, stop execution
# and return the response
HttpAuthHook* = proc(request: HttpRequestRef): Future[HttpResponseRef]
{.gcsafe, raises: [Defect, CatchableError].}
RpcHttpServer* = ref object of RpcServer RpcHttpServer* = ref object of RpcServer
httpServers: seq[HttpServerRef] httpServers: seq[HttpServerRef]
authHooks: seq[HttpAuthHook]
proc processClientRpc(rpcServer: RpcServer): HttpProcessCallback = proc processClientRpc(rpcServer: RpcHttpServer): HttpProcessCallback =
return proc (req: RequestFence): Future[HttpResponseRef] {.async.} = return proc (req: RequestFence): Future[HttpResponseRef] {.async.} =
if req.isOk(): if req.isOk():
let request = req.get() let request = req.get()
# if hook result is not nil,
# it means we should return immediately
for hook in rpcServer.authHooks:
let res = await hook(request)
if not res.isNil:
return res
let body = await request.getBody() let body = await request.getBody()
let future = rpcServer.route(string.fromBytes(body)) let future = rpcServer.route(string.fromBytes(body))
@ -177,36 +195,36 @@ proc addSecureHttpServer*(server: RpcHttpServer,
for a in resolvedAddresses(address, port): for a in resolvedAddresses(address, port):
server.addSecureHttpServer(a, tlsPrivateKey, tlsCertificate) server.addSecureHttpServer(a, tlsPrivateKey, tlsCertificate)
proc new*(T: type RpcHttpServer): T = proc new*(T: type RpcHttpServer, authHooks: seq[HttpAuthHook] = @[]): T =
T(router: RpcRouter.init(), httpServers: @[]) T(router: RpcRouter.init(), httpServers: @[], authHooks: authHooks)
proc new*(T: type RpcHttpServer, router: RpcRouter): T = proc new*(T: type RpcHttpServer, router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): T =
T(router: router, httpServers: @[]) T(router: router, httpServers: @[], authHooks: authHooks)
proc newRpcHttpServer*(): RpcHttpServer = proc newRpcHttpServer*(authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
RpcHttpServer.new() RpcHttpServer.new(authHooks)
proc newRpcHttpServer*(router: RpcRouter): RpcHttpServer = proc newRpcHttpServer*(router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
RpcHttpServer.new(router) RpcHttpServer.new(router, authHooks)
proc newRpcHttpServer*(addresses: openArray[TransportAddress]): RpcHttpServer = proc newRpcHttpServer*(addresses: openArray[TransportAddress], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``. ## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer() result = newRpcHttpServer(authHooks)
result.addHttpServers(addresses) result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[string]): RpcHttpServer = proc newRpcHttpServer*(addresses: openArray[string], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``. ## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer() result = newRpcHttpServer(authHooks)
result.addHttpServers(addresses) result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter): RpcHttpServer = proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``. ## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer(router) result = newRpcHttpServer(router, authHooks)
result.addHttpServers(addresses) result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter): RpcHttpServer = proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``. ## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer(router) result = newRpcHttpServer(router, authHooks)
result.addHttpServers(addresses) result.addHttpServers(addresses)
proc start*(server: RpcHttpServer) = proc start*(server: RpcHttpServer) =

View File

@ -10,51 +10,34 @@ logScope:
topics = "JSONRPC-WS-SERVER" topics = "JSONRPC-WS-SERVER"
type type
RpcWebSocketServerAuth* = ##\ # WsAuthHook: handle CORS, JWT auth, etc. in HTTP header
## Authenticator function. On error, the resulting `HttpCode` is sent back\ # before actual request processed
## to the client and the `string` argument will be used in an exception,\ # return value:
## following. # - true: auth success, continue execution
proc(req: HttpTable): Result[void,(HttpCode,string)] # - false: could not authenticate, stop execution
{.gcsafe, raises: [Defect].} # and return the response
WsAuthHook* = proc(request: HttpRequest): Future[bool]
{.gcsafe, raises: [Defect, CatchableError].}
RpcWebSocketServer* = ref object of RpcServer RpcWebSocketServer* = ref object of RpcServer
authHook: Option[RpcWebSocketServerAuth] ## Authorization call back handler
server: StreamServer server: StreamServer
wsserver: WSServer wsserver: WSServer
authHooks: seq[WsAuthHook]
HookEx = ref object of Hook
handler: RpcWebSocketServerAuth ## from `RpcWebSocketServer`
request: HttpRequest ## current request needed for error response
proc authWithHtCodeResponse(ctx: Hook, headers: HttpTable):
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
## Wrapper around authorization handler which is stored in the
## extended `Hook` object.
let
cty = ctx.HookEx
rc = cty.handler(headers)
if rc.isErr:
await cty.request.stream.writer.sendError(rc.error[0])
return err(rc.error[1])
return ok()
proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} = proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
trace "Handling request:", uri = request.uri.path trace "Handling request:", uri = request.uri.path
trace "Initiating web socket connection." trace "Initiating web socket connection."
# Authorization handler constructor (if enabled) # if hook result is false,
var hooks: seq[Hook] # it means we should return immediately
if rpc.authHook.isSome: for hook in rpc.authHooks:
let hookEx = HookEx( let res = await hook(request)
append: nil, if not res:
request: request, return
handler: rpc.authHook.get,
verify: authWithHtCodeResponse)
hooks = @[hookEx.Hook]
try: try:
let server = rpc.wsserver let server = rpc.wsserver
let ws = await server.handleRequest(request, hooks = hooks) let ws = await server.handleRequest(request)
if ws.readyState != ReadyState.Open: if ws.readyState != ReadyState.Open:
error "Failed to open websocket connection" error "Failed to open websocket connection"
return return
@ -94,25 +77,25 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
error "WebSocket error:", exception = exc.msg error "WebSocket error:", exception = exc.msg
proc initWebsocket(rpc: RpcWebSocketServer, compression: bool, proc initWebsocket(rpc: RpcWebSocketServer, compression: bool,
authHandler: Option[RpcWebSocketServerAuth]) = authHooks: seq[WsAuthHook]) =
if compression: if compression:
let deflateFactory = deflateFactory() let deflateFactory = deflateFactory()
rpc.wsserver = WSServer.new(factories = [deflateFactory]) rpc.wsserver = WSServer.new(factories = [deflateFactory])
else: else:
rpc.wsserver = WSServer.new() rpc.wsserver = WSServer.new()
rpc.authHook = authHandler rpc.authHooks = authHooks
proc newRpcWebSocketServer*( proc newRpcWebSocketServer*(
address: TransportAddress, address: TransportAddress,
compression: bool = false, compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr}, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr},
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer) var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] = proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request) handleRequest(server, request)
server.initWebsocket(compression, authHandler) server.initWebsocket(compression, authHooks)
server.server = HttpServer.create( server.server = HttpServer.create(
address, address,
processCallback, processCallback,
@ -126,13 +109,13 @@ proc newRpcWebSocketServer*(
port: Port, port: Port,
compression: bool = false, compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}, flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer( newRpcWebSocketServer(
initTAddress(host, port), initTAddress(host, port),
compression, compression,
flags, flags,
authHandler authHooks
) )
proc newRpcWebSocketServer*( proc newRpcWebSocketServer*(
@ -145,13 +128,13 @@ proc newRpcWebSocketServer*(
tlsFlags: set[TLSFlags] = {}, tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12, tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12,
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer) var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] = proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request) handleRequest(server, request)
server.initWebsocket(compression, authHandler) server.initWebsocket(compression, authHooks)
server.server = TlsHttpServer.create( server.server = TlsHttpServer.create(
address, address,
tlsPrivateKey, tlsPrivateKey,
@ -176,7 +159,7 @@ proc newRpcWebSocketServer*(
tlsFlags: set[TLSFlags] = {}, tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12, tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12, tlsMaxVersion = TLSVersion.TLS12,
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer = authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer( newRpcWebSocketServer(
initTAddress(host, port), initTAddress(host, port),
@ -187,7 +170,7 @@ proc newRpcWebSocketServer*(
tlsFlags, tlsFlags,
tlsMinVersion, tlsMinVersion,
tlsMaxVersion, tlsMaxVersion,
authHandler authHooks
) )
proc start*(server: RpcWebSocketServer) = proc start*(server: RpcWebSocketServer) =

View File

@ -9,3 +9,4 @@ import
when not useNews: when not useNews:
# The proxy implementation is based on websock # The proxy implementation is based on websock
import testproxy import testproxy
import testhook

86
tests/testhook.nim Normal file
View File

@ -0,0 +1,86 @@
import
unittest, json, chronicles,
websock/websock,
../json_rpc/[rpcclient, rpcserver, clients/config]
const
serverHost = "localhost"
serverPort = 8547
serverAddress = serverHost & ":" & $serverPort
proc setupServer*(srv: RpcServer) =
srv.rpc("testHook") do(input: string):
return %("Hello " & input)
proc authHeaders(): seq[(string, string)] =
@[("Auth-Token", "Good Token")]
suite "HTTP server hook test":
proc mockAuth(req: HttpRequestRef): Future[HttpResponseRef] {.async.} =
if req.headers.getString("Auth-Token") == "Good Token":
return HttpResponseRef(nil)
return await req.respond(Http401, "Unauthorized access")
let srv = newRpcHttpServer([serverAddress], @[HttpAuthHook(mockAuth)])
srv.setupServer()
srv.start()
test "no auth token":
let client = newRpcHttpClient()
waitFor client.connect(serverHost, Port(serverPort), false)
expect ErrorResponse:
let r = waitFor client.call("testHook", %[%"abc"])
test "good auth token":
let client = newRpcHttpClient(getHeaders = authHeaders)
waitFor client.connect(serverHost, Port(serverPort), false)
let r = waitFor client.call("testHook", %[%"abc"])
check r.getStr == "Hello abc"
waitFor srv.closeWait()
proc wsAuthHeaders(ctx: Hook,
headers: var HttpTable): Result[void, string]
{.gcsafe, raises: [Defect].} =
headers.add("Auth-Token", "Good Token")
return ok()
suite "Websocket server hook test":
let hook = Hook(append: wsAuthHeaders)
proc mockAuth(req: websock.HttpRequest): Future[bool] {.async.} =
if not req.headers.contains("Auth-Token"):
await req.sendResponse(code = Http403, data = "Missing Auth-Token")
return false
let token = req.headers.getString("Auth-Token")
if token != "Good Token":
await req.sendResponse(code = Http401, data = "Unauthorized access")
return false
return true
let srv = newRpcWebSocketServer(
"127.0.0.1",
Port(8545),
authHooks = @[WsAuthHook(mockAuth)]
)
srv.setupServer()
srv.start()
let client = newRpcWebSocketClient()
test "no auth token":
try:
waitFor client.connect("ws://127.0.0.1:8545/")
check false
except CatchableError as e:
check e.msg == "Server did not reply with a websocket upgrade: Header code: 403 Header reason: Forbidden Address: 127.0.0.1:8545"
test "good auth token":
waitFor client.connect("ws://127.0.0.1:8545/", hooks = @[hook])
let r = waitFor client.call("testHook", %[%"abc"])
check r.getStr == "Hello abc"
srv.stop()
waitFor srv.closeWait()