From 359a81df4a408bfbcef4f548d594b6449490e6e5 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Fri, 1 Mar 2024 13:49:05 +0100 Subject: [PATCH] Dtls comments + close + minor improvements --- examples/ping.nim | 23 ++++ examples/pong.nim | 2 +- webrtc/dtls/dtls.nim | 200 +++++++++++++++++++++----------- webrtc/stun/stun_attributes.nim | 1 - webrtc/webrtc.nim | 4 +- 5 files changed, 162 insertions(+), 68 deletions(-) create mode 100644 examples/ping.nim diff --git a/examples/ping.nim b/examples/ping.nim new file mode 100644 index 0000000..70c9f01 --- /dev/null +++ b/examples/ping.nim @@ -0,0 +1,23 @@ +import chronos, stew/byteutils +import ../webrtc/udp_connection +import ../webrtc/stun/stun_connection +import ../webrtc/dtls/dtls +import ../webrtc/sctp + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4244") + let udp = UdpConn() + udp.init(laddr) + let stun = StunConn() + stun.init(udp, laddr) + let dtls = Dtls() + dtls.init(stun, laddr) + let sctp = Sctp.new(dtls, laddr) + let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13) + while true: + await conn.write("ping".toBytes) + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await sleepAsync(1.seconds) + +waitFor(main()) diff --git a/examples/pong.nim b/examples/pong.nim index 69abf57..e881585 100644 --- a/examples/pong.nim +++ b/examples/pong.nim @@ -19,7 +19,7 @@ proc main() {.async.} = let stun = StunConn() stun.init(udp, laddr) let dtls = Dtls() - dtls.start(stun, laddr) + dtls.init(stun, laddr) let sctp = Sctp.new(dtls, laddr) sctp.listen(13) while true: diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim index 808a5f1..5d858fb 100644 --- a/webrtc/dtls/dtls.nim +++ b/webrtc/dtls/dtls.nim @@ -7,7 +7,7 @@ # This file may not be copied, modified, or distributed except according to # those terms. -import times, deques, tables +import times, deques, tables, sequtils import chronos, chronicles import ./utils, ../stun/stun_connection @@ -29,11 +29,22 @@ import mbedtls/timing logScope: topics = "webrtc dtls" -# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit. -# There might be some errors (or crashes) in weird cases with the no wait option +# Implementation of a DTLS client and a DTLS Server by using the mbedtls library. +# Multiple things here are unintuitive partly because of the callbacks +# used by mbedtls and that those callbacks cannot be async. +# +# TODO: +# - Check the viability of the add/pop first/last of the asyncqueue with the limit. +# There might be some errors (or crashes) with some edge cases with the no wait option +# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE +# - Not critical - May be interesting to split Dtls and DtlsConn into two files -const - PendingHandshakeLimit = 1024 +# This limit is arbitrary, it could be interesting to make it configurable. +const PendingHandshakeLimit = 1024 + +# -- DtlsConn -- +# A Dtls connection to a specific IP address recovered by the receiving part of +# the Udp "connection" type DtlsError* = object of CatchableError @@ -43,6 +54,8 @@ type raddr*: TransportAddress dataRecv: AsyncQueue[seq[byte]] sendFuture: Future[void] + closed: bool + closeEvent: AsyncEvent timer: mbedtls_timing_delay_context @@ -57,55 +70,99 @@ type localCert: seq[byte] remoteCert: seq[byte] -proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = - var self = cast[DtlsConn](ctx) - var toWrite = newSeq[byte](len) - if len > 0: - copyMem(addr toWrite[0], buf, len) - trace "dtls send", len - self.sendFuture = self.conn.write(self.raddr, toWrite) - result = len.cint - -proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = - let self = cast[DtlsConn](ctx) - if self.dataRecv.len() == 0: - return MBEDTLS_ERR_SSL_WANT_READ - - var dataRecv = self.dataRecv.popFirstNoWait() - copyMem(buf, addr dataRecv[0], dataRecv.len()) - result = dataRecv.len().cint - trace "dtls receive", len, result - -proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} = +proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) = self.conn = conn self.laddr = laddr self.dataRecv = newAsyncQueue[seq[byte]]() + self.closed = false + self.closeEvent = newAsyncEvent() + +proc join(self: DtlsConn) {.async.} = + await self.closeEvent.wait() + +proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} = + var shouldRead = isServer + while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + if isServer: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + self.sendFuture = nil + let res = mb_ssl_handshake_step(self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + shouldRead = false + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ: + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ssl) + shouldRead = isServer + continue + elif res != 0: + raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) + +proc close*(self: DtlsConn) {.async.} = + if self.closed: + debug "Try to close DtlsConn twice" + return + + self.closed = true + self.sendFuture = nil + # TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls + let x = mbedtls_ssl_close_notify(addr self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + self.closeEvent.fire() proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = + if self.closed: + debug "Try to write on an already closed DtlsConn" + return var buf = msg try: + let sendFuture = newFuture[void]("DtlsConn write") + self.sendFuture = nil let write = mb_ssl_write(self.ssl, buf) + if not self.sendFuture.isNil(): + await self.sendFuture trace "Dtls write", msgLen = msg.len(), actuallyWrote = write except MbedTLSError as exc: trace "Dtls write error", errorMsg = exc.msg raise exc proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = + if self.closed: + debug "Try to read on an already closed DtlsConn" + return var res = newSeq[byte](8192) while true: let tmp = await self.dataRecv.popFirst() self.dataRecv.addFirstNoWait(tmp) - # TODO: exception catching - let length = mb_ssl_read(self.ssl, res) + # TODO: Find a clear way to use the template `mb_ssl_read` without + # messing up things with exception + let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint) if length == MBEDTLS_ERR_SSL_WANT_READ: continue if length < 0: - trace "dtls read", error = $(length.cint.mbedtls_high_level_strerr()) + raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr())) res.setLen(length) return res -proc close*(self: DtlsConn) {.async.} = - discard +# -- Dtls -- +# The Dtls object read every messages from the UdpConn/StunConn and, if the address +# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue +# to be accepted later, if the address is stored, add the message received to the +# corresponding DtlsConn `dataRecv` queue. type Dtls* = ref object of RootObj @@ -130,7 +187,7 @@ proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])], return aq.addLastNoWait((raddr, buf)) -proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) = +proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) = if self.started: warn "Already started" return @@ -159,43 +216,16 @@ proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) = self.localCert = newSeq[byte](self.serverCert.raw.len) copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len) -proc stop*(self: Dtls) = +proc stop*(self: Dtls) {.async.} = if not self.started: warn "Already stopped" return + await allFutures(toSeq(self.connections.values()).mapIt(it.close())) self.readLoop.cancel() self.started = false -proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} = - var shouldRead = isServer - while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: - if shouldRead: - if isServer: - case self.raddr.family - of AddressFamily.IPv4: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) - of AddressFamily.IPv6: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) - else: - raise newException(DtlsError, "Remote address isn't an IP address") - let tmp = await self.dataRecv.popFirst() - self.dataRecv.addFirstNoWait(tmp) - self.sendFuture = nil - let res = mb_ssl_handshake_step(self.ssl) - if not self.sendFuture.isNil(): await self.sendFuture - shouldRead = false - if res == MBEDTLS_ERR_SSL_WANT_WRITE: - continue - elif res == MBEDTLS_ERR_SSL_WANT_READ: - shouldRead = true - continue - elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: - mb_ssl_session_reset(self.ssl) - shouldRead = true - continue - elif res != 0: - raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) +# -- Remote / Local certificate getter -- proc remoteCertificate*(conn: DtlsConn): seq[byte] = conn.remoteCert @@ -206,8 +236,14 @@ proc localCertificate*(conn: DtlsConn): seq[byte] = proc localCertificate*(self: Dtls): seq[byte] = self.localCert +# -- MbedTLS Callbacks -- + proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32): cint {.cdecl.} = + # verify is the procedure called by mbedtls when receiving the remote + # certificate. It's usually used to verify the validity of the certificate. + # We use this procedure to store the remote certificate as it's mandatory + # to have it for the Prologue of the Noise protocol, aswell as the localCertificate. var self = cast[DtlsConn](ctx) let cert = pcert[] @@ -215,12 +251,45 @@ proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt, copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len) return 0 +proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsSend is the procedure called by mbedtls when data needs to be sent. + # As the StunConn's write proc is asynchronous and dtlsSend cannot be async, + # we store the future of this write and await it after the end of the + # function (see write or dtlsHanshake for example). + var self = cast[DtlsConn](ctx) + var toWrite = newSeq[byte](len) + if len > 0: + copyMem(addr toWrite[0], buf, len) + trace "dtls send", len + self.sendFuture = self.conn.write(self.raddr, toWrite) + result = len.cint + +proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsRecv is the procedure called by mbedtls when data needs to be received. + # As we cannot asynchronously await for data to be received, we use a data received + # queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await + # when the mbedtls proc resumed (see read or dtlsHandshake for example) + let self = cast[DtlsConn](ctx) + if self.dataRecv.len() == 0: + return MBEDTLS_ERR_SSL_WANT_READ + + var dataRecv = self.dataRecv.popFirstNoWait() + copyMem(buf, addr dataRecv[0], dataRecv.len()) + result = dataRecv.len().cint + trace "dtls receive", len, result + +# -- Dtls Accept / Connect procedures -- + +proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} = + await conn.join() + self.connections.del(raddr) + proc accept*(self: Dtls): Future[DtlsConn] {.async.} = var selfvar = self res = DtlsConn() - await res.init(self.conn, self.laddr) + res.init(self.conn, self.laddr) mb_ssl_init(res.ssl) mb_ssl_config_init(res.config) mb_ssl_cookie_init(res.cookie) @@ -248,8 +317,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} = mb_ssl_session_reset(res.ssl) mb_ssl_set_verify(res.ssl, verify, res) mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) - mb_ssl_set_bio(res.ssl, cast[pointer](res), - dtlsSend, dtlsRecv, nil) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) while true: let (raddr, buf) = await self.pendingHandshakes.popFirst() try: @@ -257,6 +325,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} = res.dataRecv.addLastNoWait(buf) self.connections[raddr] = res await res.dtlsHandshake(true) + asyncSpawn self.removeConnection(res, raddr) break except CatchableError as exc: trace "Handshake fail", remoteAddress = raddr, error = exc.msg @@ -269,7 +338,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = selfvar = self res = DtlsConn() - await res.init(self.conn, self.laddr) + res.init(self.conn, self.laddr) mb_ssl_init(res.ssl) mb_ssl_config_init(res.config) @@ -303,6 +372,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = try: await res.dtlsHandshake(false) + asyncSpawn self.removeConnection(res, raddr) except CatchableError as exc: trace "Handshake fail", remoteAddress = raddr, error = exc.msg self.connections.del(raddr) diff --git a/webrtc/stun/stun_attributes.nim b/webrtc/stun/stun_attributes.nim index eaffc82..bd6179f 100644 --- a/webrtc/stun/stun_attributes.nim +++ b/webrtc/stun/stun_attributes.nim @@ -11,7 +11,6 @@ import std/sha1, sequtils, typetraits, std/md5 import binary_serialization, stew/byteutils, chronos -import ../utils # -- Utils -- diff --git a/webrtc/webrtc.nim b/webrtc/webrtc.nim index 1c39eb8..d93a8b8 100644 --- a/webrtc/webrtc.nim +++ b/webrtc/webrtc.nim @@ -17,6 +17,8 @@ import sctp, datachannel logScope: topics = "webrtc" +# TODO: Implement a connect (or dial) procedure + type WebRTC* = ref object udp*: UdpConn @@ -29,7 +31,7 @@ proc new*(T: typedesc[WebRTC], address: TransportAddress): T = var webrtc = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls()) webrtc.udp.init(address) webrtc.stun.init(webrtc.udp, address) - webrtc.dtls.start(webrtc.stun, address) + webrtc.dtls.init(webrtc.stun, address) webrtc.sctp = Sctp.new(webrtc.dtls, address) return webrtc