mirror of https://github.com/vacp2p/nim-webrtc.git
fix: stun and dtls close (#22)
* fix: stun & dtls close * chore: change the test names & variable names * fix: add isClosed to DtlsConn * feat: remove `closeEvent`/`join`/`cleanup` and replace those by `onClose` sequence of proc to be executed on close
This commit is contained in:
parent
e9bda6babf
commit
26cad2a383
|
@ -81,3 +81,40 @@ suite "DTLS":
|
||||||
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
|
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
|
||||||
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
|
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
|
||||||
await allFutures(udp1.close(), udp2.close(), udp3.close())
|
await allFutures(udp1.close(), udp2.close(), udp3.close())
|
||||||
|
|
||||||
|
asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes":
|
||||||
|
# Related to https://github.com/vacp2p/nim-webrtc/pull/22
|
||||||
|
let
|
||||||
|
localAddr1 = initTAddress("127.0.0.1:4444")
|
||||||
|
localAddr2 = initTAddress("127.0.0.1:5555")
|
||||||
|
udp1 = UdpTransport.new(localAddr1)
|
||||||
|
udp2 = UdpTransport.new(localAddr2)
|
||||||
|
stun1 = Stun.new(udp1)
|
||||||
|
stun2 = Stun.new(udp2)
|
||||||
|
dtls1 = Dtls.new(stun1)
|
||||||
|
dtls2 = Dtls.new(stun2)
|
||||||
|
var
|
||||||
|
serverConnFut = dtls1.accept()
|
||||||
|
clientConn = await dtls2.connect(localAddr1)
|
||||||
|
serverConn = await serverConnFut
|
||||||
|
|
||||||
|
await serverConn.write(@[1'u8, 2, 3, 4])
|
||||||
|
await clientConn.write(@[5'u8, 6, 7, 8])
|
||||||
|
check (await serverConn.read()) == @[5'u8, 6, 7, 8]
|
||||||
|
check (await clientConn.read()) == @[1'u8, 2, 3, 4]
|
||||||
|
await allFutures(serverConn.close(), clientConn.close())
|
||||||
|
check serverConn.isClosed() and clientConn.isClosed()
|
||||||
|
|
||||||
|
serverConnFut = dtls1.accept()
|
||||||
|
clientConn = await dtls2.connect(localAddr1)
|
||||||
|
serverConn = await serverConnFut
|
||||||
|
|
||||||
|
await serverConn.write(@[5'u8, 6, 7, 8])
|
||||||
|
await clientConn.write(@[1'u8, 2, 3, 4])
|
||||||
|
check (await serverConn.read()) == @[1'u8, 2, 3, 4]
|
||||||
|
check (await clientConn.read()) == @[5'u8, 6, 7, 8]
|
||||||
|
|
||||||
|
await allFutures(serverConn.close(), clientConn.close())
|
||||||
|
await allFutures(dtls1.stop(), dtls2.stop())
|
||||||
|
await allFutures(stun1.stop(), stun2.stop())
|
||||||
|
await allFutures(udp1.close(), udp2.close())
|
||||||
|
|
|
@ -21,6 +21,8 @@ logScope:
|
||||||
const DtlsConnTracker* = "webrtc.dtls.conn"
|
const DtlsConnTracker* = "webrtc.dtls.conn"
|
||||||
|
|
||||||
type
|
type
|
||||||
|
DtlsConnOnClose* = proc() {.raises: [], gcsafe.}
|
||||||
|
|
||||||
MbedTLSCtx = object
|
MbedTLSCtx = object
|
||||||
ssl: mbedtls_ssl_context
|
ssl: mbedtls_ssl_context
|
||||||
config: mbedtls_ssl_config
|
config: mbedtls_ssl_config
|
||||||
|
@ -34,7 +36,7 @@ type
|
||||||
DtlsConn* = ref object
|
DtlsConn* = ref object
|
||||||
# DtlsConn is a Dtls connection receiving and sending data using
|
# DtlsConn is a Dtls connection receiving and sending data using
|
||||||
# the underlying Stun Connection
|
# the underlying Stun Connection
|
||||||
conn*: StunConn # The wrapper protocol Stun Connection
|
conn: StunConn # The wrapper protocol Stun Connection
|
||||||
raddr: TransportAddress # Remote address
|
raddr: TransportAddress # Remote address
|
||||||
dataRecv: seq[byte] # data received which will be read by SCTP
|
dataRecv: seq[byte] # data received which will be read by SCTP
|
||||||
dataToSend: seq[byte]
|
dataToSend: seq[byte]
|
||||||
|
@ -43,7 +45,7 @@ type
|
||||||
|
|
||||||
# Close connection management
|
# Close connection management
|
||||||
closed: bool
|
closed: bool
|
||||||
closeEvent: AsyncEvent
|
onClose: seq[DtlsConnOnClose]
|
||||||
|
|
||||||
# Local and Remote certificate, needed by wrapped protocol DataChannel
|
# Local and Remote certificate, needed by wrapped protocol DataChannel
|
||||||
# and by libp2p
|
# and by libp2p
|
||||||
|
@ -53,6 +55,9 @@ type
|
||||||
# Mbed-TLS contexts
|
# Mbed-TLS contexts
|
||||||
ctx: MbedTLSCtx
|
ctx: MbedTLSCtx
|
||||||
|
|
||||||
|
proc isClosed*(self: DtlsConn): bool =
|
||||||
|
return self.closed
|
||||||
|
|
||||||
proc getRemoteCertificateCallback(
|
proc getRemoteCertificateCallback(
|
||||||
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
|
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
|
||||||
): cint {.cdecl.} =
|
): cint {.cdecl.} =
|
||||||
|
@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T =
|
||||||
var self = T(conn: conn)
|
var self = T(conn: conn)
|
||||||
self.raddr = conn.raddr
|
self.raddr = conn.raddr
|
||||||
self.closed = false
|
self.closed = false
|
||||||
self.closeEvent = newAsyncEvent()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
proc dtlsConnInit(self: DtlsConn) =
|
proc dtlsConnInit(self: DtlsConn) =
|
||||||
|
@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
|
||||||
except MbedTLSError as exc:
|
except MbedTLSError as exc:
|
||||||
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
|
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)
|
||||||
|
|
||||||
proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
|
proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) =
|
||||||
## Wait for the Dtls Connection to be closed
|
## Adds a proc to be called when DtlsConn is closed
|
||||||
##
|
##
|
||||||
await self.closeEvent.wait()
|
self.onClose.add(onCloseProc)
|
||||||
|
|
||||||
proc dtlsHandshake*(
|
proc dtlsHandshake*(
|
||||||
self: DtlsConn, isServer: bool
|
self: DtlsConn, isServer: bool
|
||||||
|
@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
|
||||||
await self.conn.write(self.dataToSend)
|
await self.conn.write(self.dataToSend)
|
||||||
self.dataToSend = @[]
|
self.dataToSend = @[]
|
||||||
untrackCounter(DtlsConnTracker)
|
untrackCounter(DtlsConnTracker)
|
||||||
self.closeEvent.fire()
|
await self.conn.close()
|
||||||
|
for onCloseProc in self.onClose:
|
||||||
|
onCloseProc()
|
||||||
|
self.onClose = @[]
|
||||||
|
|
||||||
proc write*(
|
proc write*(
|
||||||
self: DtlsConn, msg: seq[byte]
|
self: DtlsConn, msg: seq[byte]
|
||||||
|
|
|
@ -28,12 +28,8 @@ logScope:
|
||||||
const DtlsTransportTracker* = "webrtc.dtls.transport"
|
const DtlsTransportTracker* = "webrtc.dtls.transport"
|
||||||
|
|
||||||
type
|
type
|
||||||
DtlsConnAndCleanup = object
|
|
||||||
connection: DtlsConn
|
|
||||||
cleanup: Future[void].Raising([])
|
|
||||||
|
|
||||||
Dtls* = ref object of RootObj
|
Dtls* = ref object of RootObj
|
||||||
connections: Table[TransportAddress, DtlsConnAndCleanup]
|
connections: Table[TransportAddress, DtlsConn]
|
||||||
transport: Stun
|
transport: Stun
|
||||||
laddr: TransportAddress
|
laddr: TransportAddress
|
||||||
started: bool
|
started: bool
|
||||||
|
@ -46,7 +42,7 @@ type
|
||||||
|
|
||||||
proc new*(T: type Dtls, transport: Stun): T =
|
proc new*(T: type Dtls, transport: Stun): T =
|
||||||
var self = T(
|
var self = T(
|
||||||
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
|
connections: initTable[TransportAddress, DtlsConn](),
|
||||||
transport: transport,
|
transport: transport,
|
||||||
laddr: transport.laddr,
|
laddr: transport.laddr,
|
||||||
started: true,
|
started: true,
|
||||||
|
@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =
|
||||||
|
|
||||||
self.started = false
|
self.started = false
|
||||||
let
|
let
|
||||||
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
|
allCloses = toSeq(self.connections.values()).mapIt(it.close())
|
||||||
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
|
|
||||||
await noCancel allFutures(allCloses)
|
await noCancel allFutures(allCloses)
|
||||||
await noCancel allFutures(allCleanup)
|
|
||||||
untrackCounter(DtlsTransportTracker)
|
untrackCounter(DtlsTransportTracker)
|
||||||
|
|
||||||
proc localCertificate*(self: Dtls): seq[byte] =
|
proc localCertificate*(self: Dtls): seq[byte] =
|
||||||
|
@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
|
||||||
proc localAddress*(self: Dtls): TransportAddress =
|
proc localAddress*(self: Dtls): TransportAddress =
|
||||||
self.laddr
|
self.laddr
|
||||||
|
|
||||||
proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
|
proc addConnToTable(self: Dtls, conn: DtlsConn) =
|
||||||
# Waiting for a connection to be closed to remove it from the table
|
proc cleanup() =
|
||||||
try:
|
|
||||||
await conn.join()
|
|
||||||
except CancelledError as exc:
|
|
||||||
discard
|
|
||||||
|
|
||||||
self.connections.del(conn.remoteAddress())
|
self.connections.del(conn.remoteAddress())
|
||||||
|
self.connections[conn.remoteAddress()] = conn
|
||||||
|
conn.addOnClose(cleanup)
|
||||||
|
|
||||||
proc accept*(
|
proc accept*(
|
||||||
self: Dtls
|
self: Dtls
|
||||||
|
@ -114,8 +105,7 @@ proc accept*(
|
||||||
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
|
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
|
||||||
)
|
)
|
||||||
await res.dtlsHandshake(true)
|
await res.dtlsHandshake(true)
|
||||||
self.connections[raddr] =
|
self.addConnToTable(res)
|
||||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
|
||||||
break
|
break
|
||||||
except WebRtcError as exc:
|
except WebRtcError as exc:
|
||||||
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
|
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
|
||||||
|
@ -136,8 +126,7 @@ proc connect*(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await res.dtlsHandshake(false)
|
await res.dtlsHandshake(false)
|
||||||
self.connections[raddr] =
|
self.addConnToTable(res)
|
||||||
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
|
|
||||||
except WebRtcError as exc:
|
except WebRtcError as exc:
|
||||||
trace "Handshake fails", raddr, error = exc.msg
|
trace "Handshake fails", raddr, error = exc.msg
|
||||||
self.connections.del(raddr)
|
self.connections.del(raddr)
|
||||||
|
|
|
@ -28,6 +28,7 @@ type
|
||||||
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
|
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
|
||||||
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
|
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
|
||||||
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
|
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
|
||||||
|
StunConnOnClose* = proc() {.raises: [], gcsafe.}
|
||||||
|
|
||||||
StunConn* = ref object
|
StunConn* = ref object
|
||||||
udp*: UdpTransport # The wrapper protocol: UDP Transport
|
udp*: UdpTransport # The wrapper protocol: UDP Transport
|
||||||
|
@ -37,8 +38,10 @@ type
|
||||||
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
|
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
|
||||||
# processed by the stun message handler
|
# processed by the stun message handler
|
||||||
handlesFut*: Future[void] # Stun Message handler
|
handlesFut*: Future[void] # Stun Message handler
|
||||||
closeEvent: AsyncEvent
|
|
||||||
|
# Close connection management
|
||||||
closed*: bool
|
closed*: bool
|
||||||
|
onClose: seq[StunConnOnClose]
|
||||||
|
|
||||||
# Is ice-controlling and iceTiebreaker, not fully implemented yet.
|
# Is ice-controlling and iceTiebreaker, not fully implemented yet.
|
||||||
iceControlling: bool
|
iceControlling: bool
|
||||||
|
@ -201,7 +204,6 @@ proc new*(
|
||||||
laddr: udp.laddr,
|
laddr: udp.laddr,
|
||||||
raddr: raddr,
|
raddr: raddr,
|
||||||
closed: false,
|
closed: false,
|
||||||
closeEvent: newAsyncEvent(),
|
|
||||||
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
||||||
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
|
||||||
iceControlling: iceControlling,
|
iceControlling: iceControlling,
|
||||||
|
@ -215,10 +217,10 @@ proc new*(
|
||||||
trackCounter(StunConnectionTracker)
|
trackCounter(StunConnectionTracker)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
|
proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) =
|
||||||
## Wait for the Stun Connection to be closed
|
## Adds a proc to be called when StunConn is closed
|
||||||
##
|
##
|
||||||
await self.closeEvent.wait()
|
self.onClose.add(onCloseProc)
|
||||||
|
|
||||||
proc close*(self: StunConn) {.async: (raises: []).} =
|
proc close*(self: StunConn) {.async: (raises: []).} =
|
||||||
## Close a Stun Connection
|
## Close a Stun Connection
|
||||||
|
@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
|
||||||
debug "Try to close an already closed StunConn"
|
debug "Try to close an already closed StunConn"
|
||||||
return
|
return
|
||||||
await self.handlesFut.cancelAndWait()
|
await self.handlesFut.cancelAndWait()
|
||||||
self.closeEvent.fire()
|
for onCloseProc in self.onClose:
|
||||||
|
onCloseProc()
|
||||||
|
self.onClose = @[]
|
||||||
self.closed = true
|
self.closed = true
|
||||||
untrackCounter(StunConnectionTracker)
|
untrackCounter(StunConnectionTracker)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,12 @@ type
|
||||||
|
|
||||||
rng: ref HmacDrbgContext
|
rng: ref HmacDrbgContext
|
||||||
|
|
||||||
|
proc addConnToTable(self: Stun, conn: StunConn) =
|
||||||
|
proc cleanup() =
|
||||||
|
self.connections.del(conn.raddr)
|
||||||
|
self.connections[conn.raddr] = conn
|
||||||
|
conn.addOnClose(cleanup)
|
||||||
|
|
||||||
proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
|
proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
|
||||||
## Accept a Stun Connection
|
## Accept a Stun Connection
|
||||||
##
|
##
|
||||||
|
@ -53,17 +59,9 @@ proc connect*(
|
||||||
do:
|
do:
|
||||||
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
|
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
|
||||||
self.usernameChecker, self.passwordProvider, self.rng)
|
self.usernameChecker, self.passwordProvider, self.rng)
|
||||||
self.connections[raddr] = res
|
self.addConnToTable(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
|
|
||||||
# Waiting for a connection to be closed to remove it from the table
|
|
||||||
try:
|
|
||||||
await conn.join()
|
|
||||||
self.connections.del(conn.raddr)
|
|
||||||
except CancelledError as exc:
|
|
||||||
warn "Error cleaning up Stun Connection", error=exc.msg
|
|
||||||
|
|
||||||
proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
||||||
while true:
|
while true:
|
||||||
let (buf, raddr) = await self.udp.read()
|
let (buf, raddr) = await self.udp.read()
|
||||||
|
@ -71,9 +69,8 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
|
||||||
if not self.connections.hasKey(raddr):
|
if not self.connections.hasKey(raddr):
|
||||||
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
|
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
|
||||||
self.usernameChecker, self.passwordProvider, self.rng)
|
self.usernameChecker, self.passwordProvider, self.rng)
|
||||||
self.connections[raddr] = stunConn
|
self.addConnToTable(stunConn)
|
||||||
await self.pendingConn.addLast(stunConn)
|
await self.pendingConn.addLast(stunConn)
|
||||||
asyncSpawn self.cleanupStunConn(stunConn)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
stunConn = self.connections[raddr]
|
stunConn = self.connections[raddr]
|
||||||
|
|
Loading…
Reference in New Issue