diff --git a/examples/sctp_both.nim b/examples/sctp_both.nim new file mode 100644 index 0000000..bb861fa --- /dev/null +++ b/examples/sctp_both.nim @@ -0,0 +1,25 @@ +import chronos, stew/byteutils +import ../webrtc/sctp as sc + +let sctp = Sctp.new(port = 4242) +proc serv(fut: Future[void]) {.async.} = + sctp.startServer(13) + fut.complete() + let conn = await sctp.listen() + echo "await read()" + let msg = await conn.read() + echo "read() finished" + echo "Receive: ", string.fromBytes(msg) + await conn.close() + sctp.stopServer() + +proc main() {.async.} = + let fut = Future[void]() + asyncSpawn serv(fut) + await fut + let address = TransportAddress(initTAddress("127.0.0.1:4242")) + let conn = await sctp.connect(address, sctpPort = 13) + await conn.write("test".toBytes) + await conn.close() + +waitFor(main()) diff --git a/examples/sctp_client.nim b/examples/sctp_client.nim new file mode 100644 index 0000000..b49b1d8 --- /dev/null +++ b/examples/sctp_client.nim @@ -0,0 +1,14 @@ +import chronos, stew/byteutils +import ../webrtc/sctp + +proc main() {.async.} = + let + sctp = Sctp.new(port = 4244) + address = TransportAddress(initTAddress("127.0.0.1:4242")) + conn = await sctp.connect(address, sctpPort = 13) + await conn.write("test".toBytes) + let msg = await conn.read() + echo "Client read() finished ; receive: ", string.fromBytes(msg) + await conn.close() + +waitFor(main()) diff --git a/examples/sctp_server.nim b/examples/sctp_server.nim new file mode 100644 index 0000000..429e247 --- /dev/null +++ b/examples/sctp_server.nim @@ -0,0 +1,13 @@ +import chronos, stew/byteutils +import ../webrtc/sctp + +proc main() {.async.} = + let sctp = Sctp.new(port = 4242) + sctp.startServer(13) + let conn = await sctp.listen() + let msg = await conn.read() + echo "Receive: ", string.fromBytes(msg) + await conn.close() + sctp.stopServer() + +waitFor(main()) diff --git a/webrtc/dtls.nim b/webrtc/dtls.nim deleted file mode 100644 index a904011..0000000 --- a/webrtc/dtls.nim +++ /dev/null @@ -1,142 +0,0 @@ -# Nim-WebRTC -# Copyright (c) 2023 Status Research & Development GmbH -# Licensed under either of -# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) -# * MIT license ([LICENSE-MIT](LICENSE-MIT)) -# at your option. -# This file may not be copied, modified, or distributed except according to -# those terms. - -import std/times -import chronos -import webrtc_connection - -import mbedtls/ssl -import mbedtls/pk -import mbedtls/md -import mbedtls/entropy -import mbedtls/ctr_drbg -import mbedtls/rsa -import mbedtls/x509 -import mbedtls/x509_crt -import mbedtls/bignum -import mbedtls/error -import mbedtls/net_sockets - -type - DtlsConn* = ref object of WebRTCConn - recvData: seq[seq[byte]] - recvEvent: AsyncEvent - sendEvent: AsyncEvent - - entropy: mbedtls_entropy_context - ctr_drbg: mbedtls_ctr_drbg_context - - config: mbedtls_ssl_config - ssl: mbedtls_ssl_context - -proc mbedtls_pk_rsa(pk: mbedtls_pk_context): ptr mbedtls_rsa_context = - var key = pk - case mbedtls_pk_get_type(addr key): - of MBEDTLS_PK_RSA: - return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx) - else: - return nil - -proc generateKey(self: DtlsConn): mbedtls_pk_context = - var res: mbedtls_pk_context - mb_pk_init(res) - discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) - mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, self.ctr_drbg, 4096, 65537) - return res - -proc generateCertificate(self: DtlsConn): mbedtls_x509_crt = - let - name = "C=FR,O=webrtc,CN=wbrtc" - time_format = initTimeFormat("YYYYMMddHHmmss") - time_from = times.now().format(time_format) - time_to = (times.now() + times.years(1)).format(time_format) - - - var issuer_key = self.generateKey() - var write_cert: mbedtls_x509write_cert - var serial_mpi: mbedtls_mpi - mb_x509write_crt_init(write_cert) - mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256); - mb_x509write_crt_set_subject_key(write_cert, issuer_key) - mb_x509write_crt_set_issuer_key(write_cert, issuer_key) - mb_x509write_crt_set_subject_name(write_cert, name) - mb_x509write_crt_set_issuer_name(write_cert, name) - mb_x509write_crt_set_validity(write_cert, time_from, time_to) - mb_x509write_crt_set_basic_constraints(write_cert, 0, -1) - mb_x509write_crt_set_subject_key_identifier(write_cert) - mb_x509write_crt_set_authority_key_identifier(write_cert) - mb_mpi_init(serial_mpi) - let serial_hex = mb_mpi_read_string(serial_mpi, 16) - mb_x509write_crt_set_serial(write_cert, serial_mpi) - let buf = mb_x509write_crt_pem(write_cert, 4096, mbedtls_ctr_drbg_random, self.ctr_drbg) - mb_x509_crt_parse(result, buf) - -proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = - echo "dtlsSend: " - let self = cast[ptr DtlsConn](ctx) - self.sendEvent.fire() - -proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = - echo "dtlsRecv: " - let self = cast[ptr DtlsConn](ctx)[] - self.recvEvent.fire() - -method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} = - await procCall(WebRTCConn(self).init(conn, address)) - self.recvEvent = AsyncEvent() - self.sendEvent = AsyncEvent() - - mb_ctr_drbg_init(self.ctr_drbg) - mb_entropy_init(self.entropy) - mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, - self.entropy, nil, 0) - var - srvcert = self.generateCertificate() - pkey = self.generateKey() - selfvar = self - - mb_ssl_init(self.ssl) - mb_ssl_config_init(self.config) - mb_ssl_config_defaults(self.config, MBEDTLS_SSL_IS_SERVER, - MBEDTLS_SSL_TRANSPORT_DATAGRAM, - MBEDTLS_SSL_PRESET_DEFAULT) - mb_ssl_conf_rng(self.config, mbedtls_ctr_drbg_random, self.ctr_drbg) - mb_ssl_conf_read_timeout(self.config, 10000) # in milliseconds - mb_ssl_conf_ca_chain(self.config, srvcert.next, nil) - mb_ssl_conf_own_cert(self.config, srvcert, pkey) - # cookies ? - mb_ssl_setup(self.ssl, self.config) - mb_ssl_session_reset(self.ssl) - mb_ssl_set_bio(self.ssl, cast[pointer](addr selfvar), - dtlsSend, dtlsRecv, nil) - while true: - mb_ssl_handshake(self.ssl) - -method close*(self: DtlsConn) {.async.} = - discard - -method write*(self: DtlsConn, msg: seq[byte]) {.async.} = - var buf = msg - self.sendEvent.clear() - discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](buf.cstring), buf.len()) - await self.sendEvent.wait() - -method read*(self: DtlsConn): Future[seq[byte]] {.async.} = - var res = newString(4096) - self.recvEvent.clear() - discard mbedtls_ssl_read(addr self.ssl, cast[ptr byte](res.cstring), 4096) - await self.recvEvent.wait() - -proc main {.async.} = - let laddr = initTAddress("127.0.0.1:" & "4242") - var dtls = DtlsConn() - await dtls.init(nil, laddr) - let cert = dtls.generateCertificate() - -waitFor(main()) diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim new file mode 100644 index 0000000..e78b75b --- /dev/null +++ b/webrtc/dtls/dtls.nim @@ -0,0 +1,181 @@ +# Nim-WebRTC +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import times, sequtils +import chronos, chronicles +import ./utils, ../webrtc_connection + +import mbedtls/ssl +import mbedtls/ssl_cookie +import mbedtls/ssl_cache +import mbedtls/pk +import mbedtls/md +import mbedtls/entropy +import mbedtls/ctr_drbg +import mbedtls/rsa +import mbedtls/x509 +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/error +import mbedtls/net_sockets +import mbedtls/timing + +logScope: + topics = "webrtc dtls" + +type + DtlsError* = object of CatchableError + DtlsConn* = ref object of WebRTCConn + recvData: seq[seq[byte]] + recvEvent: AsyncEvent + sendFuture: Future[void] + + timer: mbedtls_timing_delay_context + + ssl: mbedtls_ssl_context + config: mbedtls_ssl_config + cookie: mbedtls_ssl_cookie_ctx + cache: mbedtls_ssl_cache_context + + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + +proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + var self = cast[DtlsConn](ctx) + var toWrite = newSeq[byte](len) + if len > 0: + copyMem(addr toWrite[0], buf, len) + self.sendFuture = self.conn.write(toWrite) + result = len.cint + +proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + var self = cast[DtlsConn](ctx) + result = self.recvData[0].len().cint + copyMem(buf, addr self.recvData[0][0], self.recvData[0].len()) + self.recvData.delete(0..0) + +method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} = + await procCall(WebRTCConn(self).init(conn, address)) + +method write*(self: DtlsConn, msg: seq[byte]) {.async.} = + var buf = msg + discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint) + +method read*(self: DtlsConn): Future[seq[byte]] {.async.} = + return await self.conn.read() + +method close*(self: DtlsConn) {.async.} = + discard + +method getRemoteAddress*(self: DtlsConn): TransportAddress = + self.conn.getRemoteAddress() + +type + Dtls* = ref object of RootObj + address: TransportAddress + started: bool + +proc start*(self: Dtls, address: TransportAddress) = + if self.started: + warn "Already started" + return + + self.address = address + self.started = true + +proc stop*(self: Dtls) = + if not self.started: + warn "Already stopped" + return + + self.started = false + +proc serverHandshake(self: DtlsConn) {.async.} = + var shouldRead = true + + while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + self.recvData.add(await self.conn.read()) + var ta = self.getRemoteAddress() + case ta.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, ta.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, ta.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") + + self.sendFuture = nil + let res = mb_ssl_handshake_step(self.ssl) + shouldRead = false + if not self.sendFuture.isNil(): await self.sendFuture + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ or + self.ssl.private_state == MBEDTLS_SSL_CLIENT_KEY_EXCHANGE: + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ssl) + shouldRead = true + continue + elif res != 0: + raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) + +proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} = + var + selfvar = self + res = DtlsConn() + let v = cast[pointer](res) + + await res.init(conn, self.address) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + mb_ssl_cookie_init(res.cookie) + mb_ssl_cache_init(res.cache) + + mb_ctr_drbg_init(res.ctr_drbg) + mb_entropy_init(res.entropy) + mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0) + + var pkey = res.ctr_drbg.generateKey() + var srvcert = res.ctr_drbg.generateCertificate(pkey) + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_conf_own_cert(res.config, srvcert, pkey) + mb_ssl_cookie_setup(res.cookie, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_dtls_cookies(res.config, res.cookie) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_session_reset(res.ssl) + mb_ssl_set_bio(res.ssl, cast[pointer](res), + dtlsSend, dtlsRecv, nil) + await res.serverHandshake() + return res + +proc dial*(self: Dtls, address: TransportAddress): DtlsConn = + discard + +import ../udp_connection +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4433") + let udp = UdpConn() + await udp.init(nil, laddr) + let dtls = Dtls() + dtls.start(laddr) + let x = await dtls.accept(udp) + echo "After accept" + +waitFor(main()) diff --git a/webrtc/dtls/utils.nim b/webrtc/dtls/utils.nim new file mode 100644 index 0000000..6f9ad5b --- /dev/null +++ b/webrtc/dtls/utils.nim @@ -0,0 +1,102 @@ +# Nim-WebRTC +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/times + +import stew/byteutils + +import mbedtls/pk +import mbedtls/rsa +import mbedtls/ctr_drbg +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/md + +import chronicles + +const mb_ssl_states* = @[ + "MBEDTLS_SSL_HELLO_REQUEST", + "MBEDTLS_SSL_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_HELLO", + "MBEDTLS_SSL_SERVER_CERTIFICATE", + "MBEDTLS_SSL_SERVER_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_REQUEST", + "MBEDTLS_SSL_SERVER_HELLO_DONE", + "MBEDTLS_SSL_CLIENT_CERTIFICATE", + "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_CLIENT_FINISHED", + "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_SERVER_FINISHED", + "MBEDTLS_SSL_FLUSH_BUFFERS", + "MBEDTLS_SSL_HANDSHAKE_WRAPUP", + "MBEDTLS_SSL_NEW_SESSION_TICKET", + "MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", + "MBEDTLS_SSL_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", + "MBEDTLS_SSL_END_OF_EARLY_DATA", + "MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED", + "MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_HANDSHAKE_OVER", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" +] + +proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context = + var key = pk + case mbedtls_pk_get_type(addr key) + of MBEDTLS_PK_RSA: + return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx) + else: + return nil + +template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = + var res: mbedtls_pk_context + mb_pk_init(res) + discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) + mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537) + let x = mb_pk_rsa(res) + res + +template generateCertificate*(random: mbedtls_ctr_drbg_context, + issuer_key: mbedtls_pk_context): mbedtls_x509_crt = + let + name = "C=FR,O=Status,CN=webrtc" + time_format = initTimeFormat("YYYYMMddHHmmss") + time_from = times.now().format(time_format) + time_to = (times.now() + times.years(1)).format(time_format) + + var write_cert: mbedtls_x509write_cert + var serial_mpi: mbedtls_mpi + mb_x509write_crt_init(write_cert) + mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256); + mb_x509write_crt_set_subject_key(write_cert, issuer_key) + mb_x509write_crt_set_issuer_key(write_cert, issuer_key) + mb_x509write_crt_set_subject_name(write_cert, name) + mb_x509write_crt_set_issuer_name(write_cert, name) + mb_x509write_crt_set_validity(write_cert, time_from, time_to) + mb_x509write_crt_set_basic_constraints(write_cert, 0, -1) + mb_x509write_crt_set_subject_key_identifier(write_cert) + mb_x509write_crt_set_authority_key_identifier(write_cert) + mb_mpi_init(serial_mpi) + let serial_hex = mb_mpi_read_string(serial_mpi, 16) + mb_x509write_crt_set_serial(write_cert, serial_mpi) + let buf = + try: + mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random) + except MbedTLSError as e: + raise e + var res: mbedtls_x509_crt + mb_x509_crt_parse(res, buf) + res diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index d86833d..7206c37 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -8,7 +8,7 @@ # those terms. import tables, bitops, posix, strutils, sequtils -import chronos, chronicles, stew/ranges/ptr_arith +import chronos, chronicles, stew/[ranges/ptr_arith, byteutils] import usrsctp export chronicles @@ -101,10 +101,13 @@ proc write*(self: SctpConnection, buf: seq[byte]) {.async.} = self.sctp.sentConnection = self self.sctp.sentAddress = self.address let sendvErr = self.sctp.usrsctpAwait: - self.sctpSocket.usrsctp_sendv(addr buf[0], buf.len.uint, + self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint, nil, 0, nil, 0, SCTP_SENDV_NOINFO, 0) +proc write*(self: SctpConnection, s: string) {.async.} = + await self.write(s.toBytes()) + proc close*(self: SctpConnection) {.async.} = self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close() @@ -143,7 +146,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = if bitand(flags, MSG_NOTIFICATION) != 0: trace "Notification received", length = n else: - conn.dataRecv = conn.dataRecv.concat(buffer[0..n]) + conn.dataRecv = conn.dataRecv.concat(buffer[0..\e[0;1m onReceive\e[0m: ", udp.getMessage().len() + self.remote = address self.recvData.add(msg) self.recvEvent.fire() self.recvEvent = newAsyncEvent() - self.udp = newDatagramTransport(onReceive) + self.udp = newDatagramTransport(onReceive, local = addrss) method close(self: UdpConn) {.async.} = self.udp.close() if not self.conn.isNil(): - self.conn.close() + await self.conn.close() method write(self: UdpConn, msg: seq[byte]) {.async.} = - await self.udp.sendTo(self.address, msg) + echo "\e[33m\e[0;1m write\e[0m" + await self.udp.sendTo(self.remote, msg) -method read(self: UdpConn): seq[byte] {.async.} = +method read(self: UdpConn): Future[seq[byte]] {.async.} = + echo "\e[33m\e[0;1m read\e[0m" while self.recvData.len() <= 0: self.recvEvent.clear() await self.recvEvent.wait() result = self.recvData[0] self.recvData.delete(0..0) + +method getRemoteAddress*(self: UdpConn): TransportAddress = + self.remote diff --git a/webrtc/webrtc.nim b/webrtc/webrtc.nim index ce8eb0b..f8cdc36 100644 --- a/webrtc/webrtc.nim +++ b/webrtc/webrtc.nim @@ -8,7 +8,7 @@ # those terms. import chronos, chronicles -import stun +import stun/stun logScope: topics = "webrtc" diff --git a/webrtc/webrtc_connection.nim b/webrtc/webrtc_connection.nim index eae2b1b..103d48c 100644 --- a/webrtc/webrtc_connection.nim +++ b/webrtc/webrtc_connection.nim @@ -11,8 +11,8 @@ import chronos type WebRTCConn* = ref object of RootObj - conn: WebRTCConn - address: TransportAddress + conn*: WebRTCConn + address*: TransportAddress # isClosed: bool # isEof: bool @@ -28,3 +28,6 @@ method write*(self: WebRTCConn, msg: seq[byte]) {.async, base.} = method read*(self: WebRTCConn): Future[seq[byte]] {.async, base.} = doAssert(false, "not implemented!") + +method getRemoteAddress*(self: WebRTCConn): TransportAddress {.base.} = + doAssert(false, "not implemented")