286 lines
7.2 KiB
Nim

## nim-websock
## Copyright (c) 2021 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import std/strutils
import pkg/[chronos, stew/byteutils]
import pkg/asynctest/unittest2
import ../../ws/ws
type
ExtHandler = proc(ext: Ext, frame: Frame): Future[Frame] {.raises: [Defect].}
HelperExtension = ref object of Ext
handler*: ExtHandler
proc new*(
T: typedesc[HelperExtension],
handler: ExtHandler,
session: WSSession = nil): HelperExtension =
HelperExtension(
handler: handler,
name: "HelperExtension")
method decode*(
self: HelperExtension,
frame: Frame): Future[Frame] {.async.} =
return await self.handler(self, frame)
method encode*(
self: HelperExtension,
frame: Frame): Future[Frame] {.async.} =
return await self.handler(self, frame)
const TestString = "Hello"
suite "Encode frame extensions flow":
test "should call extension on encode":
var data = ""
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "toUpper executed"
data = string.fromBytes(frame.data).toUpper()
check TestString.toUpper() == data
frame.data = data.toBytes()
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[HelperExtension.new(toUpper).Ext])
check frame.data == TestString.toUpper().toBytes()
test "should call extensions in correct order on encode":
var count = 0
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "first executed"
check count == 0
count.inc
return frame
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "second executed"
check count == 1
count.inc
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[
HelperExtension.new(first).Ext,
HelperExtension.new(second).Ext])
check count == 2
test "should allow modifying frame headers":
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "changeHeader executed"
frame.rsv1 = true
frame.rsv2 = true
frame.rsv3 = true
frame.opcode = Opcode.Binary
return frame
var frame = Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text, # fragments have to be `Continuation` frames
mask: false,
data: TestString.toBytes())
discard await frame.encode(@[HelperExtension.new(changeHeader).Ext])
check:
frame.rsv1 == true
frame.rsv2 == true
frame.rsv2 == true
frame.opcode == Opcode.Binary
suite "Decode frame extensions flow":
var
address: TransportAddress
server: StreamServer
maskKey = genMaskKey(newRng())
transport: StreamTransport
reader: AsyncStreamReader
frame: Frame
setup:
server = createStreamServer(
initTAddress("127.0.0.1:0"),
flags = {ServerFlags.ReuseAddr})
address = server.localAddress()
teardown:
await transport.closeWait()
await server.closeWait()
server.stop()
test "should call extension on decode":
var data = ""
proc toUpper(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "toUpper executed"
try:
var buf = newSeq[byte](frame.length)
# read data
await reader.readExactly(addr buf[0], buf.len)
if frame.mask:
mask(buf, maskKey)
frame.mask = false # we can reset the mask key here
data = string.fromBytes(buf).toUpper()
check:
TestString.toUpper() == data
frame.data = data.toBytes()
return frame
except CatchableError as exc:
checkpoint exc.msg
check false
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(toUpper).Ext])
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len
frame.data == TestString.toUpper().toBytes()
test "should call extensions in reverse order on decode":
var count = 0
proc first(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "first executed"
check count == 1
count.inc
return frame
proc second(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "second executed"
check count == 0
count.inc
return frame
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(first).Ext,
HelperExtension.new(second).Ext])
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
let transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len
count == 2
test "should allow modifying frame headers":
proc changeHeader(ext: Ext, frame: Frame): Future[Frame] {.async.} =
checkpoint "changeHeader executed"
frame.rsv1 = false
frame.rsv2 = false
frame.rsv3 = false
frame.opcode = Opcode.Binary
return frame
proc acceptHandler() {.async, gcsafe.} =
let transport = await server.accept()
reader = newAsyncStreamReader(transport)
frame = await Frame.decode(
reader,
false,
@[HelperExtension.new(changeHeader).Ext])
check:
frame.rsv1 == false
frame.rsv2 == false
frame.rsv2 == false
frame.opcode == Opcode.Binary
await reader.closeWait()
await transport.closeWait()
let handlerWait = acceptHandler()
var encodedFrame = (await Frame(
fin: false,
rsv1: true,
rsv2: true,
rsv3: true,
opcode: Opcode.Text,
mask: true,
maskKey: maskKey,
data: TestString.toBytes())
.encode())
let transport = await connect(address)
let wrote = await transport.write(encodedFrame)
await handlerWait
check:
wrote == encodedFrame.len