Multiple fixes on stun / udp / dtls

This commit is contained in:
Ludovic Chenut 2023-10-10 17:38:31 +02:00
parent 397c84238a
commit 2a9b8298eb
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
4 changed files with 111 additions and 62 deletions

View File

@ -7,9 +7,9 @@
# This file may not be copied, modified, or distributed except according to
# those terms.
import times, sequtils
import times, deques, tables
import chronos, chronicles
import ./utils, ../webrtc_connection
import ./utils, ../stun/stun_connection
import mbedtls/ssl
import mbedtls/ssl_cookie
@ -29,11 +29,19 @@ import mbedtls/timing
logScope:
topics = "webrtc dtls"
# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit.
# There might be some errors (or crashes) in weird cases with the no wait option
const
PendingHandshakeLimit = 1024
type
DtlsError* = object of CatchableError
DtlsConn* = ref object of WebRTCConn
recvData: seq[seq[byte]]
recvEvent: AsyncEvent
DtlsConn* = ref object
conn: StunConn
laddr: TransportAddress
raddr: TransportAddress
dataRecv: AsyncQueue[seq[byte]]
sendFuture: Future[void]
timer: mbedtls_timing_delay_context
@ -51,70 +59,99 @@ proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
self.sendFuture = self.conn.write(toWrite)
self.sendFuture = self.conn.write(self.raddr, toWrite)
result = len.cint
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
var self = cast[DtlsConn](ctx)
result = self.recvData[0].len().cint
copyMem(buf, addr self.recvData[0][0], self.recvData[0].len())
self.recvData.delete(0..0)
var
self = cast[DtlsConn](ctx)
dataRecv = self.dataRecv.popFirstNoWait()
copyMem(buf, addr dataRecv[0], dataRecv.len())
result = dataRecv.len().cint
method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
await procCall(WebRTCConn(self).init(conn, address))
proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
self.conn = conn
self.laddr = laddr
self.dataRecv = newAsyncQueue[seq[byte]]()
method write*(self: DtlsConn, msg: seq[byte]) {.async.} =
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
var buf = msg
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
method read*(self: DtlsConn): Future[seq[byte]] {.async.} =
return await self.conn.read()
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
var res = newSeq[byte](8192)
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
res.setLen(length)
return res
method close*(self: DtlsConn) {.async.} =
proc close*(self: DtlsConn) {.async.} =
discard
method getRemoteAddress*(self: DtlsConn): TransportAddress =
self.conn.getRemoteAddress()
type
Dtls* = ref object of RootObj
address: TransportAddress
connections: Table[TransportAddress, DtlsConn]
pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])]
conn: StunConn
laddr: TransportAddress
started: bool
readLoop: Future[void]
proc start*(self: Dtls, address: TransportAddress) =
proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
raddr: TransportAddress, buf: seq[byte]) =
for (k, v) in aq.mitems():
if k == raddr:
v = buf
return
aq.addLastNoWait((raddr, buf))
proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
if self.started:
warn "Already started"
return
self.address = address
proc readLoop() {.async.} =
while true:
let (buf, raddr) = await self.conn.read()
if self.connections.hasKey(raddr):
self.connections[raddr].dataRecv.addLastNoWait(buf)
else:
self.pendingHandshakes.updateOrAdd(raddr, buf)
self.connections = initTable[TransportAddress, DtlsConn]()
self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit)
self.conn = conn
self.laddr = laddr
self.started = true
self.readLoop = readLoop()
proc stop*(self: Dtls) =
if not self.started:
warn "Already stopped"
return
self.readLoop.cancel()
self.started = false
proc serverHandshake(self: DtlsConn) {.async.} =
var shouldRead = true
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
var shouldRead = true
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
self.recvData.add(await self.conn.read())
var ta = self.getRemoteAddress()
case ta.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, ta.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, ta.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
self.sendFuture = nil
let res = mb_ssl_handshake_step(self.ssl)
shouldRead = false
if not self.sendFuture.isNil(): await self.sendFuture
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ or
@ -128,13 +165,13 @@ proc serverHandshake(self: DtlsConn) {.async.} =
elif res != 0:
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
var
selfvar = self
res = DtlsConn()
let v = cast[pointer](res)
await res.init(conn, self.address)
await res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
mb_ssl_cookie_init(res.cookie)
@ -162,20 +199,32 @@ proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
mb_ssl_session_reset(res.ssl)
mb_ssl_set_bio(res.ssl, cast[pointer](res),
dtlsSend, dtlsRecv, nil)
await res.serverHandshake()
while true:
let (raddr, buf) = await self.pendingHandshakes.popFirst()
try:
res.raddr = raddr
res.dataRecv.addLastNoWait(buf)
self.connections[raddr] = res
await res.serverHandshake()
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
self.connections.del(raddr)
continue
return res
proc dial*(self: Dtls, address: TransportAddress): DtlsConn =
proc dial*(self: Dtls, raddr: TransportAddress): DtlsConn =
discard
import ../udp_connection
proc main() {.async.} =
let laddr = initTAddress("127.0.0.1:4433")
let udp = UdpConn()
await udp.init(nil, laddr)
await udp.init(laddr)
let stun = StunConn()
await stun.init(udp, laddr)
let dtls = Dtls()
dtls.start(laddr)
let x = await dtls.accept(udp)
dtls.start(stun, laddr)
let x = await dtls.accept()
echo "After accept"
waitFor(main())

View File

@ -10,6 +10,7 @@
import tables, bitops, posix, strutils, sequtils
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils]
import usrsctp
import dtls/dtls
export chronicles
@ -36,6 +37,7 @@ type
params*: SctpMessageParameters
SctpConnection* = ref object
conn: DtlsConn
state: SctpState
connectEvent: AsyncEvent
sctp: Sctp
@ -45,8 +47,9 @@ type
dataRecv: AsyncQueue[SctpMessage]
Sctp* = ref object
dtls: Dtls
udp: DatagramTransport
connections: Table[TransportAddress, SctpConnection]
#connections: Table[TransportAddress, SctpConnection]
gotConnection: AsyncEvent
timersHandler: Future[void]
isServer: bool

View File

@ -19,27 +19,27 @@ type
proc handles(self: StunConn) {.async.} =
while true: # TODO: while not self.conn.atEof()
let (msg, address) = await self.conn.read()
let (msg, raddr) = await self.conn.read()
if Stun.isMessage(msg):
let res = Stun.getResponse(msg, self.laddr)
if res.isSome():
await self.conn.write(res.get())
await self.conn.write(raddr, res.get())
else:
self.dataRecv.addLastNoWait((msg, address))
self.dataRecv.addLastNoWait((msg, raddr))
proc init(self: StunConn, conn: UdpConn, laddr: TransportAddress) {.async.} =
proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) {.async.} =
self.conn = conn
self.laddr = laddr
self.dataRecv = newAsyncQueue()
self.handlesFut = handles()
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
self.handlesFut = self.handles()
proc close(self: StunConn) {.async.} =
proc close*(self: StunConn) {.async.} =
self.handlesFut.cancel() # check before?
self.conn.close()
await self.conn.close()
proc write(self: StunConn, msg: seq[byte]) {.async.} =
await self.conn.write(msg)
proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
await self.conn.write(raddr, msg)
proc read(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
return await self.dataRecv.popFirst()

View File

@ -9,7 +9,6 @@
import sequtils
import chronos, chronicles
import webrtc_connection
logScope:
topics = "webrtc udp"
@ -20,7 +19,7 @@ type
udp: DatagramTransport
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
proc init(self: UdpConn, laddr: TransportAddress) {.async.} =
proc init*(self: UdpConn, laddr: TransportAddress) {.async.} =
self.laddr = laddr
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
@ -28,18 +27,16 @@ proc init(self: UdpConn, laddr: TransportAddress) {.async.} =
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m: ", msg.len()
self.dataRecv.addLastNoWait((msg, address))
self.dataRecv = newAsyncQueue()
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
self.udp = newDatagramTransport(onReceive, local = laddr)
proc close(self: UdpConn) {.async.} =
proc close*(self: UdpConn) {.async.} =
self.udp.close()
if not self.conn.isNil():
await self.conn.close()
proc write(self: UdpConn, msg: seq[byte]) {.async.} =
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
echo "\e[33m<UDP>\e[0;1m write\e[0m"
await self.udp.sendTo(self.remote, msg)
await self.udp.sendTo(raddr, msg)
proc read(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
echo "\e[33m<UDP>\e[0;1m read\e[0m"
return await self.dataRecv.popFirst()