add http header hook

fix #101
This commit is contained in:
jangko 2022-03-03 13:21:35 +07:00
parent 6e73e34975
commit 4a7a058843
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
4 changed files with 262 additions and 3 deletions

View File

@ -14,3 +14,4 @@ import ./testutf8
import ./testextutils import ./testextutils
import ./extensions/testexts import ./extensions/testexts
import ./extensions/testcompression import ./extensions/testcompression
import ./testhooks

201
tests/testhooks.nim Normal file
View File

@ -0,0 +1,201 @@
## nim-websock
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import std/strutils
import pkg/[
httputils,
chronos,
chronicles,
stew/byteutils,
asynctest/unittest2]
import ../websock/websock
import ./helpers
let address = initTAddress("127.0.0.1:8888")
type
TokenHook* = ref object of Hook
status: int
token: string
request: HttpRequest
proc clientAppendGoodToken(ctx: Hook, headers: var HttpTable):
Result[void, string] {.gcsafe, raises: [Defect].} =
headers.add("auth-token", "good-token")
return ok()
proc clientAppendBadToken(ctx: Hook, headers: var HttpTable):
Result[void, string] {.gcsafe, raises: [Defect].} =
headers.add("auth-token", "bad-token")
return ok()
proc clientVerify(ctx: Hook, headers: HttpTable):
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
var p = TokenHook(ctx)
p.token = headers.getString("auth-status")
return ok()
proc serverVerify(ctx: Hook, headers: HttpTable):
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
var p = TokenHook(ctx)
if headers.getString("auth-token") == "good-token":
p.status = 101
return ok()
proc serverAppend(ctx: Hook, headers: var HttpTable):
Result[void, string] {.gcsafe, raises: [Defect].} =
var p = TokenHook(ctx)
if p.status == 101:
headers.add("auth-status", "accept")
else:
headers.add("auth-status", "reject")
p.status = 0
return ok()
proc goodClientHook(): Hook =
TokenHook(
append: clientAppendGoodToken,
verify: clientVerify
)
proc badClientHook(): Hook =
TokenHook(
append: clientAppendBadToken,
verify: clientVerify
)
proc serverHook(): Hook =
TokenHook(
append: serverAppend,
verify: serverVerify
)
proc serverVerifyWithCode(ctx: Hook, headers: HttpTable):
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
var p = TokenHook(ctx)
if headers.getString("auth-token") == "good-token":
p.status = 101
return ok()
else:
await p.request.stream.writer.sendError(Http401)
return err("authentication error")
proc serverHookWithCode(request: HttpRequest): Hook =
TokenHook(
append: serverAppend,
verify: serverVerifyWithCode,
request: request
)
suite "Test Hooks":
var
server: HttpServer
goodCP = goodClientHook()
badCP = badClientHook()
teardown:
server.stop()
await server.closeWait()
test "client with valid token":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let
server = WSServer.new()
ws = await server.handleRequest(
request,
hooks = @[serverHook()]
)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
let session = await WebSocket.connect(
host = initTAddress("127.0.0.1:8888"),
path = WSPath,
hooks = @[goodCP]
)
check TokenHook(goodCP).token == "accept"
await session.stream.closeWait()
test "client with bad token":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let
server = WSServer.new()
ws = await server.handleRequest(
request,
hooks = @[serverHook()]
)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
let session = await WebSocket.connect(
host = initTAddress("127.0.0.1:8888"),
path = WSPath,
hooks = @[badCP]
)
check TokenHook(badCP).token == "reject"
await session.stream.closeWait()
test "server hook with code get good client":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let
server = WSServer.new()
ws = await server.handleRequest(
request,
hooks = @[serverHookWithCode(request)]
)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
let session = await WebSocket.connect(
host = initTAddress("127.0.0.1:8888"),
path = WSPath,
hooks = @[goodCP]
)
check TokenHook(goodCP).token == "accept"
await session.stream.closeWait()
test "server hook with code get bad client":
proc handle(request: HttpRequest) {.async.} =
check request.uri.path == WSPath
let
server = WSServer.new()
ws = await server.handleRequest(
request,
hooks = @[serverHookWithCode(request)]
)
server = createServer(
address = address,
handler = handle,
flags = {ReuseAddr})
expect WSFailedUpgradeError:
let session = await WebSocket.connect(
host = initTAddress("127.0.0.1:8888"),
path = WSPath,
hooks = @[badCP]
)
await session.stream.closeWait()

View File

@ -9,7 +9,11 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import pkg/[chronos, chronos/streams/tlsstream, stew/results] import pkg/[chronos,
chronos/streams/tlsstream,
chronos/apps/http/httptable,
httputils,
stew/results]
import ./utils import ./utils
const const
@ -112,6 +116,26 @@ type
factory*: ExtFactoryProc factory*: ExtFactoryProc
clientOffer*: string clientOffer*: string
# client exec order:
# 1. append to request header
# 2. verify response header
# server exec order:
# 1. verify request header
# 2. append to response header
# ------------------------------
# Handshake exec order:
# 1. client append to request header
# 2. server verify request header
# 3. server reply with response header
# 4. client verify response header from server
Hook* = ref object of RootObj
append*: proc(ctx: Hook,
headers: var HttpTable): Result[void, string]
{.gcsafe, raises: [Defect].}
verify*: proc(ctx: Hook,
headers: HttpTable): Future[Result[void, string]]
{.closure, gcsafe, raises: [Defect].}
WebSocketError* = object of CatchableError WebSocketError* = object of CatchableError
WSMalformedHeaderError* = object of WebSocketError WSMalformedHeaderError* = object of WebSocketError
WSFailedUpgradeError* = object of WebSocketError WSFailedUpgradeError* = object of WebSocketError
@ -133,6 +157,7 @@ type
WSInvalidOpcodeError* = object of WebSocketError WSInvalidOpcodeError* = object of WebSocketError
WSInvalidUTF8* = object of WebSocketError WSInvalidUTF8* = object of WebSocketError
WSExtError* = object of WebSocketError WSExtError* = object of WebSocketError
WSHookError* = object of WebSocketError
const const
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999)) StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))

View File

@ -28,7 +28,7 @@ import pkg/[chronos,
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
export utils, session, frame, types, http export utils, session, frame, types, http, httptable
logScope: logScope:
topics = "websock ws-server" topics = "websock ws-server"
@ -109,6 +109,7 @@ proc connect*(
hostName: string = "", # override used when the hostname has been externally resolved hostName: string = "", # override used when the hostname has been externally resolved
protocols: seq[string] = @[], protocols: seq[string] = @[],
factories: seq[ExtFactory] = @[], factories: seq[ExtFactory] = @[],
hooks: seq[Hook] = @[],
secure = false, secure = false,
flags: set[TLSFlags] = {}, flags: set[TLSFlags] = {},
version = WSDefaultVersion, version = WSDefaultVersion,
@ -149,6 +150,13 @@ proc connect*(
if extOffer.len > 0: if extOffer.len > 0:
headers.add("Sec-WebSocket-Extensions", extOffer) headers.add("Sec-WebSocket-Extensions", extOffer)
for hp in hooks:
if hp.append == nil: continue
let res = hp.append(hp, headers)
if res.isErr:
raise newException(WSHookError,
"Header plugin execution failed: " & res.error)
let response = try: let response = try:
await client.request(path, headers = headers) await client.request(path, headers = headers)
except CatchableError as exc: except CatchableError as exc:
@ -168,6 +176,13 @@ proc connect*(
raise newException(WSFailedUpgradeError, raise newException(WSFailedUpgradeError,
&"Invalid protocol returned {proto}!") &"Invalid protocol returned {proto}!")
for hp in hooks:
if hp.verify == nil: continue
let res = await hp.verify(hp, response.headers)
if res.isErr:
raise newException(WSHookError,
"Header verification failed: " & res.error)
var extensions: seq[Ext] var extensions: seq[Ext]
let exts = response.headers.getList("Sec-WebSocket-Extensions") let exts = response.headers.getList("Sec-WebSocket-Extensions")
discard selectExt(false, extensions, factories, exts) discard selectExt(false, extensions, factories, exts)
@ -194,6 +209,7 @@ proc connect*(
uri: Uri, uri: Uri,
protocols: seq[string] = @[], protocols: seq[string] = @[],
factories: seq[ExtFactory] = @[], factories: seq[ExtFactory] = @[],
hooks: seq[Hook] = @[],
flags: set[TLSFlags] = {}, flags: set[TLSFlags] = {},
version = WSDefaultVersion, version = WSDefaultVersion,
frameSize = WSDefaultFrameSize, frameSize = WSDefaultFrameSize,
@ -222,6 +238,7 @@ proc connect*(
path = uri.path, path = uri.path,
protocols = protocols, protocols = protocols,
factories = factories, factories = factories,
hooks = hooks,
secure = secure, secure = secure,
flags = flags, flags = flags,
version = version, version = version,
@ -234,7 +251,8 @@ proc connect*(
proc handleRequest*( proc handleRequest*(
ws: WSServer, ws: WSServer,
request: HttpRequest, request: HttpRequest,
version: uint = WSDefaultVersion): Future[WSSession] version: uint = WSDefaultVersion,
hooks: seq[Hook] = @[]): Future[WSSession]
{. {.
async, async,
raises: [ raises: [
@ -270,6 +288,13 @@ proc handleRequest*(
it in ws.protocols it in ws.protocols
) )
for hp in hooks:
if hp.verify == nil: continue
let res = await hp.verify(hp, request.headers)
if res.isErr:
raise newException(WSHookError,
"Header verification failed: " & res.error)
let let
cKey = ws.key & WSGuid cKey = ws.key & WSGuid
acceptKey = Base64Pad.encode( acceptKey = Base64Pad.encode(
@ -293,6 +318,13 @@ proc handleRequest*(
# send back any accepted extensions # send back any accepted extensions
headers.add("Sec-WebSocket-Extensions", extResp) headers.add("Sec-WebSocket-Extensions", extResp)
for hp in hooks:
if hp.append == nil: continue
let res = hp.append(hp, headers)
if res.isErr:
raise newException(WSHookError,
"Header plugin execution failed: " & res.error)
try: try:
await request.sendResponse(Http101, headers = headers) await request.sendResponse(Http101, headers = headers)
except CancelledError as exc: except CancelledError as exc: