fix: stun and dtls close (#22)

* fix: stun & dtls close

* chore: change the test names & variable names

* fix: add isClosed to DtlsConn

* feat: remove `closeEvent`/`join`/`cleanup` and replace those by `onClose` sequence of proc to be executed on close
This commit is contained in:
Ludovic Chenut 2024-08-30 12:19:09 +02:00 committed by GitHub
parent e9bda6babf
commit 26cad2a383
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 45 deletions

View File

@ -81,3 +81,40 @@ suite "DTLS":
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop()) await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
await allFutures(stun1.stop(), stun2.stop(), stun3.stop()) await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
await allFutures(udp1.close(), udp2.close(), udp3.close()) await allFutures(udp1.close(), udp2.close(), udp3.close())
asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes":
# Related to https://github.com/vacp2p/nim-webrtc/pull/22
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
var
serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut
await serverConn.write(@[1'u8, 2, 3, 4])
await clientConn.write(@[5'u8, 6, 7, 8])
check (await serverConn.read()) == @[5'u8, 6, 7, 8]
check (await clientConn.read()) == @[1'u8, 2, 3, 4]
await allFutures(serverConn.close(), clientConn.close())
check serverConn.isClosed() and clientConn.isClosed()
serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut
await serverConn.write(@[5'u8, 6, 7, 8])
await clientConn.write(@[1'u8, 2, 3, 4])
check (await serverConn.read()) == @[1'u8, 2, 3, 4]
check (await clientConn.read()) == @[5'u8, 6, 7, 8]
await allFutures(serverConn.close(), clientConn.close())
await allFutures(dtls1.stop(), dtls2.stop())
await allFutures(stun1.stop(), stun2.stop())
await allFutures(udp1.close(), udp2.close())

View File

@ -21,6 +21,8 @@ logScope:
const DtlsConnTracker* = "webrtc.dtls.conn" const DtlsConnTracker* = "webrtc.dtls.conn"
type type
DtlsConnOnClose* = proc() {.raises: [], gcsafe.}
MbedTLSCtx = object MbedTLSCtx = object
ssl: mbedtls_ssl_context ssl: mbedtls_ssl_context
config: mbedtls_ssl_config config: mbedtls_ssl_config
@ -34,7 +36,7 @@ type
DtlsConn* = ref object DtlsConn* = ref object
# DtlsConn is a Dtls connection receiving and sending data using # DtlsConn is a Dtls connection receiving and sending data using
# the underlying Stun Connection # the underlying Stun Connection
conn*: StunConn # The wrapper protocol Stun Connection conn: StunConn # The wrapper protocol Stun Connection
raddr: TransportAddress # Remote address raddr: TransportAddress # Remote address
dataRecv: seq[byte] # data received which will be read by SCTP dataRecv: seq[byte] # data received which will be read by SCTP
dataToSend: seq[byte] dataToSend: seq[byte]
@ -43,7 +45,7 @@ type
# Close connection management # Close connection management
closed: bool closed: bool
closeEvent: AsyncEvent onClose: seq[DtlsConnOnClose]
# Local and Remote certificate, needed by wrapped protocol DataChannel # Local and Remote certificate, needed by wrapped protocol DataChannel
# and by libp2p # and by libp2p
@ -53,6 +55,9 @@ type
# Mbed-TLS contexts # Mbed-TLS contexts
ctx: MbedTLSCtx ctx: MbedTLSCtx
proc isClosed*(self: DtlsConn): bool =
return self.closed
proc getRemoteCertificateCallback( proc getRemoteCertificateCallback(
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32 ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
): cint {.cdecl.} = ): cint {.cdecl.} =
@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T =
var self = T(conn: conn) var self = T(conn: conn)
self.raddr = conn.raddr self.raddr = conn.raddr
self.closed = false self.closed = false
self.closeEvent = newAsyncEvent()
return self return self
proc dtlsConnInit(self: DtlsConn) = proc dtlsConnInit(self: DtlsConn) =
@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
except MbedTLSError as exc: except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc) raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} = proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) =
## Wait for the Dtls Connection to be closed ## Adds a proc to be called when DtlsConn is closed
## ##
await self.closeEvent.wait() self.onClose.add(onCloseProc)
proc dtlsHandshake*( proc dtlsHandshake*(
self: DtlsConn, isServer: bool self: DtlsConn, isServer: bool
@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
await self.conn.write(self.dataToSend) await self.conn.write(self.dataToSend)
self.dataToSend = @[] self.dataToSend = @[]
untrackCounter(DtlsConnTracker) untrackCounter(DtlsConnTracker)
self.closeEvent.fire() await self.conn.close()
for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]
proc write*( proc write*(
self: DtlsConn, msg: seq[byte] self: DtlsConn, msg: seq[byte]

View File

@ -28,12 +28,8 @@ logScope:
const DtlsTransportTracker* = "webrtc.dtls.transport" const DtlsTransportTracker* = "webrtc.dtls.transport"
type type
DtlsConnAndCleanup = object
connection: DtlsConn
cleanup: Future[void].Raising([])
Dtls* = ref object of RootObj Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConnAndCleanup] connections: Table[TransportAddress, DtlsConn]
transport: Stun transport: Stun
laddr: TransportAddress laddr: TransportAddress
started: bool started: bool
@ -46,7 +42,7 @@ type
proc new*(T: type Dtls, transport: Stun): T = proc new*(T: type Dtls, transport: Stun): T =
var self = T( var self = T(
connections: initTable[TransportAddress, DtlsConnAndCleanup](), connections: initTable[TransportAddress, DtlsConn](),
transport: transport, transport: transport,
laddr: transport.laddr, laddr: transport.laddr,
started: true, started: true,
@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
self.started = false self.started = false
let let
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close()) allCloses = toSeq(self.connections.values()).mapIt(it.close())
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
await noCancel allFutures(allCloses) await noCancel allFutures(allCloses)
await noCancel allFutures(allCleanup)
untrackCounter(DtlsTransportTracker) untrackCounter(DtlsTransportTracker)
proc localCertificate*(self: Dtls): seq[byte] = proc localCertificate*(self: Dtls): seq[byte] =
@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
proc localAddress*(self: Dtls): TransportAddress = proc localAddress*(self: Dtls): TransportAddress =
self.laddr self.laddr
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} = proc addConnToTable(self: Dtls, conn: DtlsConn) =
# Waiting for a connection to be closed to remove it from the table proc cleanup() =
try:
await conn.join()
except CancelledError as exc:
discard
self.connections.del(conn.remoteAddress()) self.connections.del(conn.remoteAddress())
self.connections[conn.remoteAddress()] = conn
conn.addOnClose(cleanup)
proc accept*( proc accept*(
self: Dtls self: Dtls
@ -114,8 +105,7 @@ proc accept*(
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
) )
await res.dtlsHandshake(true) await res.dtlsHandshake(true)
self.connections[raddr] = self.addConnToTable(res)
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
break break
except WebRtcError as exc: except WebRtcError as exc:
trace "Handshake fails, try accept another connection", raddr, error = exc.msg trace "Handshake fails, try accept another connection", raddr, error = exc.msg
@ -136,8 +126,7 @@ proc connect*(
try: try:
await res.dtlsHandshake(false) await res.dtlsHandshake(false)
self.connections[raddr] = self.addConnToTable(res)
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
except WebRtcError as exc: except WebRtcError as exc:
trace "Handshake fails", raddr, error = exc.msg trace "Handshake fails", raddr, error = exc.msg
self.connections.del(raddr) self.connections.del(raddr)

View File

@ -28,6 +28,7 @@ type
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.} StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.} StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.} StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
StunConnOnClose* = proc() {.raises: [], gcsafe.}
StunConn* = ref object StunConn* = ref object
udp*: UdpTransport # The wrapper protocol: UDP Transport udp*: UdpTransport # The wrapper protocol: UDP Transport
@ -37,8 +38,10 @@ type
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
# processed by the stun message handler # processed by the stun message handler
handlesFut*: Future[void] # Stun Message handler handlesFut*: Future[void] # Stun Message handler
closeEvent: AsyncEvent
# Close connection management
closed*: bool closed*: bool
onClose: seq[StunConnOnClose]
# Is ice-controlling and iceTiebreaker, not fully implemented yet. # Is ice-controlling and iceTiebreaker, not fully implemented yet.
iceControlling: bool iceControlling: bool
@ -201,7 +204,6 @@ proc new*(
laddr: udp.laddr, laddr: udp.laddr,
raddr: raddr, raddr: raddr,
closed: false, closed: false,
closeEvent: newAsyncEvent(),
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
iceControlling: iceControlling, iceControlling: iceControlling,
@ -215,10 +217,10 @@ proc new*(
trackCounter(StunConnectionTracker) trackCounter(StunConnectionTracker)
return self return self
proc join*(self: StunConn) {.async: (raises: [CancelledError]).} = proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) =
## Wait for the Stun Connection to be closed ## Adds a proc to be called when StunConn is closed
## ##
await self.closeEvent.wait() self.onClose.add(onCloseProc)
proc close*(self: StunConn) {.async: (raises: []).} = proc close*(self: StunConn) {.async: (raises: []).} =
## Close a Stun Connection ## Close a Stun Connection
@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
debug "Try to close an already closed StunConn" debug "Try to close an already closed StunConn"
return return
await self.handlesFut.cancelAndWait() await self.handlesFut.cancelAndWait()
self.closeEvent.fire() for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]
self.closed = true self.closed = true
untrackCounter(StunConnectionTracker) untrackCounter(StunConnectionTracker)

View File

@ -32,6 +32,12 @@ type
rng: ref HmacDrbgContext rng: ref HmacDrbgContext
proc addConnToTable(self: Stun, conn: StunConn) =
proc cleanup() =
self.connections.del(conn.raddr)
self.connections[conn.raddr] = conn
conn.addOnClose(cleanup)
proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} = proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
## Accept a Stun Connection ## Accept a Stun Connection
## ##
@ -53,17 +59,9 @@ proc connect*(
do: do:
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider, let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng) self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = res self.addConnToTable(res)
return res return res
proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
self.connections.del(conn.raddr)
except CancelledError as exc:
warn "Error cleaning up Stun Connection", error=exc.msg
proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} = proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
while true: while true:
let (buf, raddr) = await self.udp.read() let (buf, raddr) = await self.udp.read()
@ -71,9 +69,8 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
if not self.connections.hasKey(raddr): if not self.connections.hasKey(raddr):
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider, stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng) self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = stunConn self.addConnToTable(stunConn)
await self.pendingConn.addLast(stunConn) await self.pendingConn.addLast(stunConn)
asyncSpawn self.cleanupStunConn(stunConn)
else: else:
try: try:
stunConn = self.connections[raddr] stunConn = self.connections[raddr]