implement hook to handle CORS and JWT auth

- fixes #138
- fixes #126
- fixes #38
This commit is contained in:
jangko 2022-07-17 12:41:18 +07:00
parent 12e921c2ea
commit 0fee4be2cc
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
6 changed files with 156 additions and 63 deletions

View File

@ -128,6 +128,7 @@ else:
proc connect*(
client: RpcWebSocketClient, uri: string,
compression = false,
hooks: seq[Hook] = @[],
flags: set[TLSFlags] = {NoVerifyHost, NoVerifyServerName}) {.async.} =
var ext: seq[ExtFactory] = if compression: @[deflateFactory()]
else: @[]
@ -135,6 +136,7 @@ else:
let ws = await WebSocket.connect(
uri=uri,
factories=ext,
hooks=hooks,
flags=flags
)
client.transport = ws

View File

@ -40,7 +40,7 @@ proc getWebSocketClientConfig*(
uri: string,
compression: bool = false,
flags: set[TLSFlags] = {
NoVerifyHost, NoVerifyServerName}): ClientConfig =
TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}): ClientConfig =
ClientConfig(kind: WebSocket, wsUri: uri, compression: compression, flags: flags)
proc proxyCall(client: RpcClient, name: string): RpcProc =
@ -82,7 +82,10 @@ proc connectToProxy(proxy: RpcProxy): Future[void] =
of Http:
return proxy.httpClient.connect(proxy.httpUri)
of WebSocket:
return proxy.webSocketClient.connect(proxy.wsUri, proxy.compression, proxy.flags)
return proxy.webSocketClient.connect(
uri = proxy.wsUri,
compression = proxy.compression,
flags = proxy.flags)
proc start*(proxy: RpcProxy) {.async.} =
proxy.rpcHttpServer.start()

View File

@ -5,7 +5,7 @@ import
chronos/apps/http/[httpserver, shttpserver],
".."/[errors, server]
export server
export server, shttpserver
logScope:
topics = "JSONRPC-HTTP-SERVER"
@ -14,13 +14,31 @@ type
ReqStatus = enum
Success, Error, ErrorFailure
# HttpAuthHook: handle CORS, JWT auth, etc. in HTTP header
# before actual request processed
# return value:
# - nil: auth success, continue execution
# - HttpResponse: could not authenticate, stop execution
# and return the response
HttpAuthHook* = proc(request: HttpRequestRef): Future[HttpResponseRef]
{.gcsafe, raises: [Defect, CatchableError].}
RpcHttpServer* = ref object of RpcServer
httpServers: seq[HttpServerRef]
authHooks: seq[HttpAuthHook]
proc processClientRpc(rpcServer: RpcServer): HttpProcessCallback =
proc processClientRpc(rpcServer: RpcHttpServer): HttpProcessCallback =
return proc (req: RequestFence): Future[HttpResponseRef] {.async.} =
if req.isOk():
let request = req.get()
# if hook result is not nil,
# it means we should return immediately
for hook in rpcServer.authHooks:
let res = await hook(request)
if not res.isNil:
return res
let body = await request.getBody()
let future = rpcServer.route(string.fromBytes(body))
@ -177,36 +195,36 @@ proc addSecureHttpServer*(server: RpcHttpServer,
for a in resolvedAddresses(address, port):
server.addSecureHttpServer(a, tlsPrivateKey, tlsCertificate)
proc new*(T: type RpcHttpServer): T =
T(router: RpcRouter.init(), httpServers: @[])
proc new*(T: type RpcHttpServer, authHooks: seq[HttpAuthHook] = @[]): T =
T(router: RpcRouter.init(), httpServers: @[], authHooks: authHooks)
proc new*(T: type RpcHttpServer, router: RpcRouter): T =
T(router: router, httpServers: @[])
proc new*(T: type RpcHttpServer, router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): T =
T(router: router, httpServers: @[], authHooks: authHooks)
proc newRpcHttpServer*(): RpcHttpServer =
RpcHttpServer.new()
proc newRpcHttpServer*(authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
RpcHttpServer.new(authHooks)
proc newRpcHttpServer*(router: RpcRouter): RpcHttpServer =
RpcHttpServer.new(router)
proc newRpcHttpServer*(router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
RpcHttpServer.new(router, authHooks)
proc newRpcHttpServer*(addresses: openArray[TransportAddress]): RpcHttpServer =
proc newRpcHttpServer*(addresses: openArray[TransportAddress], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer()
result = newRpcHttpServer(authHooks)
result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[string]): RpcHttpServer =
proc newRpcHttpServer*(addresses: openArray[string], authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer()
result = newRpcHttpServer(authHooks)
result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter): RpcHttpServer =
proc newRpcHttpServer*(addresses: openArray[string], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer(router)
result = newRpcHttpServer(router, authHooks)
result.addHttpServers(addresses)
proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter): RpcHttpServer =
proc newRpcHttpServer*(addresses: openArray[TransportAddress], router: RpcRouter, authHooks: seq[HttpAuthHook] = @[]): RpcHttpServer =
## Create new server and assign it to addresses ``addresses``.
result = newRpcHttpServer(router)
result = newRpcHttpServer(router, authHooks)
result.addHttpServers(addresses)
proc start*(server: RpcHttpServer) =

View File

@ -10,51 +10,34 @@ logScope:
topics = "JSONRPC-WS-SERVER"
type
RpcWebSocketServerAuth* = ##\
## Authenticator function. On error, the resulting `HttpCode` is sent back\
## to the client and the `string` argument will be used in an exception,\
## following.
proc(req: HttpTable): Result[void,(HttpCode,string)]
{.gcsafe, raises: [Defect].}
# WsAuthHook: handle CORS, JWT auth, etc. in HTTP header
# before actual request processed
# return value:
# - true: auth success, continue execution
# - false: could not authenticate, stop execution
# and return the response
WsAuthHook* = proc(request: HttpRequest): Future[bool]
{.gcsafe, raises: [Defect, CatchableError].}
RpcWebSocketServer* = ref object of RpcServer
authHook: Option[RpcWebSocketServerAuth] ## Authorization call back handler
server: StreamServer
wsserver: WSServer
HookEx = ref object of Hook
handler: RpcWebSocketServerAuth ## from `RpcWebSocketServer`
request: HttpRequest ## current request needed for error response
proc authWithHtCodeResponse(ctx: Hook, headers: HttpTable):
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
## Wrapper around authorization handler which is stored in the
## extended `Hook` object.
let
cty = ctx.HookEx
rc = cty.handler(headers)
if rc.isErr:
await cty.request.stream.writer.sendError(rc.error[0])
return err(rc.error[1])
return ok()
authHooks: seq[WsAuthHook]
proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
trace "Handling request:", uri = request.uri.path
trace "Initiating web socket connection."
# Authorization handler constructor (if enabled)
var hooks: seq[Hook]
if rpc.authHook.isSome:
let hookEx = HookEx(
append: nil,
request: request,
handler: rpc.authHook.get,
verify: authWithHtCodeResponse)
hooks = @[hookEx.Hook]
# if hook result is false,
# it means we should return immediately
for hook in rpc.authHooks:
let res = await hook(request)
if not res:
return
try:
let server = rpc.wsserver
let ws = await server.handleRequest(request, hooks = hooks)
let ws = await server.handleRequest(request)
if ws.readyState != ReadyState.Open:
error "Failed to open websocket connection"
return
@ -94,25 +77,25 @@ proc handleRequest(rpc: RpcWebSocketServer, request: HttpRequest) {.async.} =
error "WebSocket error:", exception = exc.msg
proc initWebsocket(rpc: RpcWebSocketServer, compression: bool,
authHandler: Option[RpcWebSocketServerAuth]) =
authHooks: seq[WsAuthHook]) =
if compression:
let deflateFactory = deflateFactory()
rpc.wsserver = WSServer.new(factories = [deflateFactory])
else:
rpc.wsserver = WSServer.new()
rpc.authHook = authHandler
rpc.authHooks = authHooks
proc newRpcWebSocketServer*(
address: TransportAddress,
compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay,ServerFlags.ReuseAddr},
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer =
authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request)
server.initWebsocket(compression, authHandler)
server.initWebsocket(compression, authHooks)
server.server = HttpServer.create(
address,
processCallback,
@ -126,13 +109,13 @@ proc newRpcWebSocketServer*(
port: Port,
compression: bool = false,
flags: set[ServerFlags] = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr},
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer =
authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer(
initTAddress(host, port),
compression,
flags,
authHandler
authHooks
)
proc newRpcWebSocketServer*(
@ -145,13 +128,13 @@ proc newRpcWebSocketServer*(
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12,
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer =
authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
var server = new(RpcWebSocketServer)
proc processCallback(request: HttpRequest): Future[void] =
handleRequest(server, request)
server.initWebsocket(compression, authHandler)
server.initWebsocket(compression, authHooks)
server.server = TlsHttpServer.create(
address,
tlsPrivateKey,
@ -176,7 +159,7 @@ proc newRpcWebSocketServer*(
tlsFlags: set[TLSFlags] = {},
tlsMinVersion = TLSVersion.TLS12,
tlsMaxVersion = TLSVersion.TLS12,
authHandler = none(RpcWebSocketServerAuth)): RpcWebSocketServer =
authHooks: seq[WsAuthHook] = @[]): RpcWebSocketServer =
newRpcWebSocketServer(
initTAddress(host, port),
@ -187,7 +170,7 @@ proc newRpcWebSocketServer*(
tlsFlags,
tlsMinVersion,
tlsMaxVersion,
authHandler
authHooks
)
proc start*(server: RpcWebSocketServer) =

View File

@ -9,3 +9,4 @@ import
when not useNews:
# The proxy implementation is based on websock
import testproxy
import testhook

86
tests/testhook.nim Normal file
View File

@ -0,0 +1,86 @@
import
unittest, json, chronicles,
websock/websock,
../json_rpc/[rpcclient, rpcserver, clients/config]
const
serverHost = "localhost"
serverPort = 8547
serverAddress = serverHost & ":" & $serverPort
proc setupServer*(srv: RpcServer) =
srv.rpc("testHook") do(input: string):
return %("Hello " & input)
proc authHeaders(): seq[(string, string)] =
@[("Auth-Token", "Good Token")]
suite "HTTP server hook test":
proc mockAuth(req: HttpRequestRef): Future[HttpResponseRef] {.async.} =
if req.headers.getString("Auth-Token") == "Good Token":
return HttpResponseRef(nil)
return await req.respond(Http401, "Unauthorized access")
let srv = newRpcHttpServer([serverAddress], @[HttpAuthHook(mockAuth)])
srv.setupServer()
srv.start()
test "no auth token":
let client = newRpcHttpClient()
waitFor client.connect(serverHost, Port(serverPort), false)
expect ErrorResponse:
let r = waitFor client.call("testHook", %[%"abc"])
test "good auth token":
let client = newRpcHttpClient(getHeaders = authHeaders)
waitFor client.connect(serverHost, Port(serverPort), false)
let r = waitFor client.call("testHook", %[%"abc"])
check r.getStr == "Hello abc"
waitFor srv.closeWait()
proc wsAuthHeaders(ctx: Hook,
headers: var HttpTable): Result[void, string]
{.gcsafe, raises: [Defect].} =
headers.add("Auth-Token", "Good Token")
return ok()
suite "Websocket server hook test":
let hook = Hook(append: wsAuthHeaders)
proc mockAuth(req: websock.HttpRequest): Future[bool] {.async.} =
if not req.headers.contains("Auth-Token"):
await req.sendResponse(code = Http403, data = "Missing Auth-Token")
return false
let token = req.headers.getString("Auth-Token")
if token != "Good Token":
await req.sendResponse(code = Http401, data = "Unauthorized access")
return false
return true
let srv = newRpcWebSocketServer(
"127.0.0.1",
Port(8545),
authHooks = @[WsAuthHook(mockAuth)]
)
srv.setupServer()
srv.start()
let client = newRpcWebSocketClient()
test "no auth token":
try:
waitFor client.connect("ws://127.0.0.1:8545/")
check false
except CatchableError as e:
check e.msg == "Server did not reply with a websocket upgrade: Header code: 403 Header reason: Forbidden Address: 127.0.0.1:8545"
test "good auth token":
waitFor client.connect("ws://127.0.0.1:8545/", hooks = @[hook])
let r = waitFor client.call("testHook", %[%"abc"])
check r.getStr == "Hello abc"
srv.stop()
waitFor srv.closeWait()