mirror of
https://github.com/logos-storage/nim-websock.git
synced 2026-01-05 07:03:10 +00:00
Rework http (#38)
* wip * wip * move http under ws folder * use asyctest * wip * wip * rework response sending * make example work with latest changes * wip request/response * misc * fix example to use new http layer * pass tls flags to client * more cleanup * unused imports * more unsused imports * better headers * add helpre sendError * export sendError * attach selected proto to session * move proto to session * handle unsupported version * fix tests * comment out for now * fix utf8 tests * allow tests to be ran in tls * misc * use Port type * add tls flags * better api * run tls tests * fix tests on windows * allow running tests with tls * mic * wip * fix autobahn ci * handle close * cleanup * logging and error handling * remove old stream
This commit is contained in:
parent
723971a39d
commit
64da1a4344
16
.github/workflows/ci.yml
vendored
16
.github/workflows/ci.yml
vendored
@ -224,18 +224,20 @@ jobs:
|
||||
sed -i "s/COMMIT_SHA/$GITHUB_SHA/g" autobahn/index.md
|
||||
markdown2 autobahn/index.md > autobahn/reports/index.html
|
||||
|
||||
chmod +x ./scripts/start_server.sh
|
||||
./scripts/start_server.sh
|
||||
nim c examples/server.nim
|
||||
examples/server &
|
||||
pid=$!
|
||||
cd autobahn
|
||||
wstest --mode fuzzingclient --spec fuzzingclient.json
|
||||
kill $(pidof server)
|
||||
|
||||
kill $pid
|
||||
cd ..
|
||||
chmod +x ./scripts/start_server_tls.sh
|
||||
./scripts/start_server_tls.sh
|
||||
|
||||
nim c examples/tlsserver.nim
|
||||
examples/tlsserver &
|
||||
pid=$!
|
||||
cd autobahn
|
||||
wstest --mode fuzzingclient --spec fuzzingclient_tls.json
|
||||
kill $(pidof tlsserver)
|
||||
kill $pid
|
||||
|
||||
- name: Deploy autobahn report.
|
||||
if: runner.os == 'linux' && matrix.target.cpu == 'amd64' && github.event_name == 'push'
|
||||
|
||||
@ -28,6 +28,7 @@ proc main() {.async.} =
|
||||
break
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
raise exc
|
||||
|
||||
await sleepAsync(100.millis)
|
||||
|
||||
|
||||
@ -1,49 +1,46 @@
|
||||
|
||||
import std/uri
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/httpserver,
|
||||
chronicles,
|
||||
httputils]
|
||||
|
||||
import ../ws/ws
|
||||
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isOk():
|
||||
let request = r.get()
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path == "/ws":
|
||||
debug "Initiating web socket connection."
|
||||
try:
|
||||
let server = WSServer.new()
|
||||
let ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path != "/ws":
|
||||
return
|
||||
|
||||
debug "Websocket handshake completed."
|
||||
while true:
|
||||
let recvData = await ws.recv()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
debug "Websocket closed."
|
||||
break
|
||||
debug "Initiating web socket connection."
|
||||
try:
|
||||
let server = WSServer.new()
|
||||
let ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection"
|
||||
return
|
||||
|
||||
debug "Client Response: ", size = recvData.len
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
debug "Websocket handshake completed"
|
||||
while true:
|
||||
let recvData = await ws.recv()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
debug "Websocket closed"
|
||||
break
|
||||
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
|
||||
discard await request.respond(Http200, "Hello World")
|
||||
else:
|
||||
return dumbResponse()
|
||||
debug "Client Response: ", size = recvData.len
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
except WebSocketError as exc:
|
||||
error "WebSocket error:", exception = exc.msg
|
||||
|
||||
when isMainModule:
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = HttpServerRef.new(
|
||||
address, process,
|
||||
socketFlags = socketFlags)
|
||||
proc main() {.async.} =
|
||||
let
|
||||
address = initTAddress("127.0.0.1:8888")
|
||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
server = HttpServer.create(address, handle, flags = socketFlags)
|
||||
|
||||
let server = res.get()
|
||||
server.start()
|
||||
info "Server listening at ", data = address
|
||||
waitFor server.join()
|
||||
server.start()
|
||||
info "Server listening at ", data = $server.localAddress()
|
||||
await server.join()
|
||||
|
||||
waitFor(main())
|
||||
|
||||
@ -16,7 +16,7 @@ proc main() {.async.} =
|
||||
|
||||
let reqData = "Hello Server"
|
||||
try:
|
||||
echo "sending client "
|
||||
debug "sending client "
|
||||
await ws.send(reqData)
|
||||
let buff = await ws.recv()
|
||||
if buff.len <= 0:
|
||||
|
||||
@ -1,59 +1,54 @@
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/shttpserver,
|
||||
chronicles,
|
||||
httputils,
|
||||
stew/byteutils]
|
||||
|
||||
import pkg/[chronos/streams/tlsstream]
|
||||
|
||||
import ../ws/ws
|
||||
import ../tests/keys
|
||||
|
||||
let secureKey = TLSPrivateKey.init(SecureKey)
|
||||
let secureCert = TLSCertificate.init(SecureCert)
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path != "/wss":
|
||||
debug "Initiating web socket connection."
|
||||
return
|
||||
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isOk():
|
||||
let request = r.get()
|
||||
try:
|
||||
let server = WSServer.new(protos = ["myfancyprotocol"])
|
||||
var ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
debug "Websocket handshake completed."
|
||||
# Only reads header for data frame.
|
||||
echo "receiving server "
|
||||
let recvData = await ws.recv()
|
||||
if recvData.len <= 0:
|
||||
debug "Empty messages"
|
||||
break
|
||||
|
||||
debug "Handling request:", uri = request.uri.path
|
||||
if request.uri.path == "/wss":
|
||||
debug "Initiating web socket connection."
|
||||
try:
|
||||
let server = WSServer.new(protos = ["myfancyprotocol"])
|
||||
var ws = await server.handleRequest(request)
|
||||
if ws.readyState != Open:
|
||||
error "Failed to open websocket connection."
|
||||
return
|
||||
debug "Websocket handshake completed."
|
||||
# Only reads header for data frame.
|
||||
echo "receiving server "
|
||||
let recvData = await ws.recv()
|
||||
if recvData.len <= 0:
|
||||
debug "Empty messages"
|
||||
break
|
||||
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
return
|
||||
debug "Response: ", data = string.fromBytes(recvData)
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
discard await request.respond(Http200, "Hello World")
|
||||
else:
|
||||
return dumbResponse()
|
||||
if ws.readyState == ReadyState.Closed:
|
||||
return
|
||||
debug "Response: ", data = string.fromBytes(recvData)
|
||||
await ws.send(recvData,
|
||||
if ws.binary: Opcode.Binary else: Opcode.Text)
|
||||
except WebSocketError:
|
||||
error "WebSocket error:", exception = getCurrentExceptionMsg()
|
||||
|
||||
when isMainModule:
|
||||
proc main() {.async.} =
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
let serverFlags = {Secure, NotifyDisconnect}
|
||||
let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, process,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
let server = TlsHttpServer.create(
|
||||
address = address,
|
||||
handler = handle,
|
||||
tlsPrivateKey = TLSPrivateKey.init(SecureKey),
|
||||
tlsCertificate = TLSCertificate.init(SecureCert),
|
||||
flags = socketFlags)
|
||||
|
||||
let server = res.get()
|
||||
server.start()
|
||||
info "Server listening at ", data = address
|
||||
waitFor server.join()
|
||||
info "Server listening at ", data = $server.localAddress()
|
||||
await server.join()
|
||||
|
||||
waitFor(main())
|
||||
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
nim c -r examples/server.nim &
|
||||
|
||||
max_iterations=10
|
||||
wait_seconds=6
|
||||
http_endpoint="http://127.0.0.1:8888/"
|
||||
|
||||
iterations=0
|
||||
while true
|
||||
do
|
||||
((iterations++))
|
||||
echo "Attempt $iterations"
|
||||
sleep $wait_seconds
|
||||
|
||||
http_code=$(curl --verbose -s -o /tmp/result.txt -w '%{http_code}' "$http_endpoint";)
|
||||
|
||||
if [ "$http_code" -eq 200 ]; then
|
||||
echo "Server Up"
|
||||
break
|
||||
fi
|
||||
|
||||
if [ "$iterations" -ge "$max_iterations" ]; then
|
||||
echo "Loop Timeout"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
nim c -r examples/tlsserver.nim &
|
||||
|
||||
max_iterations=10
|
||||
wait_seconds=6
|
||||
http_endpoint="https://127.0.0.1:8888/"
|
||||
|
||||
iterations=0
|
||||
while true
|
||||
do
|
||||
((iterations++))
|
||||
echo "Attempt $iterations"
|
||||
sleep $wait_seconds
|
||||
|
||||
http_code=$(curl -k --verbose -s -o /tmp/result.txt -w '%{http_code}' "$http_endpoint";)
|
||||
|
||||
if [ "$http_code" -eq 200 ]; then
|
||||
echo "Server Up"
|
||||
break
|
||||
fi
|
||||
|
||||
if [ "$iterations" -ge "$max_iterations" ]; then
|
||||
echo "Loop Timeout"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
23
tests/asyncunit.nim
Normal file
23
tests/asyncunit.nim
Normal file
@ -0,0 +1,23 @@
|
||||
import unittest2
|
||||
export unittest2 except suite, test
|
||||
|
||||
template suite*(name, body) =
|
||||
suite name:
|
||||
|
||||
template setup(setupBody) {.used.} =
|
||||
setup:
|
||||
let asyncproc = proc {.async.} = setupBody
|
||||
waitFor asyncproc()
|
||||
|
||||
template teardown(teardownBody) {.used.} =
|
||||
teardown:
|
||||
let asyncproc = proc {.async.} = teardownBody
|
||||
waitFor asyncproc()
|
||||
|
||||
let suiteproc = proc = body # Avoids GcUnsafe2 warnings with chronos
|
||||
suiteproc()
|
||||
|
||||
template test*(name, body) =
|
||||
test name:
|
||||
let asyncproc = proc {.async.} = body
|
||||
waitFor asyncproc()
|
||||
@ -1,6 +1,4 @@
|
||||
{. warning[UnusedImport]:off .}
|
||||
|
||||
import ./testframes
|
||||
import ./testwebsockets
|
||||
import ./testtlswebsockets
|
||||
import ./testutf8
|
||||
@ -1,207 +0,0 @@
|
||||
import std/strutils, httputils
|
||||
|
||||
import pkg/[asynctest,
|
||||
chronos,
|
||||
chronicles,
|
||||
chronos/apps/http/shttpserver,
|
||||
stew/byteutils]
|
||||
|
||||
import ../ws/ws, ../examples/tlsserver
|
||||
|
||||
import ./keys
|
||||
|
||||
proc waitForClose(ws: WSSession) {.async.} =
|
||||
try:
|
||||
while ws.readystate != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
except CatchableError:
|
||||
debug "Closing websocket"
|
||||
|
||||
var server: SecureHttpServerRef
|
||||
|
||||
let
|
||||
address = initTAddress("127.0.0.1:8888")
|
||||
serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect}
|
||||
socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr}
|
||||
clientFlags = {NoVerifyHost, NoVerifyServerName}
|
||||
secureKey = TLSPrivateKey.init(SecureKey)
|
||||
secureCert = TLSCertificate.init(SecureCert)
|
||||
|
||||
suite "Test websocket TLS handshake":
|
||||
teardown:
|
||||
await server.closeWait()
|
||||
|
||||
test "Test for websocket TLS incorrect protocol":
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
|
||||
expect WSProtoMismatchError:
|
||||
discard await server.handleRequest(request)
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
expect WSFailedUpgradeError:
|
||||
discard await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["wrongproto"],
|
||||
clientFlags)
|
||||
|
||||
test "Test for websocket TLS incorrect version":
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
|
||||
expect WSVersionError:
|
||||
discard await server.handleRequest(request)
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
expect WSFailedUpgradeError:
|
||||
discard await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["wrongproto"],
|
||||
clientFlags,
|
||||
version = 14)
|
||||
|
||||
test "Test for websocket TLS client headers":
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
check r.isOk()
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
check request.headers.getString("Connection").toUpperAscii() ==
|
||||
"Upgrade".toUpperAscii()
|
||||
check request.headers.getString("Upgrade").toUpperAscii() ==
|
||||
"websocket".toUpperAscii()
|
||||
check request.headers.getString("Cache-Control").toUpperAscii() ==
|
||||
"no-cache".toUpperAscii()
|
||||
check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion
|
||||
|
||||
check request.headers.contains("Sec-WebSocket-Key")
|
||||
discard await request.respond(Http200, "Connection established")
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
expect WSFailedUpgradeError:
|
||||
discard await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["proto"],
|
||||
clientFlags)
|
||||
|
||||
suite "Test websocket TLS transmission":
|
||||
teardown:
|
||||
await server.closeWait()
|
||||
|
||||
test "Server - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let servRes = await ws.recv()
|
||||
check string.fromBytes(servRes) == testString
|
||||
|
||||
await waitForClose(ws)
|
||||
return dumbResponse()
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["proto"],
|
||||
clientFlags)
|
||||
|
||||
await wsClient.send(testString)
|
||||
await wsClient.close()
|
||||
|
||||
test "Client - test reading simple frame":
|
||||
let testString = "Hello!"
|
||||
proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
check request.uri.path == "/wss"
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
await ws.send(testString)
|
||||
await ws.close()
|
||||
|
||||
return dumbResponse()
|
||||
|
||||
let res = SecureHttpServerRef.new(
|
||||
address, cb,
|
||||
serverFlags = serverFlags,
|
||||
socketFlags = socketFlags,
|
||||
tlsPrivateKey = secureKey,
|
||||
tlsCertificate = secureCert)
|
||||
|
||||
server = res.get()
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.tlsConnect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/wss",
|
||||
protocols = @["proto"],
|
||||
clientFlags)
|
||||
|
||||
let clientRes = await wsClient.recv()
|
||||
check string.fromBytes(clientRes) == testString
|
||||
await waitForClose(wsClient)
|
||||
@ -13,7 +13,6 @@ import
|
||||
stew/byteutils,
|
||||
asynctest,
|
||||
chronos,
|
||||
chronos/apps/http/httpserver,
|
||||
chronicles
|
||||
],
|
||||
../ws/[ws, utf8_dfa]
|
||||
@ -81,21 +80,17 @@ proc waitForClose(ws: WSSession) {.async.} =
|
||||
|
||||
# TODO: use new test framework from dryajov
|
||||
# if it is ready.
|
||||
var server: HttpServerRef
|
||||
var server: HttpServer
|
||||
let address = initTAddress("127.0.0.1:8888")
|
||||
|
||||
suite "UTF-8 validator in action":
|
||||
teardown:
|
||||
await server.stop()
|
||||
server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
test "valid UTF-8 sequence":
|
||||
let testData = "hello world"
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == "/ws"
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
@ -108,32 +103,30 @@ suite "UTF-8 validator in action":
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
let res = HttpServerRef.new(address, process)
|
||||
server = res.get()
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.connect(
|
||||
let session = await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws",
|
||||
protocols = @["proto"],
|
||||
)
|
||||
|
||||
await wsClient.send(testData)
|
||||
await wsClient.close()
|
||||
await session.send(testData)
|
||||
await session.close()
|
||||
|
||||
test "valid UTF-8 sequence in close reason":
|
||||
let testData = "hello world"
|
||||
let closeReason = "i want to close"
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == "/ws"
|
||||
|
||||
proc onClose(status: Status, reason: string): CloseResult{.gcsafe,
|
||||
raises: [Defect].} =
|
||||
proc onClose(status: Status, reason: string):
|
||||
CloseResult {.gcsafe, raises: [Defect].} =
|
||||
try:
|
||||
check status == Status.Fulfilled
|
||||
check reason == closeReason
|
||||
@ -143,7 +136,6 @@ suite "UTF-8 validator in action":
|
||||
|
||||
let server = WSServer.new(protos = ["proto"], onClose = onClose)
|
||||
let ws = await server.handleRequest(request)
|
||||
|
||||
let res = await ws.recv()
|
||||
check:
|
||||
string.fromBytes(res) == testData
|
||||
@ -151,57 +143,54 @@ suite "UTF-8 validator in action":
|
||||
|
||||
await waitForClose(ws)
|
||||
|
||||
let res = HttpServerRef.new(address, process)
|
||||
server = res.get()
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.connect(
|
||||
let session = await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws",
|
||||
protocols = @["proto"],
|
||||
)
|
||||
|
||||
await wsClient.send(testData)
|
||||
await wsClient.close(reason = closeReason)
|
||||
await session.send(testData)
|
||||
await session.close(reason = closeReason)
|
||||
|
||||
test "invalid UTF-8 sequence":
|
||||
# TODO: how to check for Invalid UTF8 exception?
|
||||
let testData = "hello world\xc0\xaf"
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == "/ws"
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
let ws = await server.handleRequest(request)
|
||||
discard await ws.recv()
|
||||
|
||||
let res = HttpServerRef.new(address, process)
|
||||
server = res.get()
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.connect(
|
||||
let session = await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws",
|
||||
protocols = @["proto"]
|
||||
)
|
||||
|
||||
await wsClient.send(testData)
|
||||
await waitForClose(wsClient)
|
||||
check wsClient.readyState == ReadyState.Closed
|
||||
await session.send(testData)
|
||||
await waitForClose( session)
|
||||
check session.readyState == ReadyState.Closed
|
||||
|
||||
test "invalid UTF-8 sequence close code":
|
||||
# TODO: how to check for Invalid UTF8 exception?
|
||||
let testData = "hello world"
|
||||
let closeReason = "i want to close\xc0\xaf"
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isErr():
|
||||
return dumbResponse()
|
||||
|
||||
let request = r.get()
|
||||
proc handle(request: HttpRequest) {.async.} =
|
||||
check request.uri.path == "/ws"
|
||||
|
||||
let server = WSServer.new(protos = ["proto"])
|
||||
@ -212,18 +201,20 @@ suite "UTF-8 validator in action":
|
||||
string.fromBytes(res) == testData
|
||||
ws.binary == false
|
||||
|
||||
let res = HttpServerRef.new(address, process)
|
||||
server = res.get()
|
||||
server = HttpServer.create(
|
||||
address,
|
||||
handle,
|
||||
flags = {ReuseAddr})
|
||||
server.start()
|
||||
|
||||
let wsClient = await WebSocket.connect(
|
||||
let session = await WebSocket.connect(
|
||||
"127.0.0.1",
|
||||
Port(8888),
|
||||
path = "/ws",
|
||||
protocols = @["proto"]
|
||||
)
|
||||
|
||||
await wsClient.send(testData)
|
||||
await wsClient.close(reason = closeReason)
|
||||
await waitForClose(wsClient)
|
||||
check wsClient.readyState == ReadyState.Closed
|
||||
await session.send(testData)
|
||||
await session.close(reason = closeReason)
|
||||
await waitForClose( session)
|
||||
check session.readyState == ReadyState.Closed
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
10
ws.nimble
10
ws.nimble
@ -15,5 +15,11 @@ requires "nimcrypto"
|
||||
requires "bearssl"
|
||||
|
||||
task test, "run tests":
|
||||
exec "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testall.nim"
|
||||
rmFile "./tests/testall"
|
||||
exec "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testcommon.nim"
|
||||
rmFile "./tests/testcommon"
|
||||
|
||||
exec "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testwebsockets.nim"
|
||||
rmFile "./tests/testwebsockets"
|
||||
|
||||
exec "nim -d:secure c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info ./tests/testwebsockets.nim"
|
||||
rmFile "./tests/testwebsockets"
|
||||
|
||||
11
ws/frame.nim
11
ws/frame.nim
@ -9,7 +9,12 @@
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import pkg/[chronos, chronicles, stew/endians2, stew/results]
|
||||
import pkg/[
|
||||
chronos,
|
||||
chronicles,
|
||||
stew/byteutils,
|
||||
stew/endians2,
|
||||
stew/results]
|
||||
import ./types
|
||||
|
||||
#[
|
||||
@ -94,7 +99,6 @@ proc encode*(
|
||||
ret.add(len.toBytesBE())
|
||||
|
||||
var data = f.data
|
||||
|
||||
if f.mask:
|
||||
# If we need to mask it generate random mask key and mask the data.
|
||||
mask(data, f.maskKey, offset)
|
||||
@ -118,6 +122,7 @@ proc decode*(
|
||||
##
|
||||
|
||||
var header = newSeq[byte](2)
|
||||
debug "Reading new frame"
|
||||
await reader.readExactly(addr header[0], 2)
|
||||
if header.len != 2:
|
||||
debug "Invalid websocket header length"
|
||||
@ -137,7 +142,7 @@ proc decode*(
|
||||
frame.rsv3 = HeaderFlag.rsv3 in hf
|
||||
|
||||
let opcode = (b0 and 0x0f)
|
||||
if opcode > ord(Opcode.high):
|
||||
if opcode > ord(Opcode.Pong):
|
||||
raise newException(WSOpcodeMismatchError, "Wrong opcode!")
|
||||
|
||||
frame.opcode = (opcode).Opcode
|
||||
|
||||
13
ws/http.nim
Normal file
13
ws/http.nim
Normal file
@ -0,0 +1,13 @@
|
||||
import std/uri
|
||||
import pkg/[
|
||||
chronos,
|
||||
chronos/apps/http/httptable,
|
||||
chronos/streams/tlsstream,
|
||||
httputils]
|
||||
|
||||
import ./http/client, ./http/server, ./http/common
|
||||
|
||||
export uri, httputils, client, server, httptable, tlsstream
|
||||
export TlsHttpClient, HttpClient, HttpServer,
|
||||
HttpResponse, HttpRequest, closeWait, sendResponse,
|
||||
sendError
|
||||
167
ws/http/client.nim
Normal file
167
ws/http/client.nim
Normal file
@ -0,0 +1,167 @@
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import std/[uri, strutils]
|
||||
import pkg/[
|
||||
chronos,
|
||||
chronicles,
|
||||
httputils,
|
||||
stew/byteutils]
|
||||
|
||||
import ./common
|
||||
|
||||
type
|
||||
HttpClient* = ref object of RootObj
|
||||
connected*: bool
|
||||
hostname*: string
|
||||
address*: TransportAddress
|
||||
version*: HttpVersion
|
||||
port*: Port
|
||||
stream*: AsyncStream
|
||||
buf*: seq[byte]
|
||||
|
||||
TlsHttpClient* = ref object of HttpClient
|
||||
tlsFlags*: set[TLSFlags]
|
||||
minVersion*: TLSVersion
|
||||
maxVersion*: TLSVersion
|
||||
|
||||
proc close*(client: HttpClient): Future[void] =
|
||||
client.stream.closeWait()
|
||||
|
||||
proc readResponse(stream: AsyncStreamReader): Future[HttpResponseHeader] {.async.} =
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
try:
|
||||
let
|
||||
hlenfut = stream.readUntil(
|
||||
addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep)
|
||||
ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
|
||||
if not ores:
|
||||
raise newException(HttpError,
|
||||
"Timeout expired while receiving headers")
|
||||
|
||||
let hlen = hlenfut.read()
|
||||
buffer.setLen(hlen)
|
||||
|
||||
return buffer.parseResponse()
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading headers", exc = exc.msg
|
||||
buffer.setLen(0)
|
||||
raise exc
|
||||
|
||||
proc generateHeaders(
|
||||
requestUrl: Uri,
|
||||
httpMethod: HttpMethod,
|
||||
version: HttpVersion,
|
||||
headers: HttpTables): string =
|
||||
var headersData = toUpperAscii($httpMethod)
|
||||
headersData.add " "
|
||||
|
||||
if not requestUrl.path.startsWith("/"): headersData.add "/"
|
||||
headersData.add(requestUrl.path & " ")
|
||||
headersData.add($version & CRLF)
|
||||
|
||||
for (key, val) in headers.stringItems(true):
|
||||
headersData.add(key)
|
||||
headersData.add(": ")
|
||||
headersData.add(val)
|
||||
headersData.add(CRLF)
|
||||
|
||||
headersData.add(CRLF)
|
||||
return headersData
|
||||
|
||||
proc request*(
|
||||
client: HttpClient,
|
||||
url: string | Uri,
|
||||
httpMethod = MethodGet,
|
||||
headers: HttpTables,
|
||||
body: seq[byte] = @[]): Future[HttpResponse] {.async.} =
|
||||
## Helper that actually makes the request.
|
||||
## Does not handle redirects.
|
||||
##
|
||||
|
||||
if not client.connected:
|
||||
raise newException(HttpError, "No connection to host!")
|
||||
|
||||
let requestUrl =
|
||||
when url is string:
|
||||
url.parseUri()
|
||||
else:
|
||||
url
|
||||
|
||||
if requestUrl.scheme == "":
|
||||
raise newException(HttpError, "No uri scheme supplied.")
|
||||
|
||||
let headerString = generateHeaders(requestUrl, httpMethod, client.version, headers)
|
||||
|
||||
await client.stream.writer.write(headerString)
|
||||
let response = await client.stream.reader.readResponse()
|
||||
let headers =
|
||||
block:
|
||||
var res = HttpTable.init()
|
||||
for key, value in response.headers():
|
||||
res.add(key, value)
|
||||
res
|
||||
|
||||
return HttpResponse(
|
||||
headers: headers,
|
||||
stream: client.stream,
|
||||
code: response.code,
|
||||
reason: response.reason())
|
||||
|
||||
proc connect*(
|
||||
T: typedesc[HttpClient | TlsHttpClient],
|
||||
address: TransportAddress,
|
||||
version = HttpVersion11,
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS11,
|
||||
tlsMaxVersion = TLSVersion.TLS12): Future[T] {.async.} =
|
||||
|
||||
let transp = await connect(address)
|
||||
let client = T(
|
||||
hostname: address.host,
|
||||
port: address.port,
|
||||
address: transp.remoteAddress(),
|
||||
version: version)
|
||||
|
||||
var stream = AsyncStream(
|
||||
reader: newAsyncStreamReader(transp),
|
||||
writer: newAsyncStreamWriter(transp))
|
||||
|
||||
when T is TlsHttpClient:
|
||||
client.tlsFlags = tlsFlags
|
||||
client.minVersion = tlsMinVersion
|
||||
client.maxVersion = tlsMaxVersion
|
||||
|
||||
let tlsStream = newTLSClientAsyncStream(
|
||||
stream.reader,
|
||||
stream.writer,
|
||||
address.host,
|
||||
minVersion = tlsMinVersion,
|
||||
maxVersion = tlsMaxVersion,
|
||||
flags = tlsFlags)
|
||||
|
||||
stream = AsyncStream(
|
||||
reader: tlsStream.reader,
|
||||
writer: tlsStream.writer)
|
||||
|
||||
client.stream = stream
|
||||
client.connected = true
|
||||
|
||||
return client
|
||||
|
||||
proc connect*(
|
||||
T: typedesc[HttpClient | TlsHttpClient],
|
||||
host: string,
|
||||
port: int = 80,
|
||||
version = HttpVersion11,
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS11,
|
||||
tlsMaxVersion = TLSVersion.TLS12): Future[T]
|
||||
{.raises: [Defect, HttpError].} =
|
||||
|
||||
let address = try:
|
||||
initTAddress(host, port)
|
||||
except TransportAddressError as exc:
|
||||
raise newException(HttpError, exc.msg)
|
||||
|
||||
return T.connect(address, version, tlsFlags, tlsMinVersion, tlsMaxVersion)
|
||||
122
ws/http/common.nim
Normal file
122
ws/http/common.nim
Normal file
@ -0,0 +1,122 @@
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import std/[uri]
|
||||
import pkg/[
|
||||
chronos,
|
||||
httputils,
|
||||
stew/byteutils,
|
||||
chronicles]
|
||||
|
||||
import pkg/[
|
||||
chronos/apps/http/httptable,
|
||||
chronos/streams/tlsstream]
|
||||
|
||||
export httputils, httptable, tlsstream, uri
|
||||
|
||||
const
|
||||
MaxHttpHeadersSize* = 8192 # maximum size of HTTP headers in octets
|
||||
MaxHttpRequestSize* = 128 * 1024 # maximum size of HTTP body in octets
|
||||
HttpHeadersTimeout* = 120.seconds # timeout for receiving headers (120 sec)
|
||||
HeaderSep* = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')]
|
||||
CRLF* = "\r\n"
|
||||
|
||||
type
|
||||
ReqStatus* {.pure.} = enum
|
||||
Success, Error, ErrorFailure
|
||||
|
||||
HttpCommon* = ref object of RootObj
|
||||
headers*: HttpTable
|
||||
code*: int
|
||||
version*: HttpVersion
|
||||
stream*: AsyncStream
|
||||
|
||||
HttpRequest* = ref object of HttpCommon
|
||||
uri*: Uri
|
||||
meth*: HttpMethod
|
||||
|
||||
# TODO: add useful response params, like body len
|
||||
HttpResponse* = ref object of HttpCommon
|
||||
reason*: string
|
||||
|
||||
HttpError* = object of CatchableError
|
||||
HttpHeaderError* = HttpError
|
||||
|
||||
proc closeTransp*(transp: StreamTransport) {.async.} =
|
||||
if not transp.closed():
|
||||
await transp.closeWait()
|
||||
|
||||
proc closeStream*(stream: AsyncStreamRW) {.async.} =
|
||||
if not stream.closed():
|
||||
await stream.closeWait()
|
||||
|
||||
proc closeWait*(stream: AsyncStream) {.async.} =
|
||||
await allFutures(
|
||||
stream.reader.tsource.closeTransp(),
|
||||
stream.reader.closeStream(),
|
||||
stream.writer.closeStream()
|
||||
)
|
||||
|
||||
proc sendResponse*(
|
||||
request: HttpRequest,
|
||||
code: HttpCode,
|
||||
headers: HttpTables = HttpTable.init(),
|
||||
data: seq[byte] = @[],
|
||||
version = HttpVersion11,
|
||||
content = "") {.async.} =
|
||||
## Send response
|
||||
##
|
||||
|
||||
var headers = headers
|
||||
var response: string = $version
|
||||
response.add(" ")
|
||||
response.add($code)
|
||||
response.add(CRLF)
|
||||
response.add("Date: " & httpDate() & CRLF)
|
||||
|
||||
if data.len > 0:
|
||||
if headers.getInt("Content-Length").int != data.len:
|
||||
warn "Wrong content length header, overriding"
|
||||
headers.set("Content-Length", $data.len)
|
||||
|
||||
if headers.getString("Content-Type") != content:
|
||||
headers.set("Content-Type",
|
||||
if content.len > 0: content else: "text/html")
|
||||
|
||||
for key, val in headers.stringItems(true):
|
||||
response.add(key)
|
||||
response.add(": ")
|
||||
response.add(val)
|
||||
response.add(CRLF)
|
||||
|
||||
response.add(CRLF)
|
||||
await request.stream.writer.write(
|
||||
response.toBytes() & data)
|
||||
|
||||
proc sendResponse*(
|
||||
request: HttpRequest,
|
||||
code: HttpCode,
|
||||
headers: HttpTables = HttpTable.init(),
|
||||
data: string,
|
||||
version = HttpVersion11,
|
||||
content = ""): Future[void] =
|
||||
request.sendResponse(code, headers, data.toBytes(), version, content)
|
||||
|
||||
proc sendError*(
|
||||
stream: AsyncStreamWriter,
|
||||
code: HttpCode,
|
||||
version = HttpVersion11) {.async.} =
|
||||
let content = $code
|
||||
var response: string = $version
|
||||
response.add(" ")
|
||||
response.add(content & CRLF)
|
||||
response.add(CRLF)
|
||||
|
||||
await stream.write(
|
||||
response.toBytes() &
|
||||
content.toBytes())
|
||||
|
||||
proc sendError*(
|
||||
request: HttpRequest,
|
||||
code: HttpCode,
|
||||
version = HttpVersion11): Future[void] =
|
||||
request.stream.writer.sendError(code, version)
|
||||
214
ws/http/server.nim
Normal file
214
ws/http/server.nim
Normal file
@ -0,0 +1,214 @@
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import std/uri
|
||||
import pkg/[
|
||||
chronos,
|
||||
chronicles,
|
||||
httputils]
|
||||
|
||||
import ./common
|
||||
|
||||
type
|
||||
HttpAsyncCallback* = proc (request: HttpRequest):
|
||||
Future[void] {.closure, gcsafe, raises: [Defect].}
|
||||
|
||||
HttpServer* = ref object of StreamServer
|
||||
handler*: HttpAsyncCallback
|
||||
|
||||
TlsHttpServer* = ref object of HttpServer
|
||||
tlsFlags*: set[TLSFlags]
|
||||
tlsPrivateKey*: TLSPrivateKey
|
||||
tlsCertificate*: TLSCertificate
|
||||
minVersion*: TLSVersion
|
||||
maxVersion*: TLSVersion
|
||||
|
||||
proc validateRequest(
|
||||
stream: AsyncStreamWriter,
|
||||
header: HttpRequestHeader): Future[ReqStatus] {.async.} =
|
||||
## Validate Request
|
||||
##
|
||||
|
||||
if header.meth notin {MethodGet}:
|
||||
debug "GET method is only allowed", address = stream.tsource.remoteAddress()
|
||||
await stream.sendError(Http405, version = header.version)
|
||||
return ReqStatus.Error
|
||||
|
||||
var hlen = header.contentLength()
|
||||
if hlen < 0 or hlen > MaxHttpRequestSize:
|
||||
debug "Invalid header length", address = stream.tsource.remoteAddress()
|
||||
await stream.sendError(Http413, version = header.version)
|
||||
return ReqStatus.Error
|
||||
|
||||
return ReqStatus.Success
|
||||
|
||||
proc handleRequest(
|
||||
server: HttpServer,
|
||||
stream: AsyncStream) {.async.} =
|
||||
## Process transport data to the HTTP server
|
||||
##
|
||||
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
let remoteAddr = stream.reader.tsource.remoteAddress()
|
||||
debug "Received connection", address = $remoteAddr
|
||||
try:
|
||||
let hlenfut = stream.reader.readUntil(
|
||||
addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep)
|
||||
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
if not ores:
|
||||
# Timeout
|
||||
debug "Timeout expired while receiving headers", address = $remoteAddr
|
||||
await stream.writer.sendError(Http408, version = HttpVersion11)
|
||||
return
|
||||
|
||||
let hlen = hlenfut.read()
|
||||
buffer.setLen(hlen)
|
||||
let requestData = buffer.parseRequest()
|
||||
if requestData.failed():
|
||||
# Header could not be parsed
|
||||
debug "Malformed header received", address = $remoteAddr
|
||||
await stream.writer.sendError(Http400, version = HttpVersion11)
|
||||
return
|
||||
|
||||
var vres = await stream.writer.validateRequest(requestData)
|
||||
let hdrs =
|
||||
block:
|
||||
var res = HttpTable.init()
|
||||
for key, value in requestData.headers():
|
||||
res.add(key, value)
|
||||
res
|
||||
|
||||
if vres == ReqStatus.ErrorFailure:
|
||||
debug "Remote peer disconnected", address = $remoteAddr
|
||||
return
|
||||
|
||||
debug "Received valid HTTP request", address = $remoteAddr
|
||||
# Call the user's handler.
|
||||
if server.handler != nil:
|
||||
await server.handler(
|
||||
HttpRequest(
|
||||
headers: hdrs,
|
||||
stream: stream,
|
||||
uri: requestData.uri().parseUri()))
|
||||
except TransportLimitError:
|
||||
# size of headers exceeds `MaxHttpHeadersSize`
|
||||
debug "Maximum size of headers limit reached", address = $remoteAddr
|
||||
await stream.writer.sendError(Http413, version = HttpVersion11)
|
||||
except TransportIncompleteError:
|
||||
# remote peer disconnected
|
||||
debug "Remote peer disconnected", address = $remoteAddr
|
||||
except TransportOsError as exc:
|
||||
debug "Problems with networking", address = $remoteAddr, error = exc.msg
|
||||
except CatchableError as exc:
|
||||
debug "Unknown exception", address = $remoteAddr, error = exc.msg
|
||||
finally:
|
||||
await stream.closeWait()
|
||||
|
||||
proc handleConnCb(
|
||||
server: StreamServer,
|
||||
transp: StreamTransport) {.async.} =
|
||||
|
||||
let stream = AsyncStream(
|
||||
reader: newAsyncStreamReader(transp),
|
||||
writer: newAsyncStreamWriter(transp))
|
||||
|
||||
let httpServer = HttpServer(server)
|
||||
await httpServer.handleRequest(stream)
|
||||
|
||||
proc handleTlsConnCb(
|
||||
server: StreamServer,
|
||||
transp: StreamTransport) {.async.} =
|
||||
|
||||
let tlsHttpServer = TlsHttpServer(server)
|
||||
let stream = newTLSServerAsyncStream(
|
||||
newAsyncStreamReader(transp),
|
||||
newAsyncStreamWriter(transp),
|
||||
tlsHttpServer.tlsPrivateKey,
|
||||
tlsHttpServer.tlsCertificate,
|
||||
minVersion = tlsHttpServer.minVersion,
|
||||
maxVersion = tlsHttpServer.maxVersion,
|
||||
flags = tlsHttpServer.tlsFlags)
|
||||
|
||||
await HttpServer(tlsHttpServer)
|
||||
.handleRequest(AsyncStream(
|
||||
reader: stream.reader,
|
||||
writer: stream.writer))
|
||||
|
||||
proc create*(
|
||||
_: typedesc[HttpServer],
|
||||
address: TransportAddress,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = {}): HttpServer
|
||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||
## Make a new HTTP Server
|
||||
##
|
||||
|
||||
var server = HttpServer(handler: handler)
|
||||
server = HttpServer(
|
||||
createStreamServer(
|
||||
address,
|
||||
handleConnCb,
|
||||
flags,
|
||||
child = StreamServer(server)))
|
||||
|
||||
return server
|
||||
|
||||
proc create*(
|
||||
_: typedesc[HttpServer],
|
||||
host: string,
|
||||
port: Port,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = {}): HttpServer
|
||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||
## Make a new HTTP Server
|
||||
##
|
||||
|
||||
return HttpServer.create(initTAddress(host, port), handler, flags)
|
||||
|
||||
proc create*(
|
||||
_: typedesc[TlsHttpServer],
|
||||
address: TransportAddress,
|
||||
tlsPrivateKey: TLSPrivateKey,
|
||||
tlsCertificate: TLSCertificate,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = {},
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS12,
|
||||
tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer
|
||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||
|
||||
var server = TlsHttpServer(
|
||||
handler: handler,
|
||||
tlsPrivateKey: tlsPrivateKey,
|
||||
tlsCertificate: tlsCertificate,
|
||||
minVersion: tlsMinVersion,
|
||||
maxVersion: tlsMaxVersion)
|
||||
|
||||
server = TlsHttpServer(
|
||||
createStreamServer(
|
||||
address,
|
||||
handleTlsConnCb,
|
||||
flags,
|
||||
child = StreamServer(server)))
|
||||
|
||||
return server
|
||||
|
||||
proc create*(
|
||||
_: typedesc[TlsHttpServer],
|
||||
host: string,
|
||||
port: Port,
|
||||
tlsPrivateKey: TLSPrivateKey,
|
||||
tlsCertificate: TLSCertificate,
|
||||
handler: HttpAsyncCallback = nil,
|
||||
flags: set[ServerFlags] = {},
|
||||
tlsFlags: set[TLSFlags] = {},
|
||||
tlsMinVersion = TLSVersion.TLS12,
|
||||
tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer
|
||||
{.raises: [Defect, CatchableError].} = # TODO: remove CatchableError
|
||||
TlsHttpServer.create(
|
||||
address = initTAddress(host, port),
|
||||
handler = handler,
|
||||
tlsPrivateKey = tlsPrivateKey,
|
||||
tlsCertificate = tlsCertificate,
|
||||
flags = flags,
|
||||
tlsFlags = tlsFlags)
|
||||
163
ws/session.nim
163
ws/session.nim
@ -10,16 +10,15 @@
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import pkg/[chronos, chronicles, stew/byteutils, stew/endians2]
|
||||
import ./types, ./frame, ./utils, ./stream, ./utf8_dfa
|
||||
import ./types, ./frame, ./utils, ./utf8_dfa, ./http
|
||||
|
||||
import pkg/chronos/[
|
||||
streams/asyncstream,
|
||||
streams/tlsstream]
|
||||
import pkg/chronos/[streams/asyncstream]
|
||||
|
||||
type
|
||||
WSSession* = ref object of WebSocket
|
||||
stream*: AsyncStream
|
||||
frame*: Frame
|
||||
proto*: string
|
||||
|
||||
proc prepareCloseBody(code: Status, reason: string): seq[byte] =
|
||||
result = reason.toBytes
|
||||
@ -102,7 +101,7 @@ proc handleClose*(
|
||||
opcode = frame.opcode
|
||||
readyState = ws.readyState
|
||||
|
||||
debug "Handling close sequence"
|
||||
debug "Handling close"
|
||||
|
||||
if ws.readyState notin {ReadyState.Open}:
|
||||
debug "Connection isn't open, abortig close sequence!"
|
||||
@ -133,7 +132,8 @@ proc handleClose*(
|
||||
reason = string.fromBytes(payLoad[2..payLoad.high])
|
||||
|
||||
if not ws.binary and validateUTF8(reason) == false:
|
||||
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected in close reason")
|
||||
raise newException(WSInvalidUTF8,
|
||||
"Invalid UTF8 sequence detected in close reason")
|
||||
|
||||
var rcode: Status
|
||||
if code in {Status.Fulfilled}:
|
||||
@ -157,6 +157,15 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
|
||||
## Handle control frames
|
||||
##
|
||||
|
||||
logScope:
|
||||
fin = frame.fin
|
||||
masked = frame.mask
|
||||
opcode = frame.opcode
|
||||
readyState = ws.readyState
|
||||
len = frame.length
|
||||
|
||||
debug "Handling control frame"
|
||||
|
||||
if not frame.fin:
|
||||
raise newException(WSFragmentedControlFrameError,
|
||||
"Control frame cannot be fragmented!")
|
||||
@ -165,70 +174,53 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} =
|
||||
raise newException(WSPayloadTooLarge,
|
||||
"Control message payload is greater than 125 bytes!")
|
||||
|
||||
try:
|
||||
var payLoad = newSeq[byte](frame.length.int)
|
||||
if frame.length > 0:
|
||||
payLoad.setLen(frame.length.int)
|
||||
# Read control frame payload.
|
||||
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
|
||||
if frame.mask:
|
||||
mask(
|
||||
payLoad.toOpenArray(0, payLoad.high),
|
||||
frame.maskKey)
|
||||
var payLoad = newSeq[byte](frame.length.int)
|
||||
if frame.length > 0:
|
||||
payLoad.setLen(frame.length.int)
|
||||
# Read control frame payload.
|
||||
await ws.stream.reader.readExactly(addr payLoad[0], frame.length.int)
|
||||
if frame.mask:
|
||||
mask(
|
||||
payLoad.toOpenArray(0, payLoad.high),
|
||||
frame.maskKey)
|
||||
|
||||
# Process control frame payload.
|
||||
case frame.opcode:
|
||||
of Opcode.Ping:
|
||||
if not isNil(ws.onPing):
|
||||
try:
|
||||
ws.onPing(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
||||
# Process control frame payload.
|
||||
case frame.opcode:
|
||||
of Opcode.Ping:
|
||||
if not isNil(ws.onPing):
|
||||
try:
|
||||
ws.onPing(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Ping callback, this is most likelly a bug", exc = exc.msg
|
||||
|
||||
# send pong to remote
|
||||
await ws.send(payLoad, Opcode.Pong)
|
||||
of Opcode.Pong:
|
||||
if not isNil(ws.onPong):
|
||||
try:
|
||||
ws.onPong(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
||||
of Opcode.Close:
|
||||
await ws.handleClose(frame, payLoad)
|
||||
else:
|
||||
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
|
||||
except WebSocketError as exc:
|
||||
debug "Handled websocket exception", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "Exception handling control messages", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
# send pong to remote
|
||||
await ws.send(payLoad, Opcode.Pong)
|
||||
of Opcode.Pong:
|
||||
if not isNil(ws.onPong):
|
||||
try:
|
||||
ws.onPong(payLoad)
|
||||
except CatchableError as exc:
|
||||
debug "Exception in Pong callback, this is most likelly a bug", exc = exc.msg
|
||||
of Opcode.Close:
|
||||
await ws.handleClose(frame, payLoad)
|
||||
else:
|
||||
raise newException(WSInvalidOpcodeError, "Invalid control opcode!")
|
||||
|
||||
proc readFrame*(ws: WSSession): Future[Frame] {.async.} =
|
||||
## Gets a frame from the WebSocket.
|
||||
## See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
##
|
||||
|
||||
try:
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let frame = await Frame.decode(ws.stream.reader, ws.masked)
|
||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
let frame = await Frame.decode(ws.stream.reader, ws.masked)
|
||||
debug "Decoded new frame", opcode = frame.opcode, len = frame.length, mask = frame.mask
|
||||
|
||||
# return the current frame if it's not one of the control frames
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||
await ws.handleControl(frame) # process control frames# process control frames
|
||||
continue
|
||||
# return the current frame if it's not one of the control frames
|
||||
if frame.opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}:
|
||||
await ws.handleControl(frame) # process control frames# process control frames
|
||||
continue
|
||||
|
||||
return frame
|
||||
except WebSocketError as exc:
|
||||
trace "Websocket error", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading frame, dropping socket", exc = exc.msg
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
raise exc
|
||||
return frame
|
||||
|
||||
proc ping*(ws: WSSession, data: seq[byte] = @[]): Future[void] =
|
||||
ws.send(data, opcode = Opcode.Ping)
|
||||
@ -303,17 +295,11 @@ proc recv*(
|
||||
raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected")
|
||||
|
||||
return consumed.int
|
||||
|
||||
except WebSocketError as exc:
|
||||
debug "Websocket error", exc = exc.msg
|
||||
except CatchableError as exc:
|
||||
ws.readyState = ReadyState.Closed
|
||||
await ws.stream.closeWait()
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
debug "Cancelling reading", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading frames", exc = exc.msg
|
||||
raise exc
|
||||
|
||||
proc recv*(
|
||||
ws: WSSession,
|
||||
@ -330,34 +316,25 @@ proc recv*(
|
||||
## In all other cases it awaits a full message.
|
||||
##
|
||||
var res: seq[byte]
|
||||
try:
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
var buf = newSeq[byte](ws.frameSize)
|
||||
let read = await ws.recv(addr buf[0], buf.len)
|
||||
if read <= 0:
|
||||
break
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
var buf = newSeq[byte](ws.frameSize)
|
||||
let read = await ws.recv(addr buf[0], buf.len)
|
||||
if read <= 0:
|
||||
break
|
||||
|
||||
buf.setLen(read)
|
||||
if res.len + buf.len > size:
|
||||
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
|
||||
buf.setLen(read)
|
||||
if res.len + buf.len > size:
|
||||
raise newException(WSMaxMessageSizeError, "Max message size exceeded")
|
||||
|
||||
res.add(buf)
|
||||
res.add(buf)
|
||||
|
||||
# no more frames
|
||||
if isNil(ws.frame):
|
||||
break
|
||||
# no more frames
|
||||
if isNil(ws.frame):
|
||||
break
|
||||
|
||||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
break
|
||||
except WebSocketError as exc:
|
||||
debug "Websocket error", exc = exc.msg
|
||||
raise exc
|
||||
except CancelledError as exc:
|
||||
debug "Cancelling reading", exc = exc.msg
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
debug "Exception reading frames", exc = exc.msg
|
||||
# read the entire message, exit
|
||||
if ws.frame.fin and ws.frame.remainder().int <= 0:
|
||||
break
|
||||
|
||||
return res
|
||||
|
||||
@ -380,7 +357,5 @@ proc close*(
|
||||
# read frames until closed
|
||||
while ws.readyState != ReadyState.Closed:
|
||||
discard await ws.recv()
|
||||
|
||||
except CatchableError as exc:
|
||||
debug "Exception closing", exc = exc.msg
|
||||
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/httpserver,
|
||||
chronos/timer,
|
||||
chronicles,
|
||||
httputils]
|
||||
import strutils
|
||||
|
||||
const
|
||||
HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')]
|
||||
HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec)
|
||||
MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets
|
||||
|
||||
proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} =
|
||||
var buffer = newSeq[byte](MaxHttpHeadersSize)
|
||||
var error = false
|
||||
try:
|
||||
let hlenfut = rstream.readUntil(
|
||||
addr buffer[0], MaxHttpHeadersSize,
|
||||
sep = HeaderSep)
|
||||
let ores = await withTimeout(hlenfut, HttpHeadersTimeout)
|
||||
if not ores:
|
||||
# Timeout
|
||||
debug "Timeout expired while receiving headers",
|
||||
address = rstream.tsource.remoteAddress()
|
||||
error = true
|
||||
else:
|
||||
let hlen = hlenfut.read()
|
||||
buffer.setLen(hlen)
|
||||
except AsyncStreamLimitError:
|
||||
# size of headers exceeds `MaxHttpHeadersSize`
|
||||
debug "Maximum size of headers limit reached",
|
||||
address = rstream.tsource.remoteAddress()
|
||||
error = true
|
||||
except AsyncStreamIncompleteError:
|
||||
# remote peer disconnected
|
||||
debug "Remote peer disconnected", address = rstream.tsource.remoteAddress()
|
||||
error = true
|
||||
except AsyncStreamError as exc:
|
||||
debug "Problems with networking", address = rstream.tsource.remoteAddress(),
|
||||
error = exc.msg
|
||||
error = true
|
||||
|
||||
if error:
|
||||
buffer.setLen(0)
|
||||
|
||||
return buffer
|
||||
|
||||
proc closeWait*(wsStream: AsyncStream) {.async.} =
|
||||
# TODO: this is most likelly wrongs
|
||||
await allFutures(
|
||||
wsStream.writer.closeWait(),
|
||||
wsStream.reader.closeWait())
|
||||
|
||||
await allFutures(
|
||||
wsStream.writer.tsource.closeWait(),
|
||||
wsStream.reader.tsource.closeWait())
|
||||
|
||||
# TODO: Implement stream read and write wrapper.
|
||||
@ -9,7 +9,7 @@
|
||||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import chronos
|
||||
import pkg/[chronos, chronos/streams/tlsstream]
|
||||
import ./utils
|
||||
|
||||
const
|
||||
@ -19,7 +19,6 @@ const
|
||||
WSDefaultFrameSize* = 1 shl 20 # 1mb
|
||||
WSMaxMessageSize* = 20 shl 20 # 20mb
|
||||
WSGuid* = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
CRLF* = "\r\n"
|
||||
|
||||
type
|
||||
ReadyState* {.pure.} = enum
|
||||
@ -38,6 +37,7 @@ type
|
||||
Ping = 0x9 ## Denotes a ping.
|
||||
Pong = 0xa ## Denotes a pong.
|
||||
# B-F are reserved for further control frames.
|
||||
Reserved = 0xf
|
||||
|
||||
HeaderFlag* {.pure, size: sizeof(uint8).} = enum
|
||||
rsv3
|
||||
@ -96,10 +96,10 @@ type
|
||||
extensions: seq[Extension] # extension active for this session
|
||||
version*: uint
|
||||
key*: string
|
||||
proto*: string
|
||||
readyState*: ReadyState
|
||||
masked*: bool # send masked packets
|
||||
binary*: bool # is payload binary?
|
||||
flags*: set[TLSFlags]
|
||||
rng*: Rng
|
||||
frameSize*: int
|
||||
onPing*: ControlCb
|
||||
|
||||
@ -7,7 +7,7 @@ const randMax = 18_446_744_073_709_551_615'u64
|
||||
type
|
||||
Rng* = ref BrHmacDrbgContext
|
||||
|
||||
proc newRng*(): ref BrHmacDrbgContext =
|
||||
proc newRng*(): Rng =
|
||||
# You should only create one instance of the RNG per application / library
|
||||
# Ref is used so that it can be shared between components
|
||||
# TODO consider moving to bearssl
|
||||
@ -15,10 +15,11 @@ proc newRng*(): ref BrHmacDrbgContext =
|
||||
if seeder == nil:
|
||||
return nil
|
||||
|
||||
var rng = (ref BrHmacDrbgContext)()
|
||||
var rng = Rng()
|
||||
brHmacDrbgInit(addr rng[], addr sha256Vtable, nil, 0)
|
||||
if seeder(addr rng.vtable) == 0:
|
||||
return nil
|
||||
|
||||
rng
|
||||
|
||||
proc rand*(rng: Rng, max: Natural): int =
|
||||
|
||||
332
ws/ws.nim
332
ws/ws.nim
@ -11,13 +11,13 @@
|
||||
|
||||
import std/[tables,
|
||||
strutils,
|
||||
strformat,
|
||||
sequtils,
|
||||
uri,
|
||||
parseutils]
|
||||
|
||||
import pkg/[chronos,
|
||||
chronos/apps/http/httptable,
|
||||
chronos/apps/http/httpserver,
|
||||
chronos/streams/asyncstream,
|
||||
chronos/streams/tlsstream,
|
||||
chronicles,
|
||||
@ -28,156 +28,19 @@ import pkg/[chronos,
|
||||
stew/base10,
|
||||
nimcrypto/sha]
|
||||
|
||||
import ./utils, ./stream, ./frame, ./session, /types
|
||||
import ./utils, ./frame, ./session, /types, ./http
|
||||
|
||||
export utils, session, frame, stream, types
|
||||
export utils, session, frame, types, http
|
||||
|
||||
type
|
||||
HttpCode* = enum
|
||||
Http101 = 101 # Switching Protocols
|
||||
|
||||
WSServer* = ref object of WebSocket
|
||||
protocols: seq[string]
|
||||
|
||||
proc `$`(ht: HttpTables): string =
|
||||
## Returns string representation of HttpTable/Ref.
|
||||
var res = ""
|
||||
for key, value in ht.stringItems(true):
|
||||
res.add(key.normalizeHeaderName())
|
||||
res.add(": ")
|
||||
res.add(value)
|
||||
res.add(CRLF)
|
||||
func toException(e: string): ref WebSocketError =
|
||||
(ref WebSocketError)(msg: e)
|
||||
|
||||
## add for end of header mark
|
||||
res.add(CRLF)
|
||||
res
|
||||
|
||||
proc handshake*(
|
||||
ws: WSServer,
|
||||
request: HttpRequestRef,
|
||||
stream: AsyncStream,
|
||||
version: uint = WSDefaultVersion): Future[WSSession] {.async.} =
|
||||
## Handles the websocket handshake.
|
||||
##
|
||||
|
||||
let
|
||||
reqHeaders = request.headers
|
||||
|
||||
ws.version = Base10.decode(
|
||||
uint,
|
||||
reqHeaders.getString("Sec-WebSocket-Version"))
|
||||
.tryGet() # this method throws
|
||||
|
||||
if ws.version != version:
|
||||
raise newException(WSVersionError,
|
||||
"Websocket version not supported, Version: " &
|
||||
reqHeaders.getString("Sec-WebSocket-Version"))
|
||||
|
||||
ws.key = reqHeaders.getString("Sec-WebSocket-Key").strip()
|
||||
var protos = @[""]
|
||||
if reqHeaders.contains("Sec-WebSocket-Protocol"):
|
||||
let wantProtos = reqHeaders.getList("Sec-WebSocket-Protocol")
|
||||
protos = wantProtos.filterIt(
|
||||
it in ws.protocols
|
||||
)
|
||||
|
||||
if protos.len <= 0:
|
||||
raise newException(WSProtoMismatchError,
|
||||
"Protocol mismatch (expected: " & ws.protocols.join(", ") & ", got: " &
|
||||
wantProtos.join(", ") & ")")
|
||||
|
||||
let
|
||||
cKey = ws.key & WSGuid
|
||||
acceptKey = Base64Pad.encode(
|
||||
sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
|
||||
|
||||
var headerData = [
|
||||
("Connection", "Upgrade"),
|
||||
("Upgrade", "webSocket"),
|
||||
("Sec-WebSocket-Accept", acceptKey)]
|
||||
|
||||
var headers = HttpTable.init(headerData)
|
||||
if protos.len > 0:
|
||||
headers.add("Sec-WebSocket-Protocol", protos[0]) # send back the first matching proto
|
||||
|
||||
try:
|
||||
discard await request.respond(httputils.Http101, "", headers)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise newException(WSHandshakeError,
|
||||
"Failed to sent handshake response. Error: " & exc.msg)
|
||||
|
||||
return WSSession(
|
||||
readyState: ReadyState.Open,
|
||||
stream: stream,
|
||||
proto: protos[0],
|
||||
masked: false,
|
||||
rng: ws.rng,
|
||||
frameSize: ws.frameSize,
|
||||
onPing: ws.onPing,
|
||||
onPong: ws.onPong,
|
||||
onClose: ws.onClose)
|
||||
|
||||
proc initiateHandshake(
|
||||
uri: Uri,
|
||||
address: TransportAddress,
|
||||
headers: HttpTable,
|
||||
flags: set[TLSFlags] = {}): Future[AsyncStream] {.async.} =
|
||||
## Initiate handshake with server
|
||||
|
||||
var transp: StreamTransport
|
||||
try:
|
||||
transp = await connect(address)
|
||||
except CatchableError as exc:
|
||||
raise newException(
|
||||
TransportError,
|
||||
"Cannot connect to " & $address & " Error: " & exc.msg)
|
||||
|
||||
let
|
||||
requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers
|
||||
reader = newAsyncStreamReader(transp)
|
||||
writer = newAsyncStreamWriter(transp)
|
||||
|
||||
var stream: AsyncStream
|
||||
|
||||
try:
|
||||
var res: seq[byte]
|
||||
if uri.scheme == "https":
|
||||
let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags)
|
||||
stream = AsyncStream(
|
||||
reader: tlsstream.reader,
|
||||
writer: tlsstream.writer)
|
||||
|
||||
await tlsstream.writer.write(requestHeader)
|
||||
res = await tlsstream.reader.readHeaders()
|
||||
else:
|
||||
stream = AsyncStream(
|
||||
reader: reader,
|
||||
writer: writer)
|
||||
await stream.writer.write(requestHeader)
|
||||
res = await stream.reader.readHeaders()
|
||||
|
||||
if res.len == 0:
|
||||
raise newException(ValueError, "Empty response from server")
|
||||
|
||||
let resHeader = res.parseResponse()
|
||||
if resHeader.failed():
|
||||
# Header could not be parsed
|
||||
raise newException(WSMalformedHeaderError, "Malformed header received.")
|
||||
|
||||
if resHeader.code != ord(Http101):
|
||||
raise newException(WSFailedUpgradeError,
|
||||
"Server did not reply with a websocket upgrade:" &
|
||||
" Header code: " & $resHeader.code &
|
||||
" Header reason: " & resHeader.reason() &
|
||||
" Address: " & $transp.remoteAddress())
|
||||
except CatchableError as exc:
|
||||
debug "Websocket failed during handshake", exc = exc.msg
|
||||
await stream.closeWait()
|
||||
raise exc
|
||||
|
||||
return stream
|
||||
func toException(e: cstring): ref WebSocketError =
|
||||
(ref WebSocketError)(msg: $e)
|
||||
|
||||
proc connect*(
|
||||
_: type WebSocket,
|
||||
@ -193,18 +56,21 @@ proc connect*(
|
||||
## create a new websockets client
|
||||
##
|
||||
|
||||
var key = Base64.encode(genWebSecKey(newRng()))
|
||||
var rng = if isNil(rng): newRng() else: rng
|
||||
var key = Base64.encode(genWebSecKey(rng))
|
||||
var uri = uri
|
||||
case uri.scheme
|
||||
of "ws":
|
||||
uri.scheme = "http"
|
||||
of "wss":
|
||||
uri.scheme = "https"
|
||||
else:
|
||||
raise newException(WSWrongUriSchemeError,
|
||||
"uri scheme has to be 'ws' or 'wss'")
|
||||
let client = case uri.scheme:
|
||||
of "wss":
|
||||
uri.scheme = "https"
|
||||
await TlsHttpClient.connect(uri.hostname, uri.port.parseInt(), tlsFlags = flags)
|
||||
of "ws":
|
||||
uri.scheme = "http"
|
||||
await HttpClient.connect(uri.hostname, uri.port.parseInt())
|
||||
else:
|
||||
raise newException(WSWrongUriSchemeError,
|
||||
"uri scheme has to be 'ws' or 'wss'")
|
||||
|
||||
var headerData = [
|
||||
let headerData = [
|
||||
("Connection", "Upgrade"),
|
||||
("Upgrade", "websocket"),
|
||||
("Cache-Control", "no-cache"),
|
||||
@ -212,19 +78,34 @@ proc connect*(
|
||||
("Sec-WebSocket-Key", key)]
|
||||
|
||||
var headers = HttpTable.init(headerData)
|
||||
|
||||
if protocols.len != 0:
|
||||
if protocols.len > 0:
|
||||
headers.add("Sec-WebSocket-Protocol", protocols.join(", "))
|
||||
|
||||
let address = initTAddress(uri.hostname & ":" & uri.port)
|
||||
let stream = await initiateHandshake(uri, address, headers, flags)
|
||||
let response = try:
|
||||
await client.request(uri, headers = headers)
|
||||
except CatchableError as exc:
|
||||
debug "Websocket failed during handshake", exc = exc.msg
|
||||
await client.close()
|
||||
raise exc
|
||||
|
||||
if response.code != Http101.toInt():
|
||||
raise newException(WSFailedUpgradeError,
|
||||
&"Server did not reply with a websocket upgrade: " &
|
||||
&"Header code: {response.code} Header reason: {response.reason} " &
|
||||
&"Address: {client.address}")
|
||||
|
||||
let proto = response.headers.getString("Sec-WebSocket-Protocol")
|
||||
if proto.len > 0 and protocols.len > 0:
|
||||
if proto notin protocols:
|
||||
raise newException(WSFailedUpgradeError,
|
||||
&"Invalid protocol returned {proto}!")
|
||||
|
||||
# Client data should be masked.
|
||||
return WSSession(
|
||||
stream: stream,
|
||||
stream: client.stream,
|
||||
readyState: ReadyState.Open,
|
||||
masked: true,
|
||||
rng: if isNil(rng): newRng() else: rng,
|
||||
rng: rng,
|
||||
frameSize: frameSize,
|
||||
onPing: onPing,
|
||||
onPong: onPong,
|
||||
@ -232,41 +113,49 @@ proc connect*(
|
||||
|
||||
proc connect*(
|
||||
_: type WebSocket,
|
||||
host: string,
|
||||
port: Port,
|
||||
address: TransportAddress,
|
||||
path: string,
|
||||
protocols: seq[string] = @[],
|
||||
secure = false,
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
onPing: ControlCb = nil,
|
||||
onPong: ControlCb = nil,
|
||||
onClose: CloseCb = nil): Future[WSSession] {.async.} =
|
||||
onClose: CloseCb = nil,
|
||||
rng: Rng = nil): Future[WSSession] {.async.} =
|
||||
## Create a new websockets client
|
||||
## using a string path
|
||||
##
|
||||
|
||||
var uri = "ws://" & host & ":" & $port
|
||||
var uri = if secure:
|
||||
&"wss://"
|
||||
else:
|
||||
&"ws://"
|
||||
|
||||
uri &= address.host & ":" & $address.port
|
||||
if path.startsWith("/"):
|
||||
uri.add path
|
||||
else:
|
||||
uri.add "/" & path
|
||||
uri.add &"/{path}"
|
||||
|
||||
return await WebSocket.connect(
|
||||
parseUri(uri),
|
||||
protocols,
|
||||
{},
|
||||
version,
|
||||
frameSize,
|
||||
onPing,
|
||||
onPong,
|
||||
onClose)
|
||||
uri = parseUri(uri),
|
||||
protocols = protocols,
|
||||
flags = flags,
|
||||
version = version,
|
||||
frameSize = frameSize,
|
||||
onPing = onPing,
|
||||
onPong = onPong,
|
||||
onClose = onClose)
|
||||
|
||||
proc tlsConnect*(
|
||||
proc connect*(
|
||||
_: type WebSocket,
|
||||
host: string,
|
||||
port: Port,
|
||||
path: string,
|
||||
protocols: seq[string] = @[],
|
||||
secure = false,
|
||||
flags: set[TLSFlags] = {},
|
||||
version = WSDefaultVersion,
|
||||
frameSize = WSDefaultFrameSize,
|
||||
@ -275,38 +164,91 @@ proc tlsConnect*(
|
||||
onClose: CloseCb = nil,
|
||||
rng: Rng = nil): Future[WSSession] {.async.} =
|
||||
|
||||
var uri = "wss://" & host & ":" & $port
|
||||
if path.startsWith("/"):
|
||||
uri.add path
|
||||
else:
|
||||
uri.add "/" & path
|
||||
|
||||
return await WebSocket.connect(
|
||||
parseUri(uri),
|
||||
protocols,
|
||||
flags,
|
||||
version,
|
||||
frameSize,
|
||||
onPing,
|
||||
onPong,
|
||||
onClose,
|
||||
rng)
|
||||
address = initTAddress(host, port),
|
||||
path = path,
|
||||
protocols = protocols,
|
||||
flags = flags,
|
||||
version = version,
|
||||
frameSize = frameSize,
|
||||
onPing = onPing,
|
||||
onPong = onPong,
|
||||
onClose = onClose,
|
||||
rng = rng)
|
||||
|
||||
proc handleRequest*(
|
||||
ws: WSServer,
|
||||
request: HttpRequestRef): Future[WSSession]
|
||||
{.raises: [Defect, WSHandshakeError].} =
|
||||
request: HttpRequest,
|
||||
version: uint = WSDefaultVersion): Future[WSSession]
|
||||
{.
|
||||
async,
|
||||
raises: [
|
||||
Defect,
|
||||
WSHandshakeError,
|
||||
WSProtoMismatchError]
|
||||
.} =
|
||||
## Creates a new socket from a request.
|
||||
##
|
||||
|
||||
if not request.headers.contains("Sec-WebSocket-Version"):
|
||||
raise newException(WSHandshakeError, "Missing version header")
|
||||
|
||||
let wsStream = AsyncStream(
|
||||
reader: request.connection.reader,
|
||||
writer: request.connection.writer)
|
||||
ws.version = Base10.decode(
|
||||
uint,
|
||||
request.headers.getString("Sec-WebSocket-Version"))
|
||||
.tryGet() # this method throws
|
||||
|
||||
return ws.handshake(request, wsStream)
|
||||
if ws.version != version:
|
||||
await request.stream.writer.sendError(Http426)
|
||||
debug "Websocket version not supported", version = ws.version
|
||||
|
||||
raise newException(WSVersionError,
|
||||
&"Websocket version not supported, Version: {version}")
|
||||
|
||||
ws.key = request.headers.getString("Sec-WebSocket-Key").strip()
|
||||
let wantProtos = if request.headers.contains("Sec-WebSocket-Protocol"):
|
||||
request.headers.getList("Sec-WebSocket-Protocol")
|
||||
else:
|
||||
@[""]
|
||||
|
||||
let protos = wantProtos.filterIt(
|
||||
it in ws.protocols
|
||||
)
|
||||
|
||||
let
|
||||
cKey = ws.key & WSGuid
|
||||
acceptKey = Base64Pad.encode(
|
||||
sha1.digest(cKey.toOpenArray(0, cKey.high)).data)
|
||||
|
||||
var headers = HttpTable.init([
|
||||
("Connection", "Upgrade"),
|
||||
("Upgrade", "websocket"),
|
||||
("Sec-WebSocket-Accept", acceptKey)])
|
||||
|
||||
let protocol = if protos.len > 0: protos[0] else: ""
|
||||
if protocol.len > 0:
|
||||
headers.add("Sec-WebSocket-Protocol", protocol) # send back the first matching proto
|
||||
else:
|
||||
debug "Didn't match any protocol", supported = ws.protocols, requested = wantProtos
|
||||
|
||||
try:
|
||||
await request.sendResponse(Http101, headers = headers)
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
raise newException(WSHandshakeError,
|
||||
"Failed to sent handshake response. Error: " & exc.msg)
|
||||
|
||||
return WSSession(
|
||||
readyState: ReadyState.Open,
|
||||
stream: request.stream,
|
||||
proto: protocol,
|
||||
masked: false,
|
||||
rng: ws.rng,
|
||||
frameSize: ws.frameSize,
|
||||
onPing: ws.onPing,
|
||||
onPong: ws.onPong,
|
||||
onClose: ws.onClose)
|
||||
|
||||
proc new*(
|
||||
_: typedesc[WSServer],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user