fix: stun and dtls close (#22)

* fix: stun & dtls close

* chore: change the test names & variable names

* fix: add isClosed to DtlsConn

* feat: remove `closeEvent`/`join`/`cleanup` and replace those by `onClose` sequence of proc to be executed on close
This commit is contained in:
Ludovic Chenut 2024-08-30 12:19:09 +02:00 committed by GitHub
parent e9bda6babf
commit 26cad2a383
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 45 deletions

View File

@ -81,3 +81,40 @@ suite "DTLS":
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
await allFutures(udp1.close(), udp2.close(), udp3.close())
asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes":
# Related to https://github.com/vacp2p/nim-webrtc/pull/22
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
var
serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut
await serverConn.write(@[1'u8, 2, 3, 4])
await clientConn.write(@[5'u8, 6, 7, 8])
check (await serverConn.read()) == @[5'u8, 6, 7, 8]
check (await clientConn.read()) == @[1'u8, 2, 3, 4]
await allFutures(serverConn.close(), clientConn.close())
check serverConn.isClosed() and clientConn.isClosed()
serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut
await serverConn.write(@[5'u8, 6, 7, 8])
await clientConn.write(@[1'u8, 2, 3, 4])
check (await serverConn.read()) == @[1'u8, 2, 3, 4]
check (await clientConn.read()) == @[5'u8, 6, 7, 8]
await allFutures(serverConn.close(), clientConn.close())
await allFutures(dtls1.stop(), dtls2.stop())
await allFutures(stun1.stop(), stun2.stop())
await allFutures(udp1.close(), udp2.close())

View File

@ -21,6 +21,8 @@ logScope:
const DtlsConnTracker* = "webrtc.dtls.conn"
type
DtlsConnOnClose* = proc() {.raises: [], gcsafe.}
MbedTLSCtx = object
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
@ -34,7 +36,7 @@ type
DtlsConn* = ref object
# DtlsConn is a Dtls connection receiving and sending data using
# the underlying Stun Connection
conn*: StunConn # The wrapper protocol Stun Connection
conn: StunConn # The wrapper protocol Stun Connection
raddr: TransportAddress # Remote address
dataRecv: seq[byte] # data received which will be read by SCTP
dataToSend: seq[byte]
@ -43,7 +45,7 @@ type
# Close connection management
closed: bool
closeEvent: AsyncEvent
onClose: seq[DtlsConnOnClose]
# Local and Remote certificate, needed by wrapped protocol DataChannel
# and by libp2p
@ -53,6 +55,9 @@ type
# Mbed-TLS contexts
ctx: MbedTLSCtx
proc isClosed*(self: DtlsConn): bool =
return self.closed
proc getRemoteCertificateCallback(
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
): cint {.cdecl.} =
@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T =
var self = T(conn: conn)
self.raddr = conn.raddr
self.closed = false
self.closeEvent = newAsyncEvent()
return self
proc dtlsConnInit(self: DtlsConn) =
@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
## Wait for the Dtls Connection to be closed
proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) =
## Adds a proc to be called when DtlsConn is closed
##
await self.closeEvent.wait()
self.onClose.add(onCloseProc)
proc dtlsHandshake*(
self: DtlsConn, isServer: bool
@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
untrackCounter(DtlsConnTracker)
self.closeEvent.fire()
await self.conn.close()
for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]
proc write*(
self: DtlsConn, msg: seq[byte]

View File

@ -28,12 +28,8 @@ logScope:
const DtlsTransportTracker* = "webrtc.dtls.transport"
type
DtlsConnAndCleanup = object
connection: DtlsConn
cleanup: Future[void].Raising([])
Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConnAndCleanup]
connections: Table[TransportAddress, DtlsConn]
transport: Stun
laddr: TransportAddress
started: bool
@ -46,7 +42,7 @@ type
proc new*(T: type Dtls, transport: Stun): T =
var self = T(
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
connections: initTable[TransportAddress, DtlsConn](),
transport: transport,
laddr: transport.laddr,
started: true,
@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
self.started = false
let
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
allCloses = toSeq(self.connections.values()).mapIt(it.close())
await noCancel allFutures(allCloses)
await noCancel allFutures(allCleanup)
untrackCounter(DtlsTransportTracker)
proc localCertificate*(self: Dtls): seq[byte] =
@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
proc localAddress*(self: Dtls): TransportAddress =
self.laddr
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
except CancelledError as exc:
discard
self.connections.del(conn.remoteAddress())
proc addConnToTable(self: Dtls, conn: DtlsConn) =
proc cleanup() =
self.connections.del(conn.remoteAddress())
self.connections[conn.remoteAddress()] = conn
conn.addOnClose(cleanup)
proc accept*(
self: Dtls
@ -114,8 +105,7 @@ proc accept*(
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
)
await res.dtlsHandshake(true)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
break
except WebRtcError as exc:
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
@ -136,8 +126,7 @@ proc connect*(
try:
await res.dtlsHandshake(false)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
except WebRtcError as exc:
trace "Handshake fails", raddr, error = exc.msg
self.connections.del(raddr)

View File

@ -28,6 +28,7 @@ type
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
StunConnOnClose* = proc() {.raises: [], gcsafe.}
StunConn* = ref object
udp*: UdpTransport # The wrapper protocol: UDP Transport
@ -37,8 +38,10 @@ type
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
# processed by the stun message handler
handlesFut*: Future[void] # Stun Message handler
closeEvent: AsyncEvent
# Close connection management
closed*: bool
onClose: seq[StunConnOnClose]
# Is ice-controlling and iceTiebreaker, not fully implemented yet.
iceControlling: bool
@ -201,7 +204,6 @@ proc new*(
laddr: udp.laddr,
raddr: raddr,
closed: false,
closeEvent: newAsyncEvent(),
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
iceControlling: iceControlling,
@ -215,10 +217,10 @@ proc new*(
trackCounter(StunConnectionTracker)
return self
proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
## Wait for the Stun Connection to be closed
proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) =
## Adds a proc to be called when StunConn is closed
##
await self.closeEvent.wait()
self.onClose.add(onCloseProc)
proc close*(self: StunConn) {.async: (raises: []).} =
## Close a Stun Connection
@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
debug "Try to close an already closed StunConn"
return
await self.handlesFut.cancelAndWait()
self.closeEvent.fire()
for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]
self.closed = true
untrackCounter(StunConnectionTracker)

View File

@ -32,6 +32,12 @@ type
rng: ref HmacDrbgContext
proc addConnToTable(self: Stun, conn: StunConn) =
proc cleanup() =
self.connections.del(conn.raddr)
self.connections[conn.raddr] = conn
conn.addOnClose(cleanup)
proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
## Accept a Stun Connection
##
@ -53,17 +59,9 @@ proc connect*(
do:
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = res
self.addConnToTable(res)
return res
proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
self.connections.del(conn.raddr)
except CancelledError as exc:
warn "Error cleaning up Stun Connection", error=exc.msg
proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
while true:
let (buf, raddr) = await self.udp.read()
@ -71,9 +69,8 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
if not self.connections.hasKey(raddr):
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = stunConn
self.addConnToTable(stunConn)
await self.pendingConn.addLast(stunConn)
asyncSpawn self.cleanupStunConn(stunConn)
else:
try:
stunConn = self.connections[raddr]