diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 899c1ee0..0420f6a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,21 +76,21 @@ jobs: if: runner.os == 'Linux' && matrix.target.cpu == 'i386' run: | sudo dpkg --add-architecture i386 - sudo rm /etc/apt/sources.list.d/devel:kubic:libcontainers:stable.list sudo apt-get update -qq sudo DEBIAN_FRONTEND='noninteractive' apt-get install \ - --no-install-recommends -yq gcc-multilib g++-multilib + --no-install-recommends -yq gcc-multilib g++-multilib \ + libssl-dev:i386 mkdir -p external/bin cat << EOF > external/bin/gcc #!/bin/bash - exec $(which gcc) -m32 -mno-adx "\$@" + exec $(which gcc) -m32 "\$@" EOF cat << EOF > external/bin/g++ #!/bin/bash - exec $(which g++) -m32 -mno-adx "\$@" + exec $(which g++) -m32 "\$@" EOF chmod 755 external/bin/gcc external/bin/g++ - echo "${{ github.workspace }}/external/bin" >> $GITHUB_PATH + echo '${{ github.workspace }}/external/bin' >> $GITHUB_PATH - name: Install build dependencies (Windows) if: runner.os == 'Windows' diff --git a/examples/client.nim b/examples/client.nim index cb00025b..ea61b991 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -1,4 +1,9 @@ -import ../src/ws, nativesockets, chronos,chronicles, stew/byteutils +import pkg/[ + chronos, + chronicles, + stew/byteutils] + +import ../ws/ws proc main() {.async.} = let ws = await WebSocket.connect( diff --git a/examples/server.nim b/examples/server.nim index 379f047c..fceee461 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -3,7 +3,8 @@ chronicles, httputils, stew/byteutils] -import ../src/ws + +import ../ws/ws proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = if r.isOk(): @@ -28,7 +29,7 @@ proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = debug "Client Response: ", size = recvData.len await ws.send(recvData) # await ws.close() - + except WebSocketError as exc: error "WebSocket error:", exception = exc.msg discard await request.respond(Http200, "Hello World") @@ -41,8 +42,8 @@ when isMainModule: let res = HttpServerRef.new( address, process, socketFlags = socketFlags) - + let server = res.get() server.start() info "Server listening at ", data = address - waitFor server.join() \ No newline at end of file + waitFor server.join() diff --git a/examples/tlsclient.nim b/examples/tlsclient.nim new file mode 100644 index 00000000..7b3dbab8 --- /dev/null +++ b/examples/tlsclient.nim @@ -0,0 +1,34 @@ +import pkg/[chronos, + chronos/streams/tlsstream, + chronicles, + stew/byteutils] + +import ../ws/ws + +proc main() {.async.} = + let ws = await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["myfancyprotocol"], + flags = {NoVerifyHost,NoVerifyServerName}) + debug "Websocket client: ", State = ws.readyState + + let reqData = "Hello Server" + try: + echo "sending client " + await ws.send(reqData) + let buff = await ws.recv() + if buff.len <= 0: + break + let dataStr = string.fromBytes(buff) + debug "Server:", data = dataStr + + assert dataStr == reqData + return # bail out + except WebSocketError as exc: + error "WebSocket error:", exception = exc.msg + + # close the websocket + await ws.close() +waitFor(main()) diff --git a/examples/tlsserver.nim b/examples/tlsserver.nim new file mode 100644 index 00000000..d50ea05f --- /dev/null +++ b/examples/tlsserver.nim @@ -0,0 +1,57 @@ +import pkg/[chronos, + chronos/apps/http/shttpserver, + chronicles, + httputils, + stew/byteutils] + +import ../ws/ws +import ../tests/keys + +let secureKey = TLSPrivateKey.init(SecureKey) +let secureCert = TLSCertificate.init(SecureCert) + +proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + + debug "Handling request:", uri = request.uri.path + if request.uri.path == "/wss": + debug "Initiating web socket connection." + try: + var ws = await createServer(request, "myfancyprotocol") + if ws.readyState != Open: + error "Failed to open websocket connection." + return + debug "Websocket handshake completed." + # Only reads header for data frame. + echo "receiving server " + let recvData = await ws.recv() + if recvData.len <= 0: + debug "Empty messages" + break + + if ws.readyState == ReadyState.Closed: + return + debug "Response: ", data = string.fromBytes(recvData) + await ws.send(recvData) + except WebSocketError: + error "WebSocket error:", exception = getCurrentExceptionMsg() + discard await request.respond(Http200, "Hello World") + else: + return dumbResponse() + +when isMainModule: + let address = initTAddress("127.0.0.1:8888") + let serverFlags = {Secure, NotifyDisconnect} + let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} + let res = SecureHttpServerRef.new( + address, process, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + let server = res.get() + server.start() + info "Server listening at ", data = address + waitFor server.join() diff --git a/tests/keys.nim b/tests/keys.nim new file mode 100644 index 00000000..74e17c5e --- /dev/null +++ b/tests/keys.nim @@ -0,0 +1,55 @@ +const + SecureKey* = """ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCdNv0SX02aeZ4/ +Yc+p/Kwd5UVOHlpmK7/TVC/kcjFbdoUuKNn8pnX/fyhgSKpUYut+te7YRiZhqlaL +EZKjfy8GBZwXZnJCevFkTvGTTebXXExLIsLGfJqKeLAdFCQkX8wV3jV1DT5JLV+D +5+HWaiiBr38gsl4ZbfyedTF40JvzokCmcdlx9bpzX1j/b84L/zSwUyyEcgp5G28F +Jh5TnxAeDHJpOVjr8XMb/xoNqiDF6NwF96hvOZC14mZ1TxxW5bUzXprsy0l52pmh +dN3Crz11+t2h519hRKHxT6/l5pTx/+dApXiP6hMV04CQJNnas3NyRxTDR9dNel+3 ++wD7/PRTAgMBAAECggEBAJuXPEbegxMKog7gYoE9S6oaqchySc0sJyCjBPL2ANsg +JRZV38cnh0hhNDh2MfxqGd7Bd6wbYQjvZ88iiRm+WW+ARcby4MnimtxHNNYwFvG0 +qt0BffqqftfkMYfV0x8coAJUdFtvy+DoQstsxhlJ3uTaJtrZLD/GlmjMWzXSX0Vy +FXiLDO7/LoSjsjaf4e4aLofIyLJS3H1T+5cr/d2mdpRzkeWkxShODsK4cRLOlZ5I +pz4Wm2770DTbiYph8ixl/CnmYn6T7V0F5VYujALknipUBeQY4e/A9vrQ/pvqJV+W +JjFUne6Rxg/lJjh8vNJp2bK1ZbzpwmZLaZIoEz8t/qECgYEAzvCCA48uQPaurSQ3 +cvHDhcVwYmEaH8MW8aIW/5l8XJK60GsUHPFhEsfD/ObI5PJJ9aOqgabpRHkvD4ZY +a8QJBxCy6UeogUeKvGks8VQ34SZXLimmgrL9Mlljv0v9PloEkVYbztYyX4GVO0ov +3oH+hKO+/MclzNDyeXZx3Vv4K+UCgYEAwnyb7tqp7fRqm/8EymIZV5pa0p6h609p +EhCBi9ii6d/ewEjsBhs7bPDBO4PO9ylvOvryYZH1hVbQja2anOCBjO8dAHRHWM86 +964TFriywBQkYxp6dsB8nUjLBDza2xAM3m+OGi9/ATuhEAe5sXp/fZL3tkfSaOXI +A7Gzro+kS9cCgYEAtKScSfEeBlWQa9H2mV9UN5z/mtF61YkeqTW+b8cTGVh4vWEL +wKww+gzqGAV6Duk2CLijKeSDMmO64gl7fC83VjSMiTklbhz+jbQeKFhFI0Sty71N +/j+y6NXBTgdOfLRl0lzhj2/JrzdWBtie6tR9UloCaXSKmb04PTFY+kvDWsUCgYBR +krJUnKJpi/qrM2tu93Zpp/QwIxkG+We4i/PKFDNApQVo4S0d4o4qQ1DJBZ/pSxe8 +RUUkZ3PzWVZgFlCjPAcadbBUYHEMbt7sw7Z98ToIFmqspo53AIVD8yQzwtKIz1KW +eXPAx+sdOUV008ivCBIxOVNswPMfzED4S7Bxpw3iQQKBgGJhct2nBsgu0l2/wzh9 +tpKbalW1RllgptNQzjuBEZMTvPF0L+7BE09/exKtt4N9s3yAzi8o6Qo7RHX5djVc +SNgafV4jj7jt2Ilh6KOy9dshtLoEkS1NmiqfVe2go2auXZdyGm+I2yzKWdKGDO0J +diTtYf1sA0PgNXdSyDC03TZl +-----END PRIVATE KEY----- +""" + + SecureCert* = """ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUe9fr78Dz9PedQ5Sq0uluMWQhX9wwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCSU4xEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTAzMTcwOTMzMzZaFw0zMTAz +MTUwOTMzMzZaMEUxCzAJBgNVBAYTAklOMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCdNv0SX02aeZ4/Yc+p/Kwd5UVOHlpmK7/TVC/kcjFb +doUuKNn8pnX/fyhgSKpUYut+te7YRiZhqlaLEZKjfy8GBZwXZnJCevFkTvGTTebX +XExLIsLGfJqKeLAdFCQkX8wV3jV1DT5JLV+D5+HWaiiBr38gsl4ZbfyedTF40Jvz +okCmcdlx9bpzX1j/b84L/zSwUyyEcgp5G28FJh5TnxAeDHJpOVjr8XMb/xoNqiDF +6NwF96hvOZC14mZ1TxxW5bUzXprsy0l52pmhdN3Crz11+t2h519hRKHxT6/l5pTx +/+dApXiP6hMV04CQJNnas3NyRxTDR9dNel+3+wD7/PRTAgMBAAGjUzBRMB0GA1Ud +DgQWBBRkSY1AkGUpVNxG5fYocfgFODtQmTAfBgNVHSMEGDAWgBRkSY1AkGUpVNxG +5fYocfgFODtQmTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBt +D71VH7F8GOQXITFXCrHwEq1Fx3ScuSnL04NJrXw/e9huzLVQOchAYp/EIn4x2utN +S31dt94wvi/IysOVbR1LatYNF5kKgGj2Wc6DH0PswBMk8R1G8QMeCz+hCjf1VDHe +AAW1x2q20rJAvUrT6cRBQqeiMzQj0OaJbvfnd2hu0/d0DFkcuGVgBa2zlbG5rbdU +Jnq7MQfSaZHd0uBgiKkS+Zw6XaYfWfByCAGSnUqRdOChiJ2stFVLvu+9oQ+PJjJt +Er1u9bKTUyeuYpqXr2BP9dqphwu8R4NFVUg6DIRpMFMsybaL7KAd4hD22RXCvc0m +uLu7KODi+eW62MHqs4N2 +-----END CERTIFICATE----- +""" diff --git a/tests/testall.nim b/tests/testall.nim index 51d4c49d..80f6b764 100644 --- a/tests/testall.nim +++ b/tests/testall.nim @@ -1,2 +1,3 @@ import ./testframes import ./testwebsockets +import ./testtlswebsockets diff --git a/tests/testframes.nim b/tests/testframes.nim index 8fe80fbc..18b32be3 100644 --- a/tests/testframes.nim +++ b/tests/testframes.nim @@ -1,7 +1,7 @@ import unittest -include ../src/ws -include ../src/random +include ../ws/ws +include ../ws/random # TODO: Fix Test. diff --git a/tests/testtlswebsockets.nim b/tests/testtlswebsockets.nim new file mode 100644 index 00000000..455d3711 --- /dev/null +++ b/tests/testtlswebsockets.nim @@ -0,0 +1,225 @@ +import std/strutils, httputils + +import pkg/[asynctest, + chronos, + chronos/apps/http/shttpserver, + stew/byteutils] + +import ../ws/[ws, stream], + ../examples/tlsserver + +import ./keys + +var server: SecureHttpServerRef +let address = initTAddress("127.0.0.1:8888") +let serverFlags = {HttpServerFlags.Secure, HttpServerFlags.NotifyDisconnect} +let socketFlags = {ServerFlags.TcpNoDelay, ServerFlags.ReuseAddr} +let clientFlags = {NoVerifyHost, NoVerifyServerName} + +let secureKey = TLSPrivateKey.init(SecureKey) +let secureCert = TLSCertificate.init(SecureCert) + +suite "Test websocket TLS handshake": + teardown: + await server.closeWait() + + test "Test for websocket TLS incorrect protocol": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return + + let request = r.get() + check request.uri.path == "/wss" + expect WSProtoMismatchError: + var ws = await createServer(request, "proto") + check ws.readyState == ReadyState.Closed + + return await request.respond(Http200, "Connection established") + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + expect WSFailedUpgradeError: + discard await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["wrongproto"], + clientFlags) + + test "Test for websocket TLS incorrect version": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return + + let request = r.get() + check request.uri.path == "/wss" + expect WSVersionError: + var ws = await createServer(request, "proto") + check ws.readyState == ReadyState.Closed + + return await request.respond(Http200, "Connection established") + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + expect WSFailedUpgradeError: + discard await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["wrongproto"], + clientFlags, + version = 14) + + test "Test for websocket TLS client headers": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + check r.isOk() + let request = r.get() + check request.uri.path == "/wss" + check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() + check request.headers.getString("Upgrade").toUpperAscii() == "websocket".toUpperAscii() + check request.headers.getString("Cache-Control").toUpperAscii() == "no-cache".toUpperAscii() + check request.headers.getString("Sec-WebSocket-Version") == $WSDefaultVersion + + check request.headers.contains("Sec-WebSocket-Key") + + discard await request.respond( Http200,"Connection established") + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + expect WSFailedUpgradeError: + discard await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["proto"], + clientFlags) + +suite "Test websocket TLS transmission": + teardown: + await server.closeWait() + + test "Server - test reading simple frame": + let testString = "Hello!" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + + let request = r.get() + check request.uri.path == "/wss" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == testString + await ws.close() + return dumbResponse() + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + let wsClient = await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["proto"], + clientFlags) + + await wsClient.send(testString) + await wsClient.close() + + test "Client - test reading simple frame": + let testString = "Hello!" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + + let request = r.get() + check request.uri.path == "/wss" + let ws = await createServer(request, "proto") + let servRes = await ws.recv() + check string.fromBytes(servRes) == testString + await ws.close() + return dumbResponse() + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + let wsClient = await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["proto"], + clientFlags) + + await wsClient.send(testString) + await wsClient.close() + + test "Client - test reading simple frame": + let testString = "Hello!" + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return dumbResponse() + + let request = r.get() + check request.uri.path == "/wss" + let ws = await createServer(request, "proto") + await ws.send(testString) + await ws.close() + return dumbResponse() + + let res = SecureHttpServerRef.new( + address, cb, + serverFlags = serverFlags, + socketFlags = socketFlags, + tlsPrivateKey = secureKey, + tlsCertificate = secureCert) + + server = res.get() + server.start() + + let wsClient = await WebSocket.tlsConnect( + "127.0.0.1", + Port(8888), + path = "/wss", + protocols = @["proto"], + clientFlags) + + var clientRes = await wsClient.recv() + check string.fromBytes(clientRes) == testString + await wsClient.close() diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index 6c370c86..5a12f8d4 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -4,26 +4,29 @@ import pkg/[asynctest, chronos, chronos/apps/http/httpserver, stew/byteutils] -import ../src/ws, ../src/stream + +import ../ws/[ws, stream] var server: HttpServerRef let address = initTAddress("127.0.0.1:8888") suite "Test handshake": teardown: + await server.stop() await server.closeWait() test "Test for incorrect protocol": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" expect WSProtoMismatchError: var ws = await createServer(request, "proto") check ws.readyState == ReadyState.Closed - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -36,15 +39,16 @@ suite "Test handshake": test "Test for incorrect version": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" expect WSVersionError: var ws = await createServer(request, "proto") check ws.readyState == ReadyState.Closed - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -58,7 +62,9 @@ suite "Test handshake": test "Test for client headers": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" check request.headers.getString("Connection").toUpperAscii() == "Upgrade".toUpperAscii() @@ -68,8 +74,7 @@ suite "Test handshake": check request.headers.contains("Sec-WebSocket-Key") - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -80,23 +85,49 @@ suite "Test handshake": path = "/ws", protocols = @["proto"]) + test "Test for incorrect scheme": + proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isErr(): + return + + let request = r.get() + check request.uri.path == "/ws" + + expect WSProtoMismatchError: + var ws = await createServer(request, "proto") + check ws.readyState == ReadyState.Closed + + return await request.respond(Http200, "Connection established") + + let res = HttpServerRef.new(address, cb) + server = res.get() + server.start() + + let uri = "wx://127.0.0.1:8888/ws" + expect WSWrongUriSchemeError: + discard await WebSocket.connect( + parseUri(uri), + protocols = @["proto"]) + suite "Test transmission": teardown: await server.closeWait() + test "Server - test reading simple frame": let testString = "Hello!" proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" let ws = await createServer(request, "proto") let servRes = await ws.recv() check string.fromBytes(servRes) == testString - await ws.stream.closeWait() + await ws.close() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -106,18 +137,21 @@ suite "Test transmission": path = "/ws", protocols = @["proto"]) await wsClient.send(testString) + await wsClient.close() test "Client - test reading simple frame": let testString = "Hello!" proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" let ws = await createServer(request, "proto") await ws.send(testString) - await ws.stream.closeWait() - let res = HttpServerRef.new( - address, cb) + await ws.close() + + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -128,6 +162,7 @@ suite "Test transmission": protocols = @["proto"]) var clientRes = await wsClient.recv() + await wsClient.close() check string.fromBytes(clientRes) == testString suite "Test ping-pong": @@ -137,7 +172,9 @@ suite "Test ping-pong": test "Server - test ping-pong control messages": var ping, pong = false proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" let ws = await createServer( @@ -146,11 +183,11 @@ suite "Test ping-pong": onPong = proc() = pong = true ) + await ws.ping() await ws.close() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -171,7 +208,9 @@ suite "Test ping-pong": test "Client - test ping-pong control messages": var ping, pong = false proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return + let request = r.get() check request.uri.path == "/ws" let ws = await createServer( @@ -182,8 +221,9 @@ suite "Test ping-pong": ) discard await ws.recv() - let res = HttpServerRef.new( - address, cb) + await ws.close() + + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -208,13 +248,14 @@ suite "Test framing": test "should split message into frames": let testString = "1234567890" - var done = newFuture[void]() proc cb(r: RequestFence): Future[HttpResponseRef]{.async.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - let ws = await createServer(request, "proto") + let ws = await createServer(request, "proto") let frame1 = await ws.readFrame() check not isNil(frame1) var data1 = newSeq[byte](frame1.remainder().int) @@ -227,11 +268,10 @@ suite "Test framing": let read2 = await ws.stream.reader.readOnce(addr data2[0], data2.len) check read2 == 5 - await ws.stream.closeWait() - done.complete() + await ws.close() + return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -243,20 +283,22 @@ suite "Test framing": frameSize = 5) await wsClient.send(testString) - await done + await wsClient.close() test "should fail to read past max message size": let testString = "1234567890" proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" let ws = await createServer(request, "proto") await ws.send(testString) - await ws.stream.closeWait() + await ws.close() + return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -269,19 +311,24 @@ suite "Test framing": expect WSMaxMessageSizeError: discard await wsClient.recv(5) + await wsClient.close() + suite "Test Closing": teardown: await server.closeWait() test "Server closing": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" let ws = await createServer(request, "proto") await ws.close() - let res = HttpServerRef.new( - address, cb) + return dumbResponse() + + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -296,12 +343,18 @@ suite "Test Closing": test "Server closing with status": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - proc closeServer(status: Status, reason: string): CloseResult {.gcsafe.} = - check status == Status.TooLarge - check reason == "Message too big!" + proc closeServer(status: Status, reason: string): CloseResult + {.gcsafe, raises: [Defect].} = + try: + check status == Status.TooLarge + check reason == "Message too big!" + except Exception as exc: + raise newException(Defect, exc.msg) return (Status.Fulfilled, "") @@ -311,15 +364,19 @@ suite "Test Closing": onClose = closeServer) await ws.close() + return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() - proc clientClose(status: Status, reason: string): CloseResult {.gcsafe.} = - check status == Status.Fulfilled - return (Status.TooLarge, "Message too big!") + proc clientClose(status: Status, reason: string): CloseResult + {.gcsafe, raises: [Defect].} = + try: + check status == Status.Fulfilled + return (Status.TooLarge, "Message too big!") + except Exception as exc: + raise newException(Defect, exc.msg) let wsClient = await WebSocket.connect( "127.0.0.1", @@ -333,14 +390,17 @@ suite "Test Closing": test "Client closing": proc cb(r: RequestFence): Future[HttpResponseRef] {.async.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" let ws = await createServer(request, "proto") discard await ws.recv() + await ws.close() + return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() @@ -353,28 +413,39 @@ suite "Test Closing": test "Client closing with status": proc cb(r: RequestFence): Future[HttpResponseRef] {.async, gcsafe.} = - check r.isOk() + if r.isErr(): + return dumbResponse() + let request = r.get() check request.uri.path == "/ws" - proc closeServer(status: Status, reason: string): CloseResult {.gcsafe.} = - check status == Status.Fulfilled - return (Status.TooLarge, "Message too big!") + proc closeServer(status: Status, reason: string): CloseResult + {.gcsafe, raises: [Defect].} = + try: + check status == Status.Fulfilled + return (Status.TooLarge, "Message too big!") + except Exception as exc: + raise newException(Defect, exc.msg) let ws = await createServer( request, "proto", onClose = closeServer) discard await ws.recv() + await ws.close() + return dumbResponse() - let res = HttpServerRef.new( - address, cb) + let res = HttpServerRef.new(address, cb) server = res.get() server.start() - proc clientClose(status: Status, reason: string): CloseResult {.gcsafe.} = - check status == Status.TooLarge - check reason == "Message too big!" - return (Status.Fulfilled, "") + proc clientClose(status: Status, reason: string): CloseResult + {.gcsafe, raises: [Defect].} = + try: + check status == Status.TooLarge + check reason == "Message too big!" + return (Status.Fulfilled, "") + except Exception as exc: + raise newException(Defect, exc.msg) let wsClient = await WebSocket.connect( "127.0.0.1", diff --git a/ws.nimble b/ws.nimble index b86b81fd..399f061e 100644 --- a/ws.nimble +++ b/ws.nimble @@ -3,16 +3,20 @@ version = "0.1.0" author = "Status Research & Development GmbH" description = "WS protocol implementation" license = "MIT" +skipDirs = @["examples", "test"] -requires "nim == 1.2.6" +requires "nim >= 1.2.6" requires "chronos >= 2.5.2" requires "httputils >= 0.2.0" requires "chronicles >= 0.10.0" -requires "urlly >= 0.2.0" requires "stew >= 0.1.0" -requires "eth" requires "asynctest >= 0.2.0 & < 0.3.0" requires "nimcrypto" +requires "bearssl" task test, "run tests": exec "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off ./tests/testall.nim" + rmFile "./tests/testall" + rmFile "./tests/testwebsockets" + rmFile "./tests/testframes" + rmFile "./tests/testtlswebsockets" diff --git a/src/random.nim b/ws/random.nim similarity index 65% rename from src/random.nim rename to ws/random.nim index e3ea2d65..2862b02d 100644 --- a/src/random.nim +++ b/ws/random.nim @@ -1,8 +1,23 @@ import bearssl +export bearssl ## Random helpers: similar as in stdlib, but with BrHmacDrbgContext rng const randMax = 18_446_744_073_709_551_615'u64 +proc newRng*(): ref BrHmacDrbgContext = + # You should only create one instance of the RNG per application / library + # Ref is used so that it can be shared between components + # TODO consider moving to bearssl + var seeder = brPrngSeederSystem(nil) + if seeder == nil: + return nil + + var rng = (ref BrHmacDrbgContext)() + brHmacDrbgInit(addr rng[], addr sha256Vtable, nil, 0) + if seeder(addr rng.vtable) == 0: + return nil + rng + proc rand*(rng: var BrHmacDrbgContext, max: Natural): int = if max == 0: return 0 var x: uint64 diff --git a/src/stream.nim b/ws/stream.nim similarity index 82% rename from src/stream.nim rename to ws/stream.nim index 1d9ccbeb..fdd05df9 100644 --- a/src/stream.nim +++ b/ws/stream.nim @@ -6,9 +6,9 @@ import strutils const - HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) HeaderSep = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] - MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets + HttpHeadersTimeout = timer.seconds(120) # timeout for receiving headers (120 sec) + MaxHttpHeadersSize = 8192 # maximum size of HTTP headers in octets proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = var buffer = newSeq[byte](MaxHttpHeadersSize) @@ -42,12 +42,12 @@ proc readHeaders*(rstream: AsyncStreamReader): Future[seq[byte]] {.async.} = if error: buffer.setLen(0) + return buffer -proc closeWait*(wsStream : AsyncStream): Future[void] {.async.} = - if not wsStream.writer.tsource.closed(): - await wsStream.writer.tsource.closeWait() - if not wsStream.reader.tsource.closed(): - await wsStream.reader.tsource.closeWait() +proc closeWait*(wsStream : AsyncStream) {.async.} = + await allFutures( + wsStream.writer.tsource.closeWait(), + wsStream.reader.tsource.closeWait()) # TODO: Implement stream read and write wrapper. diff --git a/src/ws.nim b/ws/ws.nim similarity index 92% rename from src/ws.nim rename to ws/ws.nim index 120a797f..c35f9d50 100644 --- a/src/ws.nim +++ b/ws/ws.nim @@ -1,3 +1,5 @@ +{.push raises: [Defect].} + import std/[tables, strutils, uri, @@ -7,13 +9,13 @@ import pkg/[chronos, chronos/apps/http/httptable, chronos/apps/http/httpserver, chronos/streams/asyncstream, + chronos/streams/tlsstream, chronicles, httputils, stew/byteutils, stew/endians2, stew/base64, stew/base10, - eth/keys, nimcrypto/sha] import ./random, ./stream @@ -125,14 +127,14 @@ type length: uint64 ## Message size. consumed: uint64 ## how much has been consumed from the frame - ControlCb* = proc() {.gcsafe.} + ControlCb* = proc() {.gcsafe, raises: [Defect].} CloseResult* = tuple code: Status reason: string CloseCb* = proc(code: Status, reason: string): - CloseResult {.gcsafe.} + CloseResult {.gcsafe, raises: [Defect].} WebSocket* = ref object stream*: AsyncStream @@ -208,9 +210,11 @@ proc handshake*( let cKey = ws.key & WSGuid let acceptKey = Base64Pad.encode(sha1.digest(cKey.toOpenArray(0, cKey.high)).data) + var headerData = [ + ("Connection", "Upgrade"), + ("Upgrade", "webSocket" ), + ("Sec-WebSocket-Accept", acceptKey)] - var headerData = [("Connection", "Upgrade"),("Upgrade", "webSocket" ), - ("Sec-WebSocket-Accept", acceptKey)] var headers = HttpTable.init(headerData) if ws.protocol != "": headers.add("Sec-WebSocket-Protocol", ws.protocol) @@ -404,7 +408,7 @@ proc handleControl*(ws: WebSocket, frame: Frame) {.async.} = if frame.length > 125: raise newException(WSPayloadTooLarge, - "Control message payload is freater than 125 bytes!") + "Control message payload is greater than 125 bytes!") try: # Process control frame payload. @@ -439,7 +443,7 @@ proc readFrame*(ws: WebSocket): Future[Frame] {.async.} = ## try: - while ws.readyState != ReadyState.Closed: # read until a data frame arrives + while ws.readyState != ReadyState.Closed: # Grab the header. var header = newSeq[byte](2) await ws.stream.reader.readExactly(addr header[0], 2) @@ -525,7 +529,7 @@ proc recv*( size: int): Future[int] {.async.} = ## Attempts to read up to `size` bytes ## - ## Will read as many frames as necesary + ## Will read as many frames as necessary ## to fill the buffer until either ## the message ends (frame.fin) or ## the buffer is full. If no data is on @@ -643,7 +647,8 @@ proc close*( proc initiateHandshake( uri: Uri, address: TransportAddress, - headers: HttpTable): Future[AsyncStream] {.async.} = + headers: HttpTable, + flags: set[TLSFlags] = {}): Future[AsyncStream] {.async.} = ## Initiate handshake with server var transp: StreamTransport @@ -654,11 +659,27 @@ proc initiateHandshake( TransportError, "Cannot connect to " & $transp.remoteAddress() & " Error: " & exc.msg) + let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers let reader = newAsyncStreamReader(transp) let writer = newAsyncStreamWriter(transp) - let requestHeader = "GET " & uri.path & " HTTP/1.1" & CRLF & $headers - await writer.write(requestHeader) - let res = await reader.readHeaders() + var stream: AsyncStream + + var res: seq[byte] + if uri.scheme == "https": + let tlsstream = newTLSClientAsyncStream(reader, writer, "", flags = flags) + stream = AsyncStream( + reader: tlsstream.reader, + writer: tlsstream.writer) + + await tlsstream.writer.write(requestHeader) + res = await tlsstream.reader.readHeaders() + else: + stream = AsyncStream( + reader: reader, + writer: writer) + await stream.writer.write(requestHeader) + res = await stream.reader.readHeaders() + if res.len == 0: raise newException(ValueError, "Empty response from server") @@ -674,14 +695,13 @@ proc initiateHandshake( " Header reason: " & resHeader.reason() & " Address: " & $transp.remoteAddress()) - return AsyncStream( - reader: reader, - writer: writer) + return stream proc connect*( _: type WebSocket, uri: Uri, protocols: seq[string] = @[], + flags: set[TLSFlags] = {}, version = WSDefaultVersion, frameSize = WSDefaultFrameSize, onPing: ControlCb = nil, @@ -695,8 +715,10 @@ proc connect*( case uri.scheme of "ws": uri.scheme = "http" + of "wss": + uri.scheme = "https" else: - raise newException(WSWrongUriSchemeError, "uri scheme has to be 'ws'") + raise newException(WSWrongUriSchemeError, "uri scheme has to be 'ws' or 'wss'") var headerData = [ ("Connection", "Upgrade"), @@ -711,7 +733,7 @@ proc connect*( headers.add("Sec-WebSocket-Protocol", protocols.join(", ")) let address = initTAddress(uri.hostname & ":" & uri.port) - let stream = await initiateHandshake(uri, address, headers) + let stream = await initiateHandshake(uri, address, headers, flags) # Client data should be masked. return WebSocket( @@ -748,6 +770,36 @@ proc connect*( return await WebSocket.connect( parseUri(uri), protocols, + {}, + version, + frameSize, + onPing, + onPong, + onClose) + +proc tlsConnect*( + _: type WebSocket, + host: string, + port: Port, + path: string, + protocols: seq[string] = @[], + flags: set[TLSFlags] = {}, + version = WSDefaultVersion, + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil): Future[WebSocket] {.async.} = + + var uri = "wss://" & host & ":" & $port + if path.startsWith("/"): + uri.add path + else: + uri.add "/" & path + + return await WebSocket.connect( + parseUri(uri), + protocols, + flags, version, frameSize, onPing,