mirror of https://github.com/vacp2p/nim-webrtc.git
Dtls comments + close + minor improvements
This commit is contained in:
parent
4c1eb13926
commit
359a81df4a
|
@ -0,0 +1,23 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/udp_connection
|
||||
import ../webrtc/stun/stun_connection
|
||||
import ../webrtc/dtls/dtls
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4244")
|
||||
let udp = UdpConn()
|
||||
udp.init(laddr)
|
||||
let stun = StunConn()
|
||||
stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.init(stun, laddr)
|
||||
let sctp = Sctp.new(dtls, laddr)
|
||||
let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13)
|
||||
while true:
|
||||
await conn.write("ping".toBytes)
|
||||
let msg = await conn.read()
|
||||
echo "Received: ", string.fromBytes(msg.data)
|
||||
await sleepAsync(1.seconds)
|
||||
|
||||
waitFor(main())
|
|
@ -19,7 +19,7 @@ proc main() {.async.} =
|
|||
let stun = StunConn()
|
||||
stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.start(stun, laddr)
|
||||
dtls.init(stun, laddr)
|
||||
let sctp = Sctp.new(dtls, laddr)
|
||||
sctp.listen(13)
|
||||
while true:
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import times, deques, tables
|
||||
import times, deques, tables, sequtils
|
||||
import chronos, chronicles
|
||||
import ./utils, ../stun/stun_connection
|
||||
|
||||
|
@ -29,11 +29,22 @@ import mbedtls/timing
|
|||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit.
|
||||
# There might be some errors (or crashes) in weird cases with the no wait option
|
||||
# Implementation of a DTLS client and a DTLS Server by using the mbedtls library.
|
||||
# Multiple things here are unintuitive partly because of the callbacks
|
||||
# used by mbedtls and that those callbacks cannot be async.
|
||||
#
|
||||
# TODO:
|
||||
# - Check the viability of the add/pop first/last of the asyncqueue with the limit.
|
||||
# There might be some errors (or crashes) with some edge cases with the no wait option
|
||||
# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE
|
||||
# - Not critical - May be interesting to split Dtls and DtlsConn into two files
|
||||
|
||||
const
|
||||
PendingHandshakeLimit = 1024
|
||||
# This limit is arbitrary, it could be interesting to make it configurable.
|
||||
const PendingHandshakeLimit = 1024
|
||||
|
||||
# -- DtlsConn --
|
||||
# A Dtls connection to a specific IP address recovered by the receiving part of
|
||||
# the Udp "connection"
|
||||
|
||||
type
|
||||
DtlsError* = object of CatchableError
|
||||
|
@ -43,6 +54,8 @@ type
|
|||
raddr*: TransportAddress
|
||||
dataRecv: AsyncQueue[seq[byte]]
|
||||
sendFuture: Future[void]
|
||||
closed: bool
|
||||
closeEvent: AsyncEvent
|
||||
|
||||
timer: mbedtls_timing_delay_context
|
||||
|
||||
|
@ -57,55 +70,99 @@ type
|
|||
localCert: seq[byte]
|
||||
remoteCert: seq[byte]
|
||||
|
||||
proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
var self = cast[DtlsConn](ctx)
|
||||
var toWrite = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr toWrite[0], buf, len)
|
||||
trace "dtls send", len
|
||||
self.sendFuture = self.conn.write(self.raddr, toWrite)
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
let self = cast[DtlsConn](ctx)
|
||||
if self.dataRecv.len() == 0:
|
||||
return MBEDTLS_ERR_SSL_WANT_READ
|
||||
|
||||
var dataRecv = self.dataRecv.popFirstNoWait()
|
||||
copyMem(buf, addr dataRecv[0], dataRecv.len())
|
||||
result = dataRecv.len().cint
|
||||
trace "dtls receive", len, result
|
||||
|
||||
proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
|
||||
proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.dataRecv = newAsyncQueue[seq[byte]]()
|
||||
self.closed = false
|
||||
self.closeEvent = newAsyncEvent()
|
||||
|
||||
proc join(self: DtlsConn) {.async.} =
|
||||
await self.closeEvent.wait()
|
||||
|
||||
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
|
||||
var shouldRead = isServer
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
if isServer:
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
|
||||
mb_ssl_session_reset(self.ssl)
|
||||
shouldRead = isServer
|
||||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
|
||||
proc close*(self: DtlsConn) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to close DtlsConn twice"
|
||||
return
|
||||
|
||||
self.closed = true
|
||||
self.sendFuture = nil
|
||||
# TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls
|
||||
let x = mbedtls_ssl_close_notify(addr self.ssl)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
self.closeEvent.fire()
|
||||
|
||||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to write on an already closed DtlsConn"
|
||||
return
|
||||
var buf = msg
|
||||
try:
|
||||
let sendFuture = newFuture[void]("DtlsConn write")
|
||||
self.sendFuture = nil
|
||||
let write = mb_ssl_write(self.ssl, buf)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
|
||||
except MbedTLSError as exc:
|
||||
trace "Dtls write error", errorMsg = exc.msg
|
||||
raise exc
|
||||
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to read on an already closed DtlsConn"
|
||||
return
|
||||
var res = newSeq[byte](8192)
|
||||
while true:
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
# TODO: exception catching
|
||||
let length = mb_ssl_read(self.ssl, res)
|
||||
# TODO: Find a clear way to use the template `mb_ssl_read` without
|
||||
# messing up things with exception
|
||||
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
if length == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
continue
|
||||
if length < 0:
|
||||
trace "dtls read", error = $(length.cint.mbedtls_high_level_strerr())
|
||||
raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr()))
|
||||
res.setLen(length)
|
||||
return res
|
||||
|
||||
proc close*(self: DtlsConn) {.async.} =
|
||||
discard
|
||||
# -- Dtls --
|
||||
# The Dtls object read every messages from the UdpConn/StunConn and, if the address
|
||||
# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue
|
||||
# to be accepted later, if the address is stored, add the message received to the
|
||||
# corresponding DtlsConn `dataRecv` queue.
|
||||
|
||||
type
|
||||
Dtls* = ref object of RootObj
|
||||
|
@ -130,7 +187,7 @@ proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
|
|||
return
|
||||
aq.addLastNoWait((raddr, buf))
|
||||
|
||||
proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
||||
proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
||||
if self.started:
|
||||
warn "Already started"
|
||||
return
|
||||
|
@ -159,43 +216,16 @@ proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
|||
self.localCert = newSeq[byte](self.serverCert.raw.len)
|
||||
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
|
||||
|
||||
proc stop*(self: Dtls) =
|
||||
proc stop*(self: Dtls) {.async.} =
|
||||
if not self.started:
|
||||
warn "Already stopped"
|
||||
return
|
||||
|
||||
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
|
||||
self.readLoop.cancel()
|
||||
self.started = false
|
||||
|
||||
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
|
||||
var shouldRead = isServer
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
if isServer:
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
if not self.sendFuture.isNil(): await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
|
||||
mb_ssl_session_reset(self.ssl)
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
# -- Remote / Local certificate getter --
|
||||
|
||||
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.remoteCert
|
||||
|
@ -206,8 +236,14 @@ proc localCertificate*(conn: DtlsConn): seq[byte] =
|
|||
proc localCertificate*(self: Dtls): seq[byte] =
|
||||
self.localCert
|
||||
|
||||
# -- MbedTLS Callbacks --
|
||||
|
||||
proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
|
||||
state: cint, pflags: ptr uint32): cint {.cdecl.} =
|
||||
# verify is the procedure called by mbedtls when receiving the remote
|
||||
# certificate. It's usually used to verify the validity of the certificate.
|
||||
# We use this procedure to store the remote certificate as it's mandatory
|
||||
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
|
||||
var self = cast[DtlsConn](ctx)
|
||||
let cert = pcert[]
|
||||
|
||||
|
@ -215,12 +251,45 @@ proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
|
|||
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
|
||||
return 0
|
||||
|
||||
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
|
||||
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
|
||||
# we store the future of this write and await it after the end of the
|
||||
# function (see write or dtlsHanshake for example).
|
||||
var self = cast[DtlsConn](ctx)
|
||||
var toWrite = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr toWrite[0], buf, len)
|
||||
trace "dtls send", len
|
||||
self.sendFuture = self.conn.write(self.raddr, toWrite)
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
|
||||
# As we cannot asynchronously await for data to be received, we use a data received
|
||||
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
|
||||
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
|
||||
let self = cast[DtlsConn](ctx)
|
||||
if self.dataRecv.len() == 0:
|
||||
return MBEDTLS_ERR_SSL_WANT_READ
|
||||
|
||||
var dataRecv = self.dataRecv.popFirstNoWait()
|
||||
copyMem(buf, addr dataRecv[0], dataRecv.len())
|
||||
result = dataRecv.len().cint
|
||||
trace "dtls receive", len, result
|
||||
|
||||
# -- Dtls Accept / Connect procedures --
|
||||
|
||||
proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} =
|
||||
await conn.join()
|
||||
self.connections.del(raddr)
|
||||
|
||||
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
|
||||
await res.init(self.conn, self.laddr)
|
||||
res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
mb_ssl_cookie_init(res.cookie)
|
||||
|
@ -248,8 +317,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
|||
mb_ssl_session_reset(res.ssl)
|
||||
mb_ssl_set_verify(res.ssl, verify, res)
|
||||
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
|
||||
while true:
|
||||
let (raddr, buf) = await self.pendingHandshakes.popFirst()
|
||||
try:
|
||||
|
@ -257,6 +325,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
|||
res.dataRecv.addLastNoWait(buf)
|
||||
self.connections[raddr] = res
|
||||
await res.dtlsHandshake(true)
|
||||
asyncSpawn self.removeConnection(res, raddr)
|
||||
break
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
|
@ -269,7 +338,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
|||
selfvar = self
|
||||
res = DtlsConn()
|
||||
|
||||
await res.init(self.conn, self.laddr)
|
||||
res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
|
||||
|
@ -303,6 +372,7 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
|||
|
||||
try:
|
||||
await res.dtlsHandshake(false)
|
||||
asyncSpawn self.removeConnection(res, raddr)
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
|
|
|
@ -11,7 +11,6 @@ import std/sha1, sequtils, typetraits, std/md5
|
|||
import binary_serialization,
|
||||
stew/byteutils,
|
||||
chronos
|
||||
import ../utils
|
||||
|
||||
# -- Utils --
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ import sctp, datachannel
|
|||
logScope:
|
||||
topics = "webrtc"
|
||||
|
||||
# TODO: Implement a connect (or dial) procedure
|
||||
|
||||
type
|
||||
WebRTC* = ref object
|
||||
udp*: UdpConn
|
||||
|
@ -29,7 +31,7 @@ proc new*(T: typedesc[WebRTC], address: TransportAddress): T =
|
|||
var webrtc = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls())
|
||||
webrtc.udp.init(address)
|
||||
webrtc.stun.init(webrtc.udp, address)
|
||||
webrtc.dtls.start(webrtc.stun, address)
|
||||
webrtc.dtls.init(webrtc.stun, address)
|
||||
webrtc.sctp = Sctp.new(webrtc.dtls, address)
|
||||
return webrtc
|
||||
|
||||
|
|
Loading…
Reference in New Issue