Dtls comments + close + minor improvements

This commit is contained in:
Ludovic Chenut 2024-03-01 13:49:05 +01:00
parent 4c1eb13926
commit 359a81df4a
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
5 changed files with 162 additions and 68 deletions

23
examples/ping.nim Normal file
View File

@ -0,0 +1,23 @@
import chronos, stew/byteutils
import ../webrtc/udp_connection
import ../webrtc/stun/stun_connection
import ../webrtc/dtls/dtls
import ../webrtc/sctp
proc main() {.async.} =
let laddr = initTAddress("127.0.0.1:4244")
let udp = UdpConn()
udp.init(laddr)
let stun = StunConn()
stun.init(udp, laddr)
let dtls = Dtls()
dtls.init(stun, laddr)
let sctp = Sctp.new(dtls, laddr)
let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13)
while true:
await conn.write("ping".toBytes)
let msg = await conn.read()
echo "Received: ", string.fromBytes(msg.data)
await sleepAsync(1.seconds)
waitFor(main())

View File

@ -19,7 +19,7 @@ proc main() {.async.} =
let stun = StunConn()
stun.init(udp, laddr)
let dtls = Dtls()
dtls.start(stun, laddr)
dtls.init(stun, laddr)
let sctp = Sctp.new(dtls, laddr)
sctp.listen(13)
while true:

View File

@ -7,7 +7,7 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import times, deques, tables
import times, deques, tables, sequtils
import chronos, chronicles
import ./utils, ../stun/stun_connection
@ -29,11 +29,22 @@ import mbedtls/timing
logScope:
topics = "webrtc dtls"
# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit.
# There might be some errors (or crashes) in weird cases with the no wait option
# Implementation of a DTLS client and a DTLS Server by using the mbedtls library.
# Multiple things here are unintuitive partly because of the callbacks
# used by mbedtls and that those callbacks cannot be async.
#
# TODO:
# - Check the viability of the add/pop first/last of the asyncqueue with the limit.
# There might be some errors (or crashes) with some edge cases with the no wait option
# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE
# - Not critical - May be interesting to split Dtls and DtlsConn into two files
const
PendingHandshakeLimit = 1024
# This limit is arbitrary, it could be interesting to make it configurable.
const PendingHandshakeLimit = 1024
# -- DtlsConn --
# A Dtls connection to a specific IP address recovered by the receiving part of
# the Udp "connection"
type
DtlsError* = object of CatchableError
@ -43,6 +54,8 @@ type
raddr*: TransportAddress
dataRecv: AsyncQueue[seq[byte]]
sendFuture: Future[void]
closed: bool
closeEvent: AsyncEvent
timer: mbedtls_timing_delay_context
@ -57,55 +70,99 @@ type
localCert: seq[byte]
remoteCert: seq[byte]
proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
var self = cast[DtlsConn](ctx)
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
trace "dtls send", len
self.sendFuture = self.conn.write(self.raddr, toWrite)
result = len.cint
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ
var dataRecv = self.dataRecv.popFirstNoWait()
copyMem(buf, addr dataRecv[0], dataRecv.len())
result = dataRecv.len().cint
trace "dtls receive", len, result
proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) =
self.conn = conn
self.laddr = laddr
self.dataRecv = newAsyncQueue[seq[byte]]()
self.closed = false
self.closeEvent = newAsyncEvent()
proc join(self: DtlsConn) {.async.} =
await self.closeEvent.wait()
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
var shouldRead = isServer
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
if isServer:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
self.sendFuture = nil
let res = mb_ssl_handshake_step(self.ssl)
if not self.sendFuture.isNil():
await self.sendFuture
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ:
shouldRead = true
continue
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
mb_ssl_session_reset(self.ssl)
shouldRead = isServer
continue
elif res != 0:
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
proc close*(self: DtlsConn) {.async.} =
if self.closed:
debug "Try to close DtlsConn twice"
return
self.closed = true
self.sendFuture = nil
# TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls
let x = mbedtls_ssl_close_notify(addr self.ssl)
if not self.sendFuture.isNil():
await self.sendFuture
self.closeEvent.fire()
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
if self.closed:
debug "Try to write on an already closed DtlsConn"
return
var buf = msg
try:
let sendFuture = newFuture[void]("DtlsConn write")
self.sendFuture = nil
let write = mb_ssl_write(self.ssl, buf)
if not self.sendFuture.isNil():
await self.sendFuture
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
except MbedTLSError as exc:
trace "Dtls write error", errorMsg = exc.msg
raise exc
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
if self.closed:
debug "Try to read on an already closed DtlsConn"
return
var res = newSeq[byte](8192)
while true:
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
# TODO: exception catching
let length = mb_ssl_read(self.ssl, res)
# TODO: Find a clear way to use the template `mb_ssl_read` without
# messing up things with exception
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
if length == MBEDTLS_ERR_SSL_WANT_READ:
continue
if length < 0:
trace "dtls read", error = $(length.cint.mbedtls_high_level_strerr())
raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr()))
res.setLen(length)
return res
proc close*(self: DtlsConn) {.async.} =
discard
# -- Dtls --
# The Dtls object read every messages from the UdpConn/StunConn and, if the address
# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue
# to be accepted later, if the address is stored, add the message received to the
# corresponding DtlsConn `dataRecv` queue.
type
Dtls* = ref object of RootObj
@ -130,7 +187,7 @@ proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
return
aq.addLastNoWait((raddr, buf))
proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
if self.started:
warn "Already started"
return
@ -159,43 +216,16 @@ proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
self.localCert = newSeq[byte](self.serverCert.raw.len)
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
proc stop*(self: Dtls) =
proc stop*(self: Dtls) {.async.} =
if not self.started:
warn "Already stopped"
return
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
self.readLoop.cancel()
self.started = false
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
var shouldRead = isServer
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
if isServer:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
self.sendFuture = nil
let res = mb_ssl_handshake_step(self.ssl)
if not self.sendFuture.isNil(): await self.sendFuture
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ:
shouldRead = true
continue
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
mb_ssl_session_reset(self.ssl)
shouldRead = true
continue
elif res != 0:
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
# -- Remote / Local certificate getter --
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
conn.remoteCert
@ -206,8 +236,14 @@ proc localCertificate*(conn: DtlsConn): seq[byte] =
proc localCertificate*(self: Dtls): seq[byte] =
self.localCert
# -- MbedTLS Callbacks --
proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
state: cint, pflags: ptr uint32): cint {.cdecl.} =
# verify is the procedure called by mbedtls when receiving the remote
# certificate. It's usually used to verify the validity of the certificate.
# We use this procedure to store the remote certificate as it's mandatory
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
var self = cast[DtlsConn](ctx)
let cert = pcert[]
@ -215,12 +251,45 @@ proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
return 0
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
# we store the future of this write and await it after the end of the
# function (see write or dtlsHanshake for example).
var self = cast[DtlsConn](ctx)
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
trace "dtls send", len
self.sendFuture = self.conn.write(self.raddr, toWrite)
result = len.cint
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
# As we cannot asynchronously await for data to be received, we use a data received
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ
var dataRecv = self.dataRecv.popFirstNoWait()
copyMem(buf, addr dataRecv[0], dataRecv.len())
result = dataRecv.len().cint
trace "dtls receive", len, result
# -- Dtls Accept / Connect procedures --
proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} =
await conn.join()
self.connections.del(raddr)
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
var
selfvar = self
res = DtlsConn()
await res.init(self.conn, self.laddr)
res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
mb_ssl_cookie_init(res.cookie)
@ -248,8 +317,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
mb_ssl_session_reset(res.ssl)
mb_ssl_set_verify(res.ssl, verify, res)
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ssl, cast[pointer](res),
dtlsSend, dtlsRecv, nil)
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
while true:
let (raddr, buf) = await self.pendingHandshakes.popFirst()
try:
@ -257,6 +325,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
res.dataRecv.addLastNoWait(buf)
self.connections[raddr] = res
await res.dtlsHandshake(true)
asyncSpawn self.removeConnection(res, raddr)
break
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
@ -269,7 +338,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
selfvar = self
res = DtlsConn()
await res.init(self.conn, self.laddr)
res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
@ -303,6 +372,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
try:
await res.dtlsHandshake(false)
asyncSpawn self.removeConnection(res, raddr)
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
self.connections.del(raddr)

View File

@ -11,7 +11,6 @@ import std/sha1, sequtils, typetraits, std/md5
import binary_serialization,
stew/byteutils,
chronos
import ../utils
# -- Utils --

View File

@ -17,6 +17,8 @@ import sctp, datachannel
logScope:
topics = "webrtc"
# TODO: Implement a connect (or dial) procedure
type
WebRTC* = ref object
udp*: UdpConn
@ -29,7 +31,7 @@ proc new*(T: typedesc[WebRTC], address: TransportAddress): T =
var webrtc = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls())
webrtc.udp.init(address)
webrtc.stun.init(webrtc.udp, address)
webrtc.dtls.start(webrtc.stun, address)
webrtc.dtls.init(webrtc.stun, address)
webrtc.sctp = Sctp.new(webrtc.dtls, address)
return webrtc