mirror of
https://github.com/status-im/nim-websock.git
synced 2025-02-11 17:07:23 +00:00
parent
6e73e34975
commit
4a7a058843
@ -14,3 +14,4 @@ import ./testutf8
|
|||||||
import ./testextutils
|
import ./testextutils
|
||||||
import ./extensions/testexts
|
import ./extensions/testexts
|
||||||
import ./extensions/testcompression
|
import ./extensions/testcompression
|
||||||
|
import ./testhooks
|
||||||
|
201
tests/testhooks.nim
Normal file
201
tests/testhooks.nim
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
## nim-websock
|
||||||
|
## Copyright (c) 2021 Status Research & Development GmbH
|
||||||
|
## Licensed under either of
|
||||||
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||||
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||||
|
## at your option.
|
||||||
|
## This file may not be copied, modified, or distributed except according to
|
||||||
|
## those terms.
|
||||||
|
|
||||||
|
import std/strutils
|
||||||
|
import pkg/[
|
||||||
|
httputils,
|
||||||
|
chronos,
|
||||||
|
chronicles,
|
||||||
|
stew/byteutils,
|
||||||
|
asynctest/unittest2]
|
||||||
|
|
||||||
|
import ../websock/websock
|
||||||
|
|
||||||
|
import ./helpers
|
||||||
|
|
||||||
|
let address = initTAddress("127.0.0.1:8888")
|
||||||
|
|
||||||
|
type
|
||||||
|
TokenHook* = ref object of Hook
|
||||||
|
status: int
|
||||||
|
token: string
|
||||||
|
request: HttpRequest
|
||||||
|
|
||||||
|
proc clientAppendGoodToken(ctx: Hook, headers: var HttpTable):
|
||||||
|
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||||
|
headers.add("auth-token", "good-token")
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
proc clientAppendBadToken(ctx: Hook, headers: var HttpTable):
|
||||||
|
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||||
|
headers.add("auth-token", "bad-token")
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
proc clientVerify(ctx: Hook, headers: HttpTable):
|
||||||
|
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||||
|
var p = TokenHook(ctx)
|
||||||
|
p.token = headers.getString("auth-status")
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
proc serverVerify(ctx: Hook, headers: HttpTable):
|
||||||
|
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||||
|
var p = TokenHook(ctx)
|
||||||
|
if headers.getString("auth-token") == "good-token":
|
||||||
|
p.status = 101
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
proc serverAppend(ctx: Hook, headers: var HttpTable):
|
||||||
|
Result[void, string] {.gcsafe, raises: [Defect].} =
|
||||||
|
var p = TokenHook(ctx)
|
||||||
|
if p.status == 101:
|
||||||
|
headers.add("auth-status", "accept")
|
||||||
|
else:
|
||||||
|
headers.add("auth-status", "reject")
|
||||||
|
p.status = 0
|
||||||
|
return ok()
|
||||||
|
|
||||||
|
proc goodClientHook(): Hook =
|
||||||
|
TokenHook(
|
||||||
|
append: clientAppendGoodToken,
|
||||||
|
verify: clientVerify
|
||||||
|
)
|
||||||
|
|
||||||
|
proc badClientHook(): Hook =
|
||||||
|
TokenHook(
|
||||||
|
append: clientAppendBadToken,
|
||||||
|
verify: clientVerify
|
||||||
|
)
|
||||||
|
|
||||||
|
proc serverHook(): Hook =
|
||||||
|
TokenHook(
|
||||||
|
append: serverAppend,
|
||||||
|
verify: serverVerify
|
||||||
|
)
|
||||||
|
|
||||||
|
proc serverVerifyWithCode(ctx: Hook, headers: HttpTable):
|
||||||
|
Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} =
|
||||||
|
var p = TokenHook(ctx)
|
||||||
|
if headers.getString("auth-token") == "good-token":
|
||||||
|
p.status = 101
|
||||||
|
return ok()
|
||||||
|
else:
|
||||||
|
await p.request.stream.writer.sendError(Http401)
|
||||||
|
return err("authentication error")
|
||||||
|
|
||||||
|
proc serverHookWithCode(request: HttpRequest): Hook =
|
||||||
|
TokenHook(
|
||||||
|
append: serverAppend,
|
||||||
|
verify: serverVerifyWithCode,
|
||||||
|
request: request
|
||||||
|
)
|
||||||
|
|
||||||
|
suite "Test Hooks":
|
||||||
|
var
|
||||||
|
server: HttpServer
|
||||||
|
goodCP = goodClientHook()
|
||||||
|
badCP = badClientHook()
|
||||||
|
|
||||||
|
teardown:
|
||||||
|
server.stop()
|
||||||
|
await server.closeWait()
|
||||||
|
|
||||||
|
test "client with valid token":
|
||||||
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
|
check request.uri.path == WSPath
|
||||||
|
let
|
||||||
|
server = WSServer.new()
|
||||||
|
ws = await server.handleRequest(
|
||||||
|
request,
|
||||||
|
hooks = @[serverHook()]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = createServer(
|
||||||
|
address = address,
|
||||||
|
handler = handle,
|
||||||
|
flags = {ReuseAddr})
|
||||||
|
|
||||||
|
let session = await WebSocket.connect(
|
||||||
|
host = initTAddress("127.0.0.1:8888"),
|
||||||
|
path = WSPath,
|
||||||
|
hooks = @[goodCP]
|
||||||
|
)
|
||||||
|
|
||||||
|
check TokenHook(goodCP).token == "accept"
|
||||||
|
await session.stream.closeWait()
|
||||||
|
|
||||||
|
test "client with bad token":
|
||||||
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
|
check request.uri.path == WSPath
|
||||||
|
let
|
||||||
|
server = WSServer.new()
|
||||||
|
ws = await server.handleRequest(
|
||||||
|
request,
|
||||||
|
hooks = @[serverHook()]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = createServer(
|
||||||
|
address = address,
|
||||||
|
handler = handle,
|
||||||
|
flags = {ReuseAddr})
|
||||||
|
|
||||||
|
let session = await WebSocket.connect(
|
||||||
|
host = initTAddress("127.0.0.1:8888"),
|
||||||
|
path = WSPath,
|
||||||
|
hooks = @[badCP]
|
||||||
|
)
|
||||||
|
|
||||||
|
check TokenHook(badCP).token == "reject"
|
||||||
|
await session.stream.closeWait()
|
||||||
|
|
||||||
|
test "server hook with code get good client":
|
||||||
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
|
check request.uri.path == WSPath
|
||||||
|
let
|
||||||
|
server = WSServer.new()
|
||||||
|
ws = await server.handleRequest(
|
||||||
|
request,
|
||||||
|
hooks = @[serverHookWithCode(request)]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = createServer(
|
||||||
|
address = address,
|
||||||
|
handler = handle,
|
||||||
|
flags = {ReuseAddr})
|
||||||
|
|
||||||
|
let session = await WebSocket.connect(
|
||||||
|
host = initTAddress("127.0.0.1:8888"),
|
||||||
|
path = WSPath,
|
||||||
|
hooks = @[goodCP]
|
||||||
|
)
|
||||||
|
|
||||||
|
check TokenHook(goodCP).token == "accept"
|
||||||
|
await session.stream.closeWait()
|
||||||
|
|
||||||
|
test "server hook with code get bad client":
|
||||||
|
proc handle(request: HttpRequest) {.async.} =
|
||||||
|
check request.uri.path == WSPath
|
||||||
|
let
|
||||||
|
server = WSServer.new()
|
||||||
|
ws = await server.handleRequest(
|
||||||
|
request,
|
||||||
|
hooks = @[serverHookWithCode(request)]
|
||||||
|
)
|
||||||
|
|
||||||
|
server = createServer(
|
||||||
|
address = address,
|
||||||
|
handler = handle,
|
||||||
|
flags = {ReuseAddr})
|
||||||
|
|
||||||
|
expect WSFailedUpgradeError:
|
||||||
|
let session = await WebSocket.connect(
|
||||||
|
host = initTAddress("127.0.0.1:8888"),
|
||||||
|
path = WSPath,
|
||||||
|
hooks = @[badCP]
|
||||||
|
)
|
||||||
|
await session.stream.closeWait()
|
@ -9,7 +9,11 @@
|
|||||||
|
|
||||||
{.push raises: [Defect].}
|
{.push raises: [Defect].}
|
||||||
|
|
||||||
import pkg/[chronos, chronos/streams/tlsstream, stew/results]
|
import pkg/[chronos,
|
||||||
|
chronos/streams/tlsstream,
|
||||||
|
chronos/apps/http/httptable,
|
||||||
|
httputils,
|
||||||
|
stew/results]
|
||||||
import ./utils
|
import ./utils
|
||||||
|
|
||||||
const
|
const
|
||||||
@ -112,6 +116,26 @@ type
|
|||||||
factory*: ExtFactoryProc
|
factory*: ExtFactoryProc
|
||||||
clientOffer*: string
|
clientOffer*: string
|
||||||
|
|
||||||
|
# client exec order:
|
||||||
|
# 1. append to request header
|
||||||
|
# 2. verify response header
|
||||||
|
# server exec order:
|
||||||
|
# 1. verify request header
|
||||||
|
# 2. append to response header
|
||||||
|
# ------------------------------
|
||||||
|
# Handshake exec order:
|
||||||
|
# 1. client append to request header
|
||||||
|
# 2. server verify request header
|
||||||
|
# 3. server reply with response header
|
||||||
|
# 4. client verify response header from server
|
||||||
|
Hook* = ref object of RootObj
|
||||||
|
append*: proc(ctx: Hook,
|
||||||
|
headers: var HttpTable): Result[void, string]
|
||||||
|
{.gcsafe, raises: [Defect].}
|
||||||
|
verify*: proc(ctx: Hook,
|
||||||
|
headers: HttpTable): Future[Result[void, string]]
|
||||||
|
{.closure, gcsafe, raises: [Defect].}
|
||||||
|
|
||||||
WebSocketError* = object of CatchableError
|
WebSocketError* = object of CatchableError
|
||||||
WSMalformedHeaderError* = object of WebSocketError
|
WSMalformedHeaderError* = object of WebSocketError
|
||||||
WSFailedUpgradeError* = object of WebSocketError
|
WSFailedUpgradeError* = object of WebSocketError
|
||||||
@ -133,6 +157,7 @@ type
|
|||||||
WSInvalidOpcodeError* = object of WebSocketError
|
WSInvalidOpcodeError* = object of WebSocketError
|
||||||
WSInvalidUTF8* = object of WebSocketError
|
WSInvalidUTF8* = object of WebSocketError
|
||||||
WSExtError* = object of WebSocketError
|
WSExtError* = object of WebSocketError
|
||||||
|
WSHookError* = object of WebSocketError
|
||||||
|
|
||||||
const
|
const
|
||||||
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
StatusNotUsed* = (StatusCodes(0)..StatusCodes(999))
|
||||||
|
@ -28,7 +28,7 @@ import pkg/[chronos,
|
|||||||
|
|
||||||
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils
|
||||||
|
|
||||||
export utils, session, frame, types, http
|
export utils, session, frame, types, http, httptable
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
topics = "websock ws-server"
|
topics = "websock ws-server"
|
||||||
@ -109,6 +109,7 @@ proc connect*(
|
|||||||
hostName: string = "", # override used when the hostname has been externally resolved
|
hostName: string = "", # override used when the hostname has been externally resolved
|
||||||
protocols: seq[string] = @[],
|
protocols: seq[string] = @[],
|
||||||
factories: seq[ExtFactory] = @[],
|
factories: seq[ExtFactory] = @[],
|
||||||
|
hooks: seq[Hook] = @[],
|
||||||
secure = false,
|
secure = false,
|
||||||
flags: set[TLSFlags] = {},
|
flags: set[TLSFlags] = {},
|
||||||
version = WSDefaultVersion,
|
version = WSDefaultVersion,
|
||||||
@ -149,6 +150,13 @@ proc connect*(
|
|||||||
if extOffer.len > 0:
|
if extOffer.len > 0:
|
||||||
headers.add("Sec-WebSocket-Extensions", extOffer)
|
headers.add("Sec-WebSocket-Extensions", extOffer)
|
||||||
|
|
||||||
|
for hp in hooks:
|
||||||
|
if hp.append == nil: continue
|
||||||
|
let res = hp.append(hp, headers)
|
||||||
|
if res.isErr:
|
||||||
|
raise newException(WSHookError,
|
||||||
|
"Header plugin execution failed: " & res.error)
|
||||||
|
|
||||||
let response = try:
|
let response = try:
|
||||||
await client.request(path, headers = headers)
|
await client.request(path, headers = headers)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
@ -168,6 +176,13 @@ proc connect*(
|
|||||||
raise newException(WSFailedUpgradeError,
|
raise newException(WSFailedUpgradeError,
|
||||||
&"Invalid protocol returned {proto}!")
|
&"Invalid protocol returned {proto}!")
|
||||||
|
|
||||||
|
for hp in hooks:
|
||||||
|
if hp.verify == nil: continue
|
||||||
|
let res = await hp.verify(hp, response.headers)
|
||||||
|
if res.isErr:
|
||||||
|
raise newException(WSHookError,
|
||||||
|
"Header verification failed: " & res.error)
|
||||||
|
|
||||||
var extensions: seq[Ext]
|
var extensions: seq[Ext]
|
||||||
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
let exts = response.headers.getList("Sec-WebSocket-Extensions")
|
||||||
discard selectExt(false, extensions, factories, exts)
|
discard selectExt(false, extensions, factories, exts)
|
||||||
@ -194,6 +209,7 @@ proc connect*(
|
|||||||
uri: Uri,
|
uri: Uri,
|
||||||
protocols: seq[string] = @[],
|
protocols: seq[string] = @[],
|
||||||
factories: seq[ExtFactory] = @[],
|
factories: seq[ExtFactory] = @[],
|
||||||
|
hooks: seq[Hook] = @[],
|
||||||
flags: set[TLSFlags] = {},
|
flags: set[TLSFlags] = {},
|
||||||
version = WSDefaultVersion,
|
version = WSDefaultVersion,
|
||||||
frameSize = WSDefaultFrameSize,
|
frameSize = WSDefaultFrameSize,
|
||||||
@ -222,6 +238,7 @@ proc connect*(
|
|||||||
path = uri.path,
|
path = uri.path,
|
||||||
protocols = protocols,
|
protocols = protocols,
|
||||||
factories = factories,
|
factories = factories,
|
||||||
|
hooks = hooks,
|
||||||
secure = secure,
|
secure = secure,
|
||||||
flags = flags,
|
flags = flags,
|
||||||
version = version,
|
version = version,
|
||||||
@ -234,7 +251,8 @@ proc connect*(
|
|||||||
proc handleRequest*(
|
proc handleRequest*(
|
||||||
ws: WSServer,
|
ws: WSServer,
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
version: uint = WSDefaultVersion): Future[WSSession]
|
version: uint = WSDefaultVersion,
|
||||||
|
hooks: seq[Hook] = @[]): Future[WSSession]
|
||||||
{.
|
{.
|
||||||
async,
|
async,
|
||||||
raises: [
|
raises: [
|
||||||
@ -270,6 +288,13 @@ proc handleRequest*(
|
|||||||
it in ws.protocols
|
it in ws.protocols
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for hp in hooks:
|
||||||
|
if hp.verify == nil: continue
|
||||||
|
let res = await hp.verify(hp, request.headers)
|
||||||
|
if res.isErr:
|
||||||
|
raise newException(WSHookError,
|
||||||
|
"Header verification failed: " & res.error)
|
||||||
|
|
||||||
let
|
let
|
||||||
cKey = ws.key & WSGuid
|
cKey = ws.key & WSGuid
|
||||||
acceptKey = Base64Pad.encode(
|
acceptKey = Base64Pad.encode(
|
||||||
@ -293,6 +318,13 @@ proc handleRequest*(
|
|||||||
# send back any accepted extensions
|
# send back any accepted extensions
|
||||||
headers.add("Sec-WebSocket-Extensions", extResp)
|
headers.add("Sec-WebSocket-Extensions", extResp)
|
||||||
|
|
||||||
|
for hp in hooks:
|
||||||
|
if hp.append == nil: continue
|
||||||
|
let res = hp.append(hp, headers)
|
||||||
|
if res.isErr:
|
||||||
|
raise newException(WSHookError,
|
||||||
|
"Header plugin execution failed: " & res.error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await request.sendResponse(Http101, headers = headers)
|
await request.sendResponse(Http101, headers = headers)
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user