Fix webscoket close and fix test cases.

This commit is contained in:
Arijit Das 2021-04-14 16:37:38 +05:30 committed by jangko
parent 0feac12a67
commit 51d834a0a1
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
8 changed files with 966 additions and 554 deletions

View File

@ -35,7 +35,7 @@ jobs:
- target: - target:
os: windows os: windows
builder: windows-2019 builder: windows-2019
name: '${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})' name: "${{ matrix.target.os }}-${{ matrix.target.cpu }} (${{ matrix.branch }})"
runs-on: ${{ matrix.builder }} runs-on: ${{ matrix.builder }}
steps: steps:
- name: Checkout nim-ws - name: Checkout nim-ws

View File

@ -1,8 +1,7 @@
import pkg/[chronos, import pkg/[chronos,
chronos/apps/http/httpserver, chronos/apps/http/httpserver,
chronicles, chronicles,
httputils, httputils]
stew/byteutils]
import ../ws/ws import ../ws/ws
@ -13,22 +12,18 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if request.uri.path == "/ws": if request.uri.path == "/ws":
debug "Initiating web socket connection." debug "Initiating web socket connection."
try: try:
var ws = await createServer(request,"") let ws = await createServer(request, "")
if ws.readyState != Open: if ws.readyState != Open:
error "Failed to open websocket connection." error "Failed to open websocket connection."
return return
debug "Websocket handshake completed." debug "Websocket handshake completed."
while ws.readyState != ReadyState.Closed: while true:
# Only reads header for data frame. let recvData = await ws.recv()
var recvData = await ws.recv() if ws.readyState == ReadyState.Closed:
if recvData.len <= 0: debug "Websocket closed."
debug "Empty messages"
break break
# debug "Client Response: ", data = string.fromBytes(recvData), size = recvData.len
debug "Client Response: ", size = recvData.len debug "Client Response: ", size = recvData.len
await ws.send(recvData) await ws.send(recvData)
# await ws.close()
except WebSocketError as exc: except WebSocketError as exc:
error "WebSocket error:", exception = exc.msg error "WebSocket error:", exception = exc.msg

View File

@ -1,7 +1,7 @@
import unittest import unittest
include ../ws/ws include ../ws/ws
include ../ws/random include ../ws/utils
# TODO: Fix Test. # TODO: Fix Test.

View File

@ -2,6 +2,7 @@ import std/strutils, httputils
import pkg/[asynctest, import pkg/[asynctest,
chronos, chronos,
chronicles,
chronos/apps/http/shttpserver, chronos/apps/http/shttpserver,
stew/byteutils] stew/byteutils]
@ -10,14 +11,23 @@ import ../ws/[ws, stream],
import ./keys import ./keys
var server: SecureHttpServerRef proc waitForClose(ws: WebSocket) {.async.} =
let address = initTAddress("127.0.0.1:8888") try:
let serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect} while ws.readystate != ReadyState.Closed:
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} discard await ws.recv()
let clientFlags = {NoVerifyHost, NoVerifyServerName} except CatchableError:
debug "Closing websocket"
var server: SecureHttpServerRef
let
address = initTAddress("127.0.0.1:8888")
serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
clientFlags = {NoVerifyHost, NoVerifyServerName}
secureKey = TLSPrivateKey.init(SecureKey)
secureCert = TLSCertificate.init(SecureCert)
let secureKey = TLSPrivateKey.init(SecureKey)
let secureCert = TLSCertificate.init(SecureCert)
suite "Test websocket TLS handshake": suite "Test websocket TLS handshake":
teardown: teardown:
@ -31,10 +41,7 @@ suite "Test websocket TLS handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
expect WSProtoMismatchError: expect WSProtoMismatchError:
var ws = await createServer(request, "proto") discard await createServer(request, "proto")
check ws.readyState == ReadyState.Closed
return await request.respond(Http200, "Connection established")
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
address, cb, address, cb,
@ -62,10 +69,7 @@ suite "Test websocket TLS handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
expect WSVersionError: expect WSVersionError:
var ws = await createServer(request, "proto") discard await createServer(request, "proto")
check ws.readyState == ReadyState.Closed
return await request.respond(Http200, "Connection established")
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
address, cb, address, cb,
@ -91,14 +95,17 @@ suite "Test websocket TLS handshake":
check r.isOk() check r.isOk()
let request = r.get() let request = r.get()
check request.uri.path == "/wss" check request.uri.path == "/wss"
check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() check request.headers.getString("Connection").toUpperAscii() ==
check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii() "Upgrade".toUpperAscii()
check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii() check request.headers.getString("Upgrade").toUpperAscii() ==
"websocket".toUpperAscii()
check request.headers.getString("Cache-Control").toUpperAscii() ==
"no-cache".toUpperAscii()
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
check request.headers.contains("Sec-WebSocket-Key") check request.headers.contains("Sec-WebSocket-Key")
discard await request.respond( Http200,"Connection established") discard await request.respond(Http200, "Connection established")
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
address, cb, address, cb,
@ -133,41 +140,7 @@ suite "Test websocket TLS transmission":
let ws = await createServer(request, "proto") let ws = await createServer(request, "proto")
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == testString check string.fromBytes(servRes) == testString
await ws.close() await waitForClose(ws)
return dumbResponse()
let res = SecureHttpServerRef.new(
address, cb,
serverFlags = serverFlags,
socketFlags = socketFlags,
tlsPrivateKey = secureKey,
tlsCertificate = secureCert)
server = res.get()
server.start()
let wsClient = await WebSocket.tlsConnect(
"127.0.0.1",
Port(8888),
path = "/wss",
protocols = @["proto"],
clientFlags)
await wsClient.send(testString)
await wsClient.close()
test "Client - test reading simple frame":
let testString = "Hello!"
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/wss"
let ws = await createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
await ws.close()
return dumbResponse() return dumbResponse()
let res = SecureHttpServerRef.new( let res = SecureHttpServerRef.new(
@ -222,4 +195,4 @@ suite "Test websocket TLS transmission":
var clientRes = await wsClient.recv() var clientRes = await wsClient.recv()
check string.fromBytes(clientRes) == testString check string.fromBytes(clientRes) == testString
await wsClient.close() await waitForClose(wsClient)

View File

@ -1,15 +1,29 @@
import std/strutils,httputils import std/[strutils, random], httputils
import pkg/[asynctest, import pkg/[asynctest,
chronos, chronos,
chronos/apps/http/httpserver, chronos/apps/http/httpserver,
chronicles,
stew/byteutils] stew/byteutils]
import ../ws/[ws, stream] import ../ws/[ws, stream]
include ../ws/ws
var server: HttpServerRef var server: HttpServerRef
let address = initTAddress("127.0.0.1:8888") let address = initTAddress("127.0.0.1:8888")
proc rndStr*(size: int): string =
for _ in .. size:
add(result, char(rand(int('A') .. int('z'))))
proc waitForClose(ws: WebSocket) {.async.} =
try:
while ws.readystate != ReadyState.Closed:
discard await ws.recv()
except CatchableError:
debug "Closing websocket"
suite "Test handshake": suite "Test handshake":
teardown: teardown:
await server.stop() await server.stop()
@ -18,13 +32,12 @@ suite "Test handshake":
test "Test for incorrect protocol": test "Test for incorrect protocol":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
expect WSProtoMismatchError: expect WSProtoMismatchError:
var ws = await createServer(request, "proto") discard await createServer(request, "proto")
check ws.readyState == ReadyState.Closed
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -45,8 +58,7 @@ suite "Test handshake":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
expect WSVersionError: expect WSVersionError:
var ws = await createServer(request, "proto") discard await createServer(request, "proto")
check ws.readyState == ReadyState.Closed
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -63,17 +75,22 @@ suite "Test handshake":
test "Test for client headers": test "Test for client headers":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() check request.headers.getString("Connection").toUpperAscii() ==
check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii() "Upgrade".toUpperAscii()
check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii() check request.headers.getString("Upgrade").toUpperAscii() ==
"websocket".toUpperAscii()
check request.headers.getString("Cache-Control").toUpperAscii() ==
"no-cache".toUpperAscii()
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
check request.headers.contains("Sec-WebSocket-Key") check request.headers.contains("Sec-WebSocket-Key")
discard await request.respond(Http200, "Connection established")
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -88,7 +105,7 @@ suite "Test handshake":
test "Test for incorrect scheme": test "Test for incorrect scheme":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
@ -109,23 +126,52 @@ suite "Test handshake":
parseUri(uri), parseUri(uri),
protocols = @["proto"]) protocols = @["proto"])
test "AsyncStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test transmission": suite "Test transmission":
teardown: teardown:
await server.closeWait() await server.closeWait()
test "Send text message message with payload of length 65535":
let testString = rndStr(65535)
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == testString
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
await wsClient.send(testString)
await wsClient.close()
test "Server - test reading simple frame": test "Server - test reading simple frame":
let testString = "Hello!" let testString = "Hello!"
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await createServer(request, "proto") let ws = await createServer(request, "proto")
let servRes = await ws.recv() let servRes = await ws.recv()
check string.fromBytes(servRes) == testString check string.fromBytes(servRes) == testString
await ws.close() await waitForClose(ws)
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -143,7 +189,7 @@ suite "Test transmission":
let testString = "Hello!" let testString = "Hello!"
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
@ -162,18 +208,89 @@ suite "Test transmission":
protocols = @["proto"]) protocols = @["proto"])
var clientRes = await wsClient.recv() var clientRes = await wsClient.recv()
await wsClient.close()
check string.fromBytes(clientRes) == testString check string.fromBytes(clientRes) == testString
await waitForClose(wsClient)
test "AsyncStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test ping-pong": suite "Test ping-pong":
teardown: teardown:
await server.closeWait() await server.closeWait()
test "Send text Message fragmented into 2 fragments, one ping with payload in-between":
var ping, pong = false
let testString = "1234567890"
let msg = toBytes(testString)
let maxFrameSize = 5
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
request,
"proto",
onPing = proc() =
ping = true
)
let respData = await ws.recv()
check string.fromBytes(respData) == testString
await waitForClose(ws)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"],
frameSize = maxFrameSize,
onPong = proc() =
pong = true
)
let maskKey = genMaskKey(newRng())
let encframe = encodeFrame(Frame(
fin: false,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Text,
mask: true,
data: msg[0..4],
maskKey: maskKey))
await wsClient.stream.writer.write(encframe)
await wsClient.ping()
let encframe1 = encodeFrame(Frame(
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode.Cont,
mask: true,
data: msg[5..9],
maskKey: maskKey))
await wsClient.stream.writer.write(encframe1)
await wsClient.close()
check:
ping
pong
test "Server - test ping-pong control messages": test "Server - test ping-pong control messages":
var ping, pong = false var ping, pong = false
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
@ -200,7 +317,7 @@ suite "Test ping-pong":
ping = true ping = true
) )
discard await wsClient.recv() await waitForClose(wsClient)
check: check:
ping ping
pong pong
@ -209,7 +326,7 @@ suite "Test ping-pong":
var ping, pong = false var ping, pong = false
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr(): if r.isErr():
return return dumbResponse()
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
@ -219,10 +336,10 @@ suite "Test ping-pong":
onPing = proc() = onPing = proc() =
ping = true ping = true
) )
await waitForClose(ws)
discard await ws.recv() check:
await ws.close() ping
pong
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
server.start() server.start()
@ -238,9 +355,13 @@ suite "Test ping-pong":
await wsClient.ping() await wsClient.ping()
await wsClient.close() await wsClient.close()
test "AsyncStream leaks test":
check: check:
ping getTracker("async.stream.reader").isLeaked() == false
pong getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test framing": suite "Test framing":
teardown: teardown:
@ -268,8 +389,7 @@ suite "Test framing":
let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len) let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len)
check read2 == 5 check read2 == 5
await ws.close() await waitForClose(ws)
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -296,7 +416,6 @@ suite "Test framing":
let ws = await createServer(request, "proto") let ws = await createServer(request, "proto")
await ws.send(testString) await ws.send(testString)
await ws.close() await ws.close()
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -310,8 +429,14 @@ suite "Test framing":
expect WSMaxMessageSizeError: expect WSMaxMessageSizeError:
discard await wsClient.recv(5) discard await wsClient.recv(5)
await waitForClose(wsClient)
await wsClient.close() test "AsyncStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test Closing": suite "Test Closing":
teardown: teardown:
@ -326,7 +451,6 @@ suite "Test Closing":
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await createServer(request, "proto") let ws = await createServer(request, "proto")
await ws.close() await ws.close()
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -338,7 +462,7 @@ suite "Test Closing":
path = "/ws", path = "/ws",
protocols = @["proto"]) protocols = @["proto"])
discard await wsClient.recv() await waitForClose(wsClient)
check wsClient.readyState == ReadyState.Closed check wsClient.readyState == ReadyState.Closed
test "Server closing with status": test "Server closing with status":
@ -348,8 +472,8 @@ suite "Test Closing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
proc closeServer(status: Status, reason: string): CloseResult proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
{.gcsafe, raises: [Defect].} = raises: [Defect].} =
try: try:
check status == Status.TooLarge check status == Status.TooLarge
check reason == "Message too big!" check reason == "Message too big!"
@ -364,14 +488,13 @@ suite "Test Closing":
onClose = closeServer) onClose = closeServer)
await ws.close() await ws.close()
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
server.start() server.start()
proc clientClose(status: Status, reason: string): CloseResult proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
{.gcsafe, raises: [Defect].} = raises: [Defect].} =
try: try:
check status == Status.Fulfilled check status == Status.Fulfilled
return (Status.TooLarge, "Message too big!") return (Status.TooLarge, "Message too big!")
@ -385,7 +508,7 @@ suite "Test Closing":
protocols = @["proto"], protocols = @["proto"],
onClose = clientClose) onClose = clientClose)
discard await wsClient.recv() await waitForClose(wsClient)
check wsClient.readyState == ReadyState.Closed check wsClient.readyState == ReadyState.Closed
test "Client closing": test "Client closing":
@ -396,9 +519,7 @@ suite "Test Closing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
let ws = await createServer(request, "proto") let ws = await createServer(request, "proto")
discard await ws.recv() await waitForClose(ws)
await ws.close()
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
@ -418,8 +539,8 @@ suite "Test Closing":
let request = r.get() let request = r.get()
check request.uri.path == "/ws" check request.uri.path == "/ws"
proc closeServer(status: Status, reason: string): CloseResult proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
{.gcsafe, raises: [Defect].} = raises: [Defect].} =
try: try:
check status == Status.Fulfilled check status == Status.Fulfilled
return (Status.TooLarge, "Message too big!") return (Status.TooLarge, "Message too big!")
@ -430,16 +551,14 @@ suite "Test Closing":
request, request,
"proto", "proto",
onClose = closeServer) onClose = closeServer)
discard await ws.recv() await waitForClose(ws)
await ws.close()
return dumbResponse()
let res = HttpServerRef.new(address, cb) let res = HttpServerRef.new(address, cb)
server = res.get() server = res.get()
server.start() server.start()
proc clientClose(status: Status, reason: string): CloseResult proc clientClose(status: Status, reason: string): CloseResult {.gcsafe,
{.gcsafe, raises: [Defect].} = raises: [Defect].} =
try: try:
check status == Status.TooLarge check status == Status.TooLarge
check reason == "Message too big!" check reason == "Message too big!"
@ -456,3 +575,253 @@ suite "Test Closing":
await wsClient.close() await wsClient.close()
check wsClient.readyState == ReadyState.Closed check wsClient.readyState == ReadyState.Closed
test "Server closing with valid close code 3999":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
request,
"proto")
await ws.close(code = Status.ReservedCode)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
proc closeClient(status: Status, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
try:
check status == Status.ReservedCode
return (Status.ReservedCode, "Reserved Status")
except Exception as exc:
raise newException(Defect, exc.msg)
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"],
onClose = closeClient)
await waitForClose(wsClient)
test "Client closing with valid close code 3999":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
proc closeServer(status: Status, reason: string): CloseResult{.gcsafe,
raises: [Defect].} =
try:
check status == Status.ReservedCode
return (Status.ReservedCode, "Reserved Status")
except Exception as exc:
raise newException(Defect, exc.msg)
let ws = await createServer(
request,
"proto",
onClose = closeServer)
await waitForClose(ws)
return dumbResponse()
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
await wsClient.close(code = Status.ReservedCode)
test "Server closing with Payload of length 2":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
# Close with payload of length 2
await ws.close(reason = "HH")
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
await waitForClose(wsClient)
test "Client closing with Payload of length 2":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
await waitForClose(ws)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
# Close with payload of length 2
await wsClient.close(reason = "HH")
test "AsyncStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false
suite "Test Payload":
teardown:
await server.closeWait()
test "Test payload message length":
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
request,
"proto")
expect WSPayloadTooLarge:
discard await ws.recv()
await waitForClose(ws)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let str = rndStr(126)
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
await wsClient.send(toBytes(str), Opcode.Ping)
await wsClient.close()
test "Test single empty payload":
let emptyStr = ""
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
await wsClient.send(emptyStr)
await wsClient.close()
test "Test multiple empty payload":
let emptyStr = ""
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(request, "proto")
let servRes = await ws.recv()
check string.fromBytes(servRes) == emptyStr
await waitForClose(ws)
let res = HttpServerRef.new(
address, cb)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"])
for i in 0..3:
await wsClient.send(emptyStr)
await wsClient.close()
test "Send ping with small text payload":
let testData = toBytes("Hello, world!")
var ping, pong = false
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
if r.isErr():
return dumbResponse()
let request = r.get()
check request.uri.path == "/ws"
let ws = await createServer(
request,
"proto",
onPing = proc() =
ping = true
)
await waitForClose(ws)
let res = HttpServerRef.new(
address, process)
server = res.get()
server.start()
let wsClient = await WebSocket.connect(
"127.0.0.1",
Port(8888),
path = "/ws",
protocols = @["proto"],
onPong = proc() =
pong = true
)
await wsClient.send(testData, Opcode.Ping)
await wsClient.close()
check:
ping
pong
test "AsyncStream leaks test":
check:
getTracker("async.stream.reader").isLeaked() == false
getTracker("async.stream.writer").isLeaked() == false
getTracker("stream.server").isLeaked() == false
getTracker("stream.transport").isLeaked() == false

View File

@ -1,4 +1,4 @@
import pkg/[chronos, import pkg/[chronos,
chronos/apps/http/httpserver, chronos/apps/http/httpserver,
chronos/timer, chronos/timer,
chronicles, chronicles,
@ -45,7 +45,11 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
return buffer return buffer
proc closeWait*(wsStream : AsyncStream) {.async.} = proc closeWait*(wsStream: AsyncStream) {.async.} =
await allFutures(
wsStream.writer.closeWait(),
wsStream.reader.closeWait())
await allFutures( await allFutures(
wsStream.writer.tsource.closeWait(), wsStream.writer.tsource.closeWait(),
wsStream.reader.tsource.closeWait()) wsStream.reader.tsource.closeWait())

179
ws/ws.nim
View File

@ -18,7 +18,7 @@ import pkg/[chronos,
stew/base10, stew/base10,
nimcrypto/sha] nimcrypto/sha]
import ./random, ./stream import ./utils, ./stream
#[ #[
+---------------------------------------------------------------+ +---------------------------------------------------------------+
@ -71,7 +71,13 @@ type
WSMaxMessageSizeError* = object of WebSocketError WSMaxMessageSizeError* = object of WebSocketError
WSClosedError* = object of WebSocketError WSClosedError* = object of WebSocketError
WSSendError* = object of WebSocketError WSSendError* = object of WebSocketError
WSPayloadTooLarge = object of WebSocketError WSPayloadTooLarge* = object of WebSocketError
WSReserverdOpcodeError* = object of WebSocketError
WSFragmentedControlFrameError* = object of WebSocketError
WSInvalidCloseCodeError* = object of WebSocketError
WSPayloadLengthError* = object of WebSocketError
WSInvalidOpcodeError* = object of WebSocketError
Base16Error* = object of CatchableError Base16Error* = object of CatchableError
## Base16 specific exception type ## Base16 specific exception type
@ -111,7 +117,7 @@ type
TooLarge = 1009 TooLarge = 1009
NoExtensions = 1010 NoExtensions = 1010
UnexpectedError = 1011 UnexpectedError = 1011
TlsError # use by clients ReservedCode = 3999 # use by clients
# 3000-3999 reserved for libs # 3000-3999 reserved for libs
# 4000-4999 reserved for applications # 4000-4999 reserved for applications
@ -156,7 +162,7 @@ template remainder*(frame: Frame): uint64 =
proc `$`(ht: HttpTables): string = proc `$`(ht: HttpTables): string =
## Returns string representation of HttpTable/Ref. ## Returns string representation of HttpTable/Ref.
var res = "" var res = ""
for key,value in ht.stringItems(true): for key, value in ht.stringItems(true):
res.add(key.normalizeHeaderName()) res.add(key.normalizeHeaderName())
res.add(": ") res.add(": ")
res.add(value) res.add(value)
@ -209,10 +215,11 @@ proc handshake*(
wantProtocol & ")") wantProtocol & ")")
let cKey = ws.key & WSGuid let cKey = ws.key & WSGuid
let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data) let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0,
cKey.high)).data)
var headerData = [ var headerData = [
("Connection", "Upgrade"), ("Connection", "Upgrade"),
("Upgrade", "webSocket" ), ("Upgrade", "webSocket"),
("Sec-WebSocket-Accept", acceptKey)] ("Sec-WebSocket-Accept", acceptKey)]
var headers = HttpTable.init(headerData) var headers = HttpTable.init(headerData)
@ -222,7 +229,8 @@ proc handshake*(
try: try:
discard await request.respond(httputils.Http101, "", headers) discard await request.respond(httputils.Http101, "", headers)
except CatchableError as exc: except CatchableError as exc:
raise newException(WSHandshakeError, "Failed to sent handshake response. Error: " & exc.msg) raise newException(WSHandshakeError,
"Failed to sent handshake response. Error: " & exc.msg)
ws.readyState = ReadyState.Open ws.readyState = ReadyState.Open
proc createServer*( proc createServer*(
@ -330,6 +338,8 @@ proc send*(
maskKey = genMaskKey(ws.rng) maskKey = genMaskKey(ws.rng)
if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
if ws.readyState in {ReadyState.Closing} and opcode notin {Opcode.Close}:
return
await ws.stream.writer.write(encodeFrame(Frame( await ws.stream.writer.write(encodeFrame(Frame(
fin: true, fin: true,
rsv1: false, rsv1: false,
@ -344,9 +354,9 @@ proc send*(
let maxSize = ws.frameSize let maxSize = ws.frameSize
var i = 0 var i = 0
while i < data.len: while ws.readyState notin {ReadyState.Closing}:
let len = min(data.len, (maxSize + i)) let len = min(data.len, (maxSize + i))
let inFrame = Frame( let encFrame = encodeFrame(Frame(
fin: if (i + len >= data.len): true else: false, fin: if (i + len >= data.len): true else: false,
rsv1: false, rsv1: false,
rsv2: false, rsv2: false,
@ -354,15 +364,21 @@ proc send*(
opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames
mask: ws.masked, mask: ws.masked,
data: data[i ..< len], data: data[i ..< len],
maskKey: maskKey) maskKey: maskKey))
await ws.stream.writer.write(encodeFrame(inFrame)) await ws.stream.writer.write(encFrame)
i += len i += len
if i >= data.len:
break
proc send*(ws: WebSocket, data: string): Future[void] = proc send*(ws: WebSocket, data: string): Future[void] =
send(ws, toBytes(data), Opcode.Text) send(ws, toBytes(data), Opcode.Text)
proc handleClose*(ws: WebSocket, frame: Frame) {.async.} = proc handleClose*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
if ws.readyState notin {ReadyState.Open}:
return
logScope: logScope:
fin = frame.fin fin = frame.fin
masked = frame.mask masked = frame.mask
@ -370,46 +386,46 @@ proc handleClose*(ws: WebSocket, frame: Frame) {.async.} =
serverState = ws.readyState serverState = ws.readyState
debug "Handling close sequence" debug "Handling close sequence"
if ws.readyState == ReadyState.Open or ws.readyState == ReadyState.Closing: var
# Read control frame payload. code = Status.Fulfilled
var data = newSeq[byte](frame.length) reason = ""
if frame.length > 0:
# Read the data.
await ws.stream.reader.readExactly(addr data[0], int frame.length)
unmask(data.toOpenArray(0, data.high), frame.maskKey)
var code: Status if payLoad.len == 1:
if data.len > 0: raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!")
let ccode = uint16.fromBytesBE(data[0..<2]) # first two bytes are the status elif payLoad.len > 1:
doAssert(ccode > 999, "No valid code in close message!") # first two bytes are the status
let ccode = uint16.fromBytesBE(payLoad[0..<2])
if ccode <= 999 or ccode > 1015:
raise newException(WSInvalidCloseCodeError, "Invalid code in close message!")
try:
code = Status(ccode) code = Status(ccode)
data = data[2..data.high] except RangeError:
code = Status.Fulfilled
# remining payload bytes are reason for closing
reason = string.fromBytes(payLoad[2..payLoad.high])
var rcode: Status
if code in {Status.Fulfilled}:
rcode = Status.Fulfilled
var rcode = Status.Fulfilled
var reason = ""
if not isNil(ws.onClose): if not isNil(ws.onClose):
try: try:
(rcode, reason) = ws.onClose(code, string.fromBytes(data)) (rcode, reason) = ws.onClose(code, reason)
except CatchableError as exc: except CatchableError as exc:
debug "Exception in Close callback, this is most likely a bug", exc = exc.msg debug "Exception in Close callback, this is most likely a bug", exc = exc.msg
# don't respond to a terminated connection # don't respond to a terminated connection
if ws.readyState != ReadyState.Closing: if ws.readyState != ReadyState.Closing:
ws.readyState = ReadyState.Closing
await ws.send(prepareCloseBody(rcode, reason), Opcode.Close) await ws.send(prepareCloseBody(rcode, reason), Opcode.Close)
await ws.stream.closeWait()
ws.readyState = ReadyState.Closed ws.readyState = ReadyState.Closed
else: await ws.stream.closeWait()
raiseAssert("Invalid state during close!")
proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = proc handleControl*(ws: WebSocket, frame: Frame, payLoad: seq[byte] = @[]) {.async.} =
## handle control frames ## handle control frames
## ##
if frame.length > 125:
raise newException(WSPayloadTooLarge,
"Control message payload is greater than 125 bytes!")
try: try:
# Process control frame payload. # Process control frame payload.
case frame.opcode: case frame.opcode:
@ -421,7 +437,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
# send pong to remote # send pong to remote
await ws.send(@[], Opcode.Pong) await ws.send(payLoad, Opcode.Pong)
of Opcode.Pong: of Opcode.Pong:
if not isNil(ws.onPong): if not isNil(ws.onPong):
try: try:
@ -429,9 +445,12 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} =
except CatchableError as exc: except CatchableError as exc:
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
of Opcode.Close: of Opcode.Close:
await ws.handleClose(frame) await ws.handleClose(frame, payLoad)
else: else:
raiseAssert("Invalid control opcode") raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
except WebSocketError as exc:
debug "Handled websocket exception", exc = exc.msg
raise exc
except CatchableError as exc: except CatchableError as exc:
trace "Exception handling control messages", exc = exc.msg trace "Exception handling control messages", exc = exc.msg
ws.readyState = ReadyState.Closed ws.readyState = ReadyState.Closed
@ -467,8 +486,12 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
if opcode > ord(Opcode.high): if opcode > ord(Opcode.high):
raise newException(WSOpcodeMismatchError, "Wrong opcode!") raise newException(WSOpcodeMismatchError, "Wrong opcode!")
frame.opcode = (opcode).Opcode let frameOpcode = (opcode).Opcode
if frameOpcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary,
Opcode.Ping, Opcode.Pong, Opcode.Close}:
raise newException(WSReserverdOpcodeError, "Unknown opcode received!")
frame.opcode = frameOpcode
# If any of the rsv are set close the socket. # If any of the rsv are set close the socket.
if frame.rsv1 or frame.rsv2 or frame.rsv3: if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise newException(WSRsvMismatchError, "WebSocket rsv mismatch") raise newException(WSRsvMismatchError, "WebSocket rsv mismatch")
@ -508,12 +531,33 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} =
# return the current frame if it's not one of the control frames # return the current frame if it's not one of the control frames
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
asyncSpawn ws.handleControl(frame) # process control frames if not frame.fin:
raise newException(WSFragmentedControlFrameError, "Control frame cannot be fragmented!")
if frame.length > 125:
raise newException(WSPayloadTooLarge,
"Control message payload is greater than 125 bytes!")
var payLoad = newSeq[byte](frame.length)
if frame.length > 0:
# Read control frame payload.
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
unmask(payLoad.toOpenArray(0, payLoad.high), frame.maskKey)
await ws.handleControl(frame, payLoad) # process control frames# process control frames
continue continue
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask debug "Decoded new frame", opcode = frame.opcode, len = frame.length,
mask = frame.mask
return frame return frame
except WSReserverdOpcodeError as exc:
trace "Handled websocket opcode exception", exc = exc.msg
raise exc
except WSPayloadTooLarge as exc:
debug "Handled payload too large exception", exc = exc.msg
raise exc
except WebSocketError as exc:
debug "Handled websocket exception", exc = exc.msg
raise exc
except CatchableError as exc: except CatchableError as exc:
debug "Exception reading frame, dropping socket", exc = exc.msg debug "Exception reading frame, dropping socket", exc = exc.msg
ws.readyState = ReadyState.Closed ws.readyState = ReadyState.Closed
@ -543,17 +587,32 @@ proc recv*(
while consumed < size: while consumed < size:
# we might have to read more than # we might have to read more than
# one frame to fill the buffer # one frame to fill the buffer
if isNil(ws.frame):
ws.frame = await ws.readFrame()
# all has been consumed from the frame # all has been consumed from the frame
# read the next frame # read the next frame
if ws.frame.remainder() <= 0: if isNil(ws.frame):
ws.frame = await ws.readFrame() ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed.int
if ws.frame.opcode == Opcode.Cont:
raise newException(WSOpcodeMismatchError, "First frame cannot be continue frame")
elif (not ws.frame.fin and ws.frame.remainder() <= 0):
ws.frame = await ws.readFrame()
# This could happen if the connection is closed.
if isNil(ws.frame):
return consumed.int
if ws.frame.opcode != Opcode.Cont:
raise newException(WSOpcodeMismatchError, "expected continue frame")
if ws.frame.fin and ws.frame.remainder().int <= 0:
ws.frame = nil
break
let len = min(ws.frame.remainder().int, size - consumed) let len = min(ws.frame.remainder().int, size - consumed)
if len == 0:
continue
let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len) let read = await ws.stream.reader.readOnce(addr pbuffer[consumed], len)
if read <= 0: if read <= 0:
continue continue
@ -562,14 +621,18 @@ proc recv*(
unmask( unmask(
pbuffer.toOpenArray(consumed, (consumed + read) - 1), pbuffer.toOpenArray(consumed, (consumed + read) - 1),
ws.frame.maskKey, ws.frame.maskKey,
consumed) ws.frame.consumed.int)
consumed += read consumed += read
ws.frame.consumed += read.uint64 ws.frame.consumed += read.uint64
if ws.frame.fin and ws.frame.remainder().int <= 0:
break
return consumed.int return consumed.int
except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
ws.readyState = ReadyState.Closed
await ws.stream.closeWait()
raise exc
except CancelledError as exc: except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg debug "Cancelling reading", exc = exc.msg
raise exc raise exc
@ -611,7 +674,8 @@ proc recv*(
# read the entire message, exit # read the entire message, exit
if ws.frame.fin and ws.frame.remainder().int <= 0: if ws.frame.fin and ws.frame.remainder().int <= 0:
break break
except WSMaxMessageSizeError as exc: except WebSocketError as exc:
debug "Websocket error", exc = exc.msg
raise exc raise exc
except CancelledError as exc: except CancelledError as exc:
debug "Cancelling reading", exc = exc.msg debug "Cancelling reading", exc = exc.msg
@ -659,11 +723,14 @@ proc initiateHandshake(
TransportError, TransportError,
"Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg) "Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg)
let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers let
let reader = newAsyncStreamReader(transp) requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
let writer = newAsyncStreamWriter(transp) reader = newAsyncStreamReader(transp)
writer = newAsyncStreamWriter(transp)
var stream: AsyncStream var stream: AsyncStream
try:
var res: seq[byte] var res: seq[byte]
if uri.scheme == "https": if uri.scheme == "https":
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags) let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
@ -694,6 +761,10 @@ proc initiateHandshake(
" Header code: " & $resHeader.code & " Header code: " & $resHeader.code &
" Header reason: " & resHeader.reason() & " Header reason: " & resHeader.reason() &
" Address: " & $transp.remoteAddress()) " Address: " & $transp.remoteAddress())
except CatchableError as exc:
debug "Websocket failed during handshake", exc = exc.msg
await stream.closeWait()
raise exc
return stream return stream
@ -738,7 +809,7 @@ proc connect*(
# Client data should be masked. # Client data should be masked.
return WebSocket( return WebSocket(
stream: stream, stream: stream,
readyState: Open, readyState: ReadyState.Open,
masked: true, masked: true,
rng: newRng(), rng: newRng(),
frameSize: frameSize, frameSize: frameSize,