mirror of
https://github.com/status-im/nim-websock.git
synced 2025-02-11 08:56:52 +00:00
parent
6e73e34975
commit
4a7a058843
@ -14,3 +14,4 @@ import ./testutf8
|
||||
import ./testextutils
|
||||
import ./extensions/testexts
|
||||
import ./extensions/testcompression
|
||||
import ./testhooks
|
||||
|
201
tests/testhooks.nim
Normal file
201
tests/testhooks.nim
Normal file
@ -0,0 +1,201 @@
|
||||
## nim-websock
|
||||
## Copyright (c) 2021 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import std/strutils
|
||||
import pkg/[
|
||||
httputils,
|
||||
chronos,
|
||||
chronicles,
|
||||
stew/byteutils,
|
||||
asynctest/unittest2]
|
||||
|
||||
import ../websock/websock
|
||||
|
||||
import ./helpers
|
||||
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
|
||||
type
|
||||
TokenHook* = ref object of Hook
|
||||
status: int
|
||||
token: string
|
||||
request: HttpRequest
|
||||
|
||||
proc clientAppendGoodToken(ctx: Hook, headers: var HttpTable):
|
||||
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||
headers.add("auth-token", "good-token")
|
||||
return ok()
|
||||
|
||||
proc clientAppendBadToken(ctx: Hook, headers: var HttpTable):
|
||||
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||
headers.add("auth-token", "bad-token")
|
||||
return ok()
|
||||
|
||||
proc clientVerify(ctx: Hook, headers: HttpTable):
|
||||
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||
var p = TokenHook(ctx)
|
||||
p.token = headers.getString("auth-status")
|
||||
return ok()
|
||||
|
||||
proc serverVerify(ctx: Hook, headers: HttpTable):
|
||||
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||
var p = TokenHook(ctx)
|
||||
if headers.getString("auth-token") == "good-token":
|
||||
p.status = 101
|
||||
return ok()
|
||||
|
||||
proc serverAppend(ctx: Hook, headers: var HttpTable):
|
||||
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||
var p = TokenHook(ctx)
|
||||
if p.status == 101:
|
||||
headers.add("auth-status", "accept")
|
||||
else:
|
||||
headers.add("auth-status", "reject")
|
||||
p.status = 0
|
||||
return ok()
|
||||
|
||||
proc goodClientHook(): Hook =
|
||||
TokenHook(
|
||||
append: clientAppendGoodToken,
|
||||
verify: clientVerify
|
||||
)
|
||||
|
||||
proc badClientHook(): Hook =
|
||||
TokenHook(
|
||||
append: clientAppendBadToken,
|
||||
verify: clientVerify
|
||||
)
|
||||
|
||||
proc serverHook(): Hook =
|
||||
TokenHook(
|
||||
append: serverAppend,
|
||||
verify: serverVerify
|
||||
)
|
||||
|
||||
proc serverVerifyWithCode(ctx: Hook, headers: HttpTable):
|
||||
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||
var p = TokenHook(ctx)
|
||||
if headers.getString("auth-token") == "good-token":
|
||||
p.status = 101
|
||||
return ok()
|
||||
else:
|
||||
await p.request.stream.writer.sendError(Http401)
|
||||
return err("authentication error")
|
||||
|
||||
proc serverHookWithCode(request: HttpRequest): Hook =
|
||||
TokenHook(
|
||||
append: serverAppend,
|
||||
verify: serverVerifyWithCode,
|
||||
request: request
|
||||
)
|
||||
|
||||
suite "Test Hooks":
|
||||
var
|
||||
server: HttpServer
|
||||
goodCP = goodClientHook()
|
||||
badCP = badClientHook()
|
||||
|
||||
teardown:
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "client with valid token":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let
|
||||
server = WSServer.new()
|
||||
ws = await server.handleRequest(
|
||||
request,
|
||||
hooks = @[serverHook()]
|
||||
)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
|
||||
let session = await WebSocket.connect(
|
||||
host = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
hooks = @[goodCP]
|
||||
)
|
||||
|
||||
check TokenHook(goodCP).token == "accept"
|
||||
await session.stream.closeWait()
|
||||
|
||||
test "client with bad token":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let
|
||||
server = WSServer.new()
|
||||
ws = await server.handleRequest(
|
||||
request,
|
||||
hooks = @[serverHook()]
|
||||
)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
|
||||
let session = await WebSocket.connect(
|
||||
host = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
hooks = @[badCP]
|
||||
)
|
||||
|
||||
check TokenHook(badCP).token == "reject"
|
||||
await session.stream.closeWait()
|
||||
|
||||
test "server hook with code get good client":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let
|
||||
server = WSServer.new()
|
||||
ws = await server.handleRequest(
|
||||
request,
|
||||
hooks = @[serverHookWithCode(request)]
|
||||
)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
|
||||
let session = await WebSocket.connect(
|
||||
host = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
hooks = @[goodCP]
|
||||
)
|
||||
|
||||
check TokenHook(goodCP).token == "accept"
|
||||
await session.stream.closeWait()
|
||||
|
||||
test "server hook with code get bad client":
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == WSPath
|
||||
let
|
||||
server = WSServer.new()
|
||||
ws = await server.handleRequest(
|
||||
request,
|
||||
hooks = @[serverHookWithCode(request)]
|
||||
)
|
||||
|
||||
server = createServer(
|
||||
address = address,
|
||||
handler = handle,
|
||||
flags = {ReuseAddr})
|
||||
|
||||
expect WSFailedUpgradeError:
|
||||
let session = await WebSocket.connect(
|
||||
host = initTAddress("127.0.0.1:8888"),
|
||||
path = WSPath,
|
||||
hooks = @[badCP]
|
||||
)
|
||||
await session.stream.closeWait()
|
@ -9,7 +9,11 @@
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import pkg/[chronos, chronos/streams/tlsstream, stew/results]
|
||||
import pkg/[chronos,
|
||||
chronos/streams/tlsstream,
|
||||
chronos/apps/http/httptable,
|
||||
httputils,
|
||||
stew/results]
|
||||
import ./utils
|
||||
|
||||
const
|
||||
@ -112,6 +116,26 @@ type
|
||||
factory*: ExtFactoryProc
|
||||
clientOffer*: string
|
||||
|
||||
# client exec order:
|
||||
# 1. append to request header
|
||||
# 2. verify response header
|
||||
# server exec order:
|
||||
# 1. verify request header
|
||||
# 2. append to response header
|
||||
# ------------------------------
|
||||
# Handshake exec order:
|
||||
# 1. client append to request header
|
||||
# 2. server verify request header
|
||||
# 3. server reply with response header
|
||||
# 4. client verify response header from server
|
||||
Hook* = ref object of RootObj
|
||||
append*: proc(ctx: Hook,
|
||||
headers: var HttpTable): Result[void, string]
|
||||
{.gcsafe, raises: [Defect].}
|
||||
verify*: proc(ctx: Hook,
|
||||
headers: HttpTable): Future[Result[void, string]]
|
||||
{.closure, gcsafe, raises: [Defect].}
|
||||
|
||||
WebSocketError* = object of CatchableError
|
||||
WSMalformedHeaderError* = object of WebSocketError
|
||||
WSFailedUpgradeError* = object of WebSocketError
|
||||
@ -133,6 +157,7 @@ type
|
||||
WSInvalidOpcodeError* = object of WebSocketError
|
||||
WSInvalidUTF8* = object of WebSocketError
|
||||
WSExtError* = object of WebSocketError
|
||||
WSHookError* = object of WebSocketError
|
||||
|
||||
const
|
||||
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
||||
|
@ -28,7 +28,7 @@ import pkg/[chronos,
|
||||
|
||||
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
||||
|
||||
export utils, session, frame, types, http
|
||||
export utils, session, frame, types, http, httptable
|
||||
|
||||
logScope:
|
||||
topics = "websock ws-server"
|
||||
@ -109,6 +109,7 @@ proc connect*(
|
||||
hostName: string = "", # override used when the hostname has been externally resolved
|
||||
protocols: seq[string] = @[],
|
||||
factories: seq[ExtFactory] = @[],
|
||||
hooks: seq[Hook] = @[],
|
||||
secure = false,
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
@ -149,6 +150,13 @@ proc connect*(
|
||||
if extOffer.len > 0:
|
||||
headers.add("Sec-WebSocket-Extensions", extOffer)
|
||||
|
||||
for hp in hooks:
|
||||
if hp.append == nil: continue
|
||||
let res = hp.append(hp, headers)
|
||||
if res.isErr:
|
||||
raise newException(WSHookError,
|
||||
"Header plugin execution failed: " & res.error)
|
||||
|
||||
let response = try:
|
||||
await client.request(path, headers = headers)
|
||||
except CatchableError as exc:
|
||||
@ -168,6 +176,13 @@ proc connect*(
|
||||
raise newException(WSFailedUpgradeError,
|
||||
&"Invalid protocol returned {proto}!")
|
||||
|
||||
for hp in hooks:
|
||||
if hp.verify == nil: continue
|
||||
let res = await hp.verify(hp, response.headers)
|
||||
if res.isErr:
|
||||
raise newException(WSHookError,
|
||||
"Header verification failed: " & res.error)
|
||||
|
||||
var extensions: seq[Ext]
|
||||
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
||||
discard selectExt(false, extensions, factories, exts)
|
||||
@ -194,6 +209,7 @@ proc connect*(
|
||||
uri: Uri,
|
||||
protocols: seq[string] = @[],
|
||||
factories: seq[ExtFactory] = @[],
|
||||
hooks: seq[Hook] = @[],
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
@ -222,6 +238,7 @@ proc connect*(
|
||||
path = uri.path,
|
||||
protocols = protocols,
|
||||
factories = factories,
|
||||
hooks = hooks,
|
||||
secure = secure,
|
||||
flags = flags,
|
||||
version = version,
|
||||
@ -234,7 +251,8 @@ proc connect*(
|
||||
proc handleRequest*(
|
||||
ws: WSServer,
|
||||
request: HttpRequest,
|
||||
version: uint = WSDefaultVersion): Future[WSSession]
|
||||
version: uint = WSDefaultVersion,
|
||||
hooks: seq[Hook] = @[]): Future[WSSession]
|
||||
{.
|
||||
async,
|
||||
raises: [
|
||||
@ -270,6 +288,13 @@ proc handleRequest*(
|
||||
it in ws.protocols
|
||||
)
|
||||
|
||||
for hp in hooks:
|
||||
if hp.verify == nil: continue
|
||||
let res = await hp.verify(hp, request.headers)
|
||||
if res.isErr:
|
||||
raise newException(WSHookError,
|
||||
"Header verification failed: " & res.error)
|
||||
|
||||
let
|
||||
cKey = ws.key & WSGuid
|
||||
acceptKey = Base64Pad.encode(
|
||||
@ -293,6 +318,13 @@ proc handleRequest*(
|
||||
# send back any accepted extensions
|
||||
headers.add("Sec-WebSocket-Extensions", extResp)
|
||||
|
||||
for hp in hooks:
|
||||
if hp.append == nil: continue
|
||||
let res = hp.append(hp, headers)
|
||||
if res.isErr:
|
||||
raise newException(WSHookError,
|
||||
"Header plugin execution failed: " & res.error)
|
||||
|
||||
try:
|
||||
await request.sendResponse(Http101, headers = headers)
|
||||
except CancelledError as exc:
|
||||
|
Loading…
x
Reference in New Issue
Block a user