From 4a7a058843cdb7a6a4fd25a55f0959c51b0b5847 Mon Sep 17 00:00:00 2001 From: jangko Date: Thu, 3 Mar 2022 13:21:35 +0700 Subject: [PATCH] add http header hook fix #101 --- tests/testcommon.nim | 1 + tests/testhooks.nim | 201 +++++++++++++++++++++++++++++++++++++++++++ websock/types.nim | 27 +++++- websock/websock.nim | 36 +++++++- 4 files changed, 262 insertions(+), 3 deletions(-) create mode 100644 tests/testhooks.nim diff --git a/tests/testcommon.nim b/tests/testcommon.nim index 57e7fd0a..660f8ad0 100644 --- a/tests/testcommon.nim +++ b/tests/testcommon.nim @@ -14,3 +14,4 @@ import ./testutf8 import ./testextutils import ./extensions/testexts import ./extensions/testcompression +import ./testhooks diff --git a/tests/testhooks.nim b/tests/testhooks.nim new file mode 100644 index 00000000..1119bdb5 --- /dev/null +++ b/tests/testhooks.nim @@ -0,0 +1,201 @@ +## nim-websock +## Copyright (c) 2021 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import std/strutils +import pkg/[ + httputils, + chronos, + chronicles, + stew/byteutils, + asynctest/unittest2] + +import ../websock/websock + +import ./helpers + +let address = initTAddress("127.0.0.1:8888") + +type + TokenHook* = ref object of Hook + status: int + token: string + request: HttpRequest + +proc clientAppendGoodToken(ctx: Hook, headers: var HttpTable): + Result[void, string] {.gcsafe, raises: [Defect].} = + headers.add("auth-token", "good-token") + return ok() + +proc clientAppendBadToken(ctx: Hook, headers: var HttpTable): + Result[void, string] {.gcsafe, raises: [Defect].} = + headers.add("auth-token", "bad-token") + return ok() + +proc clientVerify(ctx: Hook, headers: HttpTable): + Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} = + var p = TokenHook(ctx) + p.token = headers.getString("auth-status") + return ok() + +proc serverVerify(ctx: Hook, headers: HttpTable): + Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} = + var p = TokenHook(ctx) + if headers.getString("auth-token") == "good-token": + p.status = 101 + return ok() + +proc serverAppend(ctx: Hook, headers: var HttpTable): + Result[void, string] {.gcsafe, raises: [Defect].} = + var p = TokenHook(ctx) + if p.status == 101: + headers.add("auth-status", "accept") + else: + headers.add("auth-status", "reject") + p.status = 0 + return ok() + +proc goodClientHook(): Hook = + TokenHook( + append: clientAppendGoodToken, + verify: clientVerify + ) + +proc badClientHook(): Hook = + TokenHook( + append: clientAppendBadToken, + verify: clientVerify + ) + +proc serverHook(): Hook = + TokenHook( + append: serverAppend, + verify: serverVerify + ) + +proc serverVerifyWithCode(ctx: Hook, headers: HttpTable): + Future[Result[void, string]] {.async, gcsafe, raises: [Defect].} = + var p = TokenHook(ctx) + if headers.getString("auth-token") == "good-token": + p.status = 101 + return ok() + else: + await p.request.stream.writer.sendError(Http401) + return err("authentication error") + +proc serverHookWithCode(request: HttpRequest): Hook = + TokenHook( + append: serverAppend, + verify: serverVerifyWithCode, + request: request + ) + +suite "Test Hooks": + var + server: HttpServer + goodCP = goodClientHook() + badCP = badClientHook() + + teardown: + server.stop() + await server.closeWait() + + test "client with valid token": + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + let + server = WSServer.new() + ws = await server.handleRequest( + request, + hooks = @[serverHook()] + ) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await WebSocket.connect( + host = initTAddress("127.0.0.1:8888"), + path = WSPath, + hooks = @[goodCP] + ) + + check TokenHook(goodCP).token == "accept" + await session.stream.closeWait() + + test "client with bad token": + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + let + server = WSServer.new() + ws = await server.handleRequest( + request, + hooks = @[serverHook()] + ) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await WebSocket.connect( + host = initTAddress("127.0.0.1:8888"), + path = WSPath, + hooks = @[badCP] + ) + + check TokenHook(badCP).token == "reject" + await session.stream.closeWait() + + test "server hook with code get good client": + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + let + server = WSServer.new() + ws = await server.handleRequest( + request, + hooks = @[serverHookWithCode(request)] + ) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + let session = await WebSocket.connect( + host = initTAddress("127.0.0.1:8888"), + path = WSPath, + hooks = @[goodCP] + ) + + check TokenHook(goodCP).token == "accept" + await session.stream.closeWait() + + test "server hook with code get bad client": + proc handle(request: HttpRequest) {.async.} = + check request.uri.path == WSPath + let + server = WSServer.new() + ws = await server.handleRequest( + request, + hooks = @[serverHookWithCode(request)] + ) + + server = createServer( + address = address, + handler = handle, + flags = {ReuseAddr}) + + expect WSFailedUpgradeError: + let session = await WebSocket.connect( + host = initTAddress("127.0.0.1:8888"), + path = WSPath, + hooks = @[badCP] + ) + await session.stream.closeWait() diff --git a/websock/types.nim b/websock/types.nim index 9538bb22..f80102d2 100644 --- a/websock/types.nim +++ b/websock/types.nim @@ -9,7 +9,11 @@ {.push raises: [Defect].} -import pkg/[chronos, chronos/streams/tlsstream, stew/results] +import pkg/[chronos, + chronos/streams/tlsstream, + chronos/apps/http/httptable, + httputils, + stew/results] import ./utils const @@ -112,6 +116,26 @@ type factory*: ExtFactoryProc clientOffer*: string + # client exec order: + # 1. append to request header + # 2. verify response header + # server exec order: + # 1. verify request header + # 2. append to response header + # ------------------------------ + # Handshake exec order: + # 1. client append to request header + # 2. server verify request header + # 3. server reply with response header + # 4. client verify response header from server + Hook* = ref object of RootObj + append*: proc(ctx: Hook, + headers: var HttpTable): Result[void, string] + {.gcsafe, raises: [Defect].} + verify*: proc(ctx: Hook, + headers: HttpTable): Future[Result[void, string]] + {.closure, gcsafe, raises: [Defect].} + WebSocketError* = object of CatchableError WSMalformedHeaderError* = object of WebSocketError WSFailedUpgradeError* = object of WebSocketError @@ -133,6 +157,7 @@ type WSInvalidOpcodeError* = object of WebSocketError WSInvalidUTF8* = object of WebSocketError WSExtError* = object of WebSocketError + WSHookError* = object of WebSocketError const StatusNotUsed* = (StatusCodes(0)..StatusCodes(999)) diff --git a/websock/websock.nim b/websock/websock.nim index 8af803f2..3845d397 100644 --- a/websock/websock.nim +++ b/websock/websock.nim @@ -28,7 +28,7 @@ import pkg/[chronos, import ./utils, ./frame, ./session, /types, ./http, ./extensions/extutils -export utils, session, frame, types, http +export utils, session, frame, types, http, httptable logScope: topics = "websock ws-server" @@ -109,6 +109,7 @@ proc connect*( hostName: string = "", # override used when the hostname has been externally resolved protocols: seq[string] = @[], factories: seq[ExtFactory] = @[], + hooks: seq[Hook] = @[], secure = false, flags: set[TLSFlags] = {}, version = WSDefaultVersion, @@ -149,6 +150,13 @@ proc connect*( if extOffer.len > 0: headers.add("Sec-WebSocket-Extensions", extOffer) + for hp in hooks: + if hp.append == nil: continue + let res = hp.append(hp, headers) + if res.isErr: + raise newException(WSHookError, + "Header plugin execution failed: " & res.error) + let response = try: await client.request(path, headers = headers) except CatchableError as exc: @@ -168,6 +176,13 @@ proc connect*( raise newException(WSFailedUpgradeError, &"Invalid protocol returned {proto}!") + for hp in hooks: + if hp.verify == nil: continue + let res = await hp.verify(hp, response.headers) + if res.isErr: + raise newException(WSHookError, + "Header verification failed: " & res.error) + var extensions: seq[Ext] let exts = response.headers.getList("Sec-WebSocket-Extensions") discard selectExt(false, extensions, factories, exts) @@ -194,6 +209,7 @@ proc connect*( uri: Uri, protocols: seq[string] = @[], factories: seq[ExtFactory] = @[], + hooks: seq[Hook] = @[], flags: set[TLSFlags] = {}, version = WSDefaultVersion, frameSize = WSDefaultFrameSize, @@ -222,6 +238,7 @@ proc connect*( path = uri.path, protocols = protocols, factories = factories, + hooks = hooks, secure = secure, flags = flags, version = version, @@ -234,7 +251,8 @@ proc connect*( proc handleRequest*( ws: WSServer, request: HttpRequest, - version: uint = WSDefaultVersion): Future[WSSession] + version: uint = WSDefaultVersion, + hooks: seq[Hook] = @[]): Future[WSSession] {. async, raises: [ @@ -270,6 +288,13 @@ proc handleRequest*( it in ws.protocols ) + for hp in hooks: + if hp.verify == nil: continue + let res = await hp.verify(hp, request.headers) + if res.isErr: + raise newException(WSHookError, + "Header verification failed: " & res.error) + let cKey = ws.key & WSGuid acceptKey = Base64Pad.encode( @@ -293,6 +318,13 @@ proc handleRequest*( # send back any accepted extensions headers.add("Sec-WebSocket-Extensions", extResp) + for hp in hooks: + if hp.append == nil: continue + let res = hp.append(hp, headers) + if res.isErr: + raise newException(WSHookError, + "Header plugin execution failed: " & res.error) + try: await request.sendResponse(Http101, headers = headers) except CancelledError as exc: