mirror of https://github.com/vacp2p/nim-webrtc.git
fix: stun and dtls close (#22)
* fix: stun & dtls close * chore: change the test names & variable names * fix: add isClosed to DtlsConn * feat: remove `closeEvent`/`join`/`cleanup` and replace those by `onClose` sequence of proc to be executed on close
This commit is contained in:
parent
e9bda6babf
commit
26cad2a383
|
@ -81,3 +81,40 @@ suite "DTLS":
|
|||
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
|
||||
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
|
||||
await allFutures(udp1.close(), udp2.close(), udp3.close())
|
||||
|
||||
asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes":
|
||||
# Related to https://github.com/vacp2p/nim-webrtc/pull/22
|
||||
let
|
||||
localAddr1 = initTAddress("127.0.0.1:4444")
|
||||
localAddr2 = initTAddress("127.0.0.1:5555")
|
||||
udp1 = UdpTransport.new(localAddr1)
|
||||
udp2 = UdpTransport.new(localAddr2)
|
||||
stun1 = Stun.new(udp1)
|
||||
stun2 = Stun.new(udp2)
|
||||
dtls1 = Dtls.new(stun1)
|
||||
dtls2 = Dtls.new(stun2)
|
||||
var
|
||||
serverConnFut = dtls1.accept()
|
||||
clientConn = await dtls2.connect(localAddr1)
|
||||
serverConn = await serverConnFut
|
||||
|
||||
await serverConn.write(@[1'u8, 2, 3, 4])
|
||||
await clientConn.write(@[5'u8, 6, 7, 8])
|
||||
check (await serverConn.read()) == @[5'u8, 6, 7, 8]
|
||||
check (await clientConn.read()) == @[1'u8, 2, 3, 4]
|
||||
await allFutures(serverConn.close(), clientConn.close())
|
||||
check serverConn.isClosed() and clientConn.isClosed()
|
||||
|
||||
serverConnFut = dtls1.accept()
|
||||
clientConn = await dtls2.connect(localAddr1)
|
||||
serverConn = await serverConnFut
|
||||
|
||||
await serverConn.write(@[5'u8, 6, 7, 8])
|
||||
await clientConn.write(@[1'u8, 2, 3, 4])
|
||||
check (await serverConn.read()) == @[1'u8, 2, 3, 4]
|
||||
check (await clientConn.read()) == @[5'u8, 6, 7, 8]
|
||||
|
||||
await allFutures(serverConn.close(), clientConn.close())
|
||||
await allFutures(dtls1.stop(), dtls2.stop())
|
||||
await allFutures(stun1.stop(), stun2.stop())
|
||||
await allFutures(udp1.close(), udp2.close())
|
||||
|
|
|
@ -21,6 +21,8 @@ logScope:
|
|||
const DtlsConnTracker* = "webrtc.dtls.conn"
|
||||
|
||||
type
|
||||
DtlsConnOnClose* = proc() {.raises: [], gcsafe.}
|
||||
|
||||
MbedTLSCtx = object
|
||||
ssl: mbedtls_ssl_context
|
||||
config: mbedtls_ssl_config
|
||||
|
@ -34,7 +36,7 @@ type
|
|||
DtlsConn* = ref object
|
||||
# DtlsConn is a Dtls connection receiving and sending data using
|
||||
# the underlying Stun Connection
|
||||
conn*: StunConn # The wrapper protocol Stun Connection
|
||||
conn: StunConn # The wrapper protocol Stun Connection
|
||||
raddr: TransportAddress # Remote address
|
||||
dataRecv: seq[byte] # data received which will be read by SCTP
|
||||
dataToSend: seq[byte]
|
||||
|
@ -43,7 +45,7 @@ type
|
|||
|
||||
# Close connection management
|
||||
closed: bool
|
||||
closeEvent: AsyncEvent
|
||||
onClose: seq[DtlsConnOnClose]
|
||||
|
||||
# Local and Remote certificate, needed by wrapped protocol DataChannel
|
||||
# and by libp2p
|
||||
|
@ -53,6 +55,9 @@ type
|
|||
# Mbed-TLS contexts
|
||||
ctx: MbedTLSCtx
|
||||
|
||||
proc isClosed*(self: DtlsConn): bool =
|
||||
return self.closed
|
||||
|
||||
proc getRemoteCertificateCallback(
|
||||
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
|
||||
): cint {.cdecl.} =
|
||||
|
@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T =
|
|||
var self = T(conn: conn)
|
||||
self.raddr = conn.raddr
|
||||
self.closed = false
|
||||
self.closeEvent = newAsyncEvent()
|
||||
return self
|
||||
|
||||
proc dtlsConnInit(self: DtlsConn) =
|
||||
|
@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
|
|||
except MbedTLSError as exc:
|
||||
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
|
||||
|
||||
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
|
||||
## Wait for the Dtls Connection to be closed
|
||||
proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) =
|
||||
## Adds a proc to be called when DtlsConn is closed
|
||||
##
|
||||
await self.closeEvent.wait()
|
||||
self.onClose.add(onCloseProc)
|
||||
|
||||
proc dtlsHandshake*(
|
||||
self: DtlsConn, isServer: bool
|
||||
|
@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
|
|||
await self.conn.write(self.dataToSend)
|
||||
self.dataToSend = @[]
|
||||
untrackCounter(DtlsConnTracker)
|
||||
self.closeEvent.fire()
|
||||
await self.conn.close()
|
||||
for onCloseProc in self.onClose:
|
||||
onCloseProc()
|
||||
self.onClose = @[]
|
||||
|
||||
proc write*(
|
||||
self: DtlsConn, msg: seq[byte]
|
||||
|
|
|
@ -28,12 +28,8 @@ logScope:
|
|||
const DtlsTransportTracker* = "webrtc.dtls.transport"
|
||||
|
||||
type
|
||||
DtlsConnAndCleanup = object
|
||||
connection: DtlsConn
|
||||
cleanup: Future[void].Raising([])
|
||||
|
||||
Dtls* = ref object of RootObj
|
||||
connections: Table[TransportAddress, DtlsConnAndCleanup]
|
||||
connections: Table[TransportAddress, DtlsConn]
|
||||
transport: Stun
|
||||
laddr: TransportAddress
|
||||
started: bool
|
||||
|
@ -46,7 +42,7 @@ type
|
|||
|
||||
proc new*(T: type Dtls, transport: Stun): T =
|
||||
var self = T(
|
||||
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
|
||||
connections: initTable[TransportAddress, DtlsConn](),
|
||||
transport: transport,
|
||||
laddr: transport.laddr,
|
||||
started: true,
|
||||
|
@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
|
|||
|
||||
self.started = false
|
||||
let
|
||||
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
|
||||
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
|
||||
allCloses = toSeq(self.connections.values()).mapIt(it.close())
|
||||
await noCancel allFutures(allCloses)
|
||||
await noCancel allFutures(allCleanup)
|
||||
untrackCounter(DtlsTransportTracker)
|
||||
|
||||
proc localCertificate*(self: Dtls): seq[byte] =
|
||||
|
@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
|
|||
proc localAddress*(self: Dtls): TransportAddress =
|
||||
self.laddr
|
||||
|
||||
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
|
||||
# Waiting for a connection to be closed to remove it from the table
|
||||
try:
|
||||
await conn.join()
|
||||
except CancelledError as exc:
|
||||
discard
|
||||
|
||||
proc addConnToTable(self: Dtls, conn: DtlsConn) =
|
||||
proc cleanup() =
|
||||
self.connections.del(conn.remoteAddress())
|
||||
self.connections[conn.remoteAddress()] = conn
|
||||
conn.addOnClose(cleanup)
|
||||
|
||||
proc accept*(
|
||||
self: Dtls
|
||||
|
@ -114,8 +105,7 @@ proc accept*(
|
|||
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
|
||||
)
|
||||
await res.dtlsHandshake(true)
|
||||
self.connections[raddr] =
|
||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
||||
self.addConnToTable(res)
|
||||
break
|
||||
except WebRtcError as exc:
|
||||
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
|
||||
|
@ -136,8 +126,7 @@ proc connect*(
|
|||
|
||||
try:
|
||||
await res.dtlsHandshake(false)
|
||||
self.connections[raddr] =
|
||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
||||
self.addConnToTable(res)
|
||||
except WebRtcError as exc:
|
||||
trace "Handshake fails", raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
|
|
|
@ -28,6 +28,7 @@ type
|
|||
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
|
||||
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
|
||||
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
|
||||
StunConnOnClose* = proc() {.raises: [], gcsafe.}
|
||||
|
||||
StunConn* = ref object
|
||||
udp*: UdpTransport # The wrapper protocol: UDP Transport
|
||||
|
@ -37,8 +38,10 @@ type
|
|||
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
|
||||
# processed by the stun message handler
|
||||
handlesFut*: Future[void] # Stun Message handler
|
||||
closeEvent: AsyncEvent
|
||||
|
||||
# Close connection management
|
||||
closed*: bool
|
||||
onClose: seq[StunConnOnClose]
|
||||
|
||||
# Is ice-controlling and iceTiebreaker, not fully implemented yet.
|
||||
iceControlling: bool
|
||||
|
@ -201,7 +204,6 @@ proc new*(
|
|||
laddr: udp.laddr,
|
||||
raddr: raddr,
|
||||
closed: false,
|
||||
closeEvent: newAsyncEvent(),
|
||||
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
||||
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
||||
iceControlling: iceControlling,
|
||||
|
@ -215,10 +217,10 @@ proc new*(
|
|||
trackCounter(StunConnectionTracker)
|
||||
return self
|
||||
|
||||
proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
|
||||
## Wait for the Stun Connection to be closed
|
||||
proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) =
|
||||
## Adds a proc to be called when StunConn is closed
|
||||
##
|
||||
await self.closeEvent.wait()
|
||||
self.onClose.add(onCloseProc)
|
||||
|
||||
proc close*(self: StunConn) {.async: (raises: []).} =
|
||||
## Close a Stun Connection
|
||||
|
@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
|
|||
debug "Try to close an already closed StunConn"
|
||||
return
|
||||
await self.handlesFut.cancelAndWait()
|
||||
self.closeEvent.fire()
|
||||
for onCloseProc in self.onClose:
|
||||
onCloseProc()
|
||||
self.onClose = @[]
|
||||
self.closed = true
|
||||
untrackCounter(StunConnectionTracker)
|
||||
|
||||
|
|
|
@ -32,6 +32,12 @@ type
|
|||
|
||||
rng: ref HmacDrbgContext
|
||||
|
||||
proc addConnToTable(self: Stun, conn: StunConn) =
|
||||
proc cleanup() =
|
||||
self.connections.del(conn.raddr)
|
||||
self.connections[conn.raddr] = conn
|
||||
conn.addOnClose(cleanup)
|
||||
|
||||
proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
|
||||
## Accept a Stun Connection
|
||||
##
|
||||
|
@ -53,17 +59,9 @@ proc connect*(
|
|||
do:
|
||||
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
|
||||
self.usernameChecker, self.passwordProvider, self.rng)
|
||||
self.connections[raddr] = res
|
||||
self.addConnToTable(res)
|
||||
return res
|
||||
|
||||
proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
|
||||
# Waiting for a connection to be closed to remove it from the table
|
||||
try:
|
||||
await conn.join()
|
||||
self.connections.del(conn.raddr)
|
||||
except CancelledError as exc:
|
||||
warn "Error cleaning up Stun Connection", error=exc.msg
|
||||
|
||||
proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
||||
while true:
|
||||
let (buf, raddr) = await self.udp.read()
|
||||
|
@ -71,9 +69,8 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
|||
if not self.connections.hasKey(raddr):
|
||||
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
|
||||
self.usernameChecker, self.passwordProvider, self.rng)
|
||||
self.connections[raddr] = stunConn
|
||||
self.addConnToTable(stunConn)
|
||||
await self.pendingConn.addLast(stunConn)
|
||||
asyncSpawn self.cleanupStunConn(stunConn)
|
||||
else:
|
||||
try:
|
||||
stunConn = self.connections[raddr]
|
||||
|
|
Loading…
Reference in New Issue