Multiple fixes on stun / udp / dtls
This commit is contained in:
parent
397c84238a
commit
2a9b8298eb
|
@ -7,9 +7,9 @@
|
|||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import times, sequtils
|
||||
import times, deques, tables
|
||||
import chronos, chronicles
|
||||
import ./utils, ../webrtc_connection
|
||||
import ./utils, ../stun/stun_connection
|
||||
|
||||
import mbedtls/ssl
|
||||
import mbedtls/ssl_cookie
|
||||
|
@ -29,11 +29,19 @@ import mbedtls/timing
|
|||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit.
|
||||
# There might be some errors (or crashes) in weird cases with the no wait option
|
||||
|
||||
const
|
||||
PendingHandshakeLimit = 1024
|
||||
|
||||
type
|
||||
DtlsError* = object of CatchableError
|
||||
DtlsConn* = ref object of WebRTCConn
|
||||
recvData: seq[seq[byte]]
|
||||
recvEvent: AsyncEvent
|
||||
DtlsConn* = ref object
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
raddr: TransportAddress
|
||||
dataRecv: AsyncQueue[seq[byte]]
|
||||
sendFuture: Future[void]
|
||||
|
||||
timer: mbedtls_timing_delay_context
|
||||
|
@ -51,70 +59,99 @@ proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
|||
var toWrite = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr toWrite[0], buf, len)
|
||||
self.sendFuture = self.conn.write(toWrite)
|
||||
self.sendFuture = self.conn.write(self.raddr, toWrite)
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
var self = cast[DtlsConn](ctx)
|
||||
result = self.recvData[0].len().cint
|
||||
copyMem(buf, addr self.recvData[0][0], self.recvData[0].len())
|
||||
self.recvData.delete(0..0)
|
||||
var
|
||||
self = cast[DtlsConn](ctx)
|
||||
dataRecv = self.dataRecv.popFirstNoWait()
|
||||
copyMem(buf, addr dataRecv[0], dataRecv.len())
|
||||
result = dataRecv.len().cint
|
||||
|
||||
method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
|
||||
await procCall(WebRTCConn(self).init(conn, address))
|
||||
proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.dataRecv = newAsyncQueue[seq[byte]]()
|
||||
|
||||
method write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
var buf = msg
|
||||
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
|
||||
|
||||
method read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
return await self.conn.read()
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
var res = newSeq[byte](8192)
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
res.setLen(length)
|
||||
return res
|
||||
|
||||
method close*(self: DtlsConn) {.async.} =
|
||||
proc close*(self: DtlsConn) {.async.} =
|
||||
discard
|
||||
|
||||
method getRemoteAddress*(self: DtlsConn): TransportAddress =
|
||||
self.conn.getRemoteAddress()
|
||||
|
||||
type
|
||||
Dtls* = ref object of RootObj
|
||||
address: TransportAddress
|
||||
connections: Table[TransportAddress, DtlsConn]
|
||||
pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])]
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
started: bool
|
||||
readLoop: Future[void]
|
||||
|
||||
proc start*(self: Dtls, address: TransportAddress) =
|
||||
proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
|
||||
raddr: TransportAddress, buf: seq[byte]) =
|
||||
for (k, v) in aq.mitems():
|
||||
if k == raddr:
|
||||
v = buf
|
||||
return
|
||||
aq.addLastNoWait((raddr, buf))
|
||||
|
||||
proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
||||
if self.started:
|
||||
warn "Already started"
|
||||
return
|
||||
|
||||
self.address = address
|
||||
proc readLoop() {.async.} =
|
||||
while true:
|
||||
let (buf, raddr) = await self.conn.read()
|
||||
if self.connections.hasKey(raddr):
|
||||
self.connections[raddr].dataRecv.addLastNoWait(buf)
|
||||
else:
|
||||
self.pendingHandshakes.updateOrAdd(raddr, buf)
|
||||
|
||||
self.connections = initTable[TransportAddress, DtlsConn]()
|
||||
self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit)
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.started = true
|
||||
self.readLoop = readLoop()
|
||||
|
||||
proc stop*(self: Dtls) =
|
||||
if not self.started:
|
||||
warn "Already stopped"
|
||||
return
|
||||
|
||||
self.readLoop.cancel()
|
||||
self.started = false
|
||||
|
||||
proc serverHandshake(self: DtlsConn) {.async.} =
|
||||
var shouldRead = true
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
|
||||
var shouldRead = true
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
self.recvData.add(await self.conn.read())
|
||||
var ta = self.getRemoteAddress()
|
||||
case ta.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, ta.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, ta.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
shouldRead = false
|
||||
if not self.sendFuture.isNil(): await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ or
|
||||
|
@ -128,13 +165,13 @@ proc serverHandshake(self: DtlsConn) {.async.} =
|
|||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
|
||||
proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
|
||||
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
let v = cast[pointer](res)
|
||||
|
||||
await res.init(conn, self.address)
|
||||
await res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
mb_ssl_cookie_init(res.cookie)
|
||||
|
@ -162,20 +199,32 @@ proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
|
|||
mb_ssl_session_reset(res.ssl)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
await res.serverHandshake()
|
||||
while true:
|
||||
let (raddr, buf) = await self.pendingHandshakes.popFirst()
|
||||
try:
|
||||
res.raddr = raddr
|
||||
res.dataRecv.addLastNoWait(buf)
|
||||
self.connections[raddr] = res
|
||||
await res.serverHandshake()
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
continue
|
||||
return res
|
||||
|
||||
proc dial*(self: Dtls, address: TransportAddress): DtlsConn =
|
||||
proc dial*(self: Dtls, raddr: TransportAddress): DtlsConn =
|
||||
discard
|
||||
|
||||
import ../udp_connection
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4433")
|
||||
let udp = UdpConn()
|
||||
await udp.init(nil, laddr)
|
||||
await udp.init(laddr)
|
||||
let stun = StunConn()
|
||||
await stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.start(laddr)
|
||||
let x = await dtls.accept(udp)
|
||||
dtls.start(stun, laddr)
|
||||
let x = await dtls.accept()
|
||||
echo "After accept"
|
||||
|
||||
waitFor(main())
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import tables, bitops, posix, strutils, sequtils
|
||||
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils]
|
||||
import usrsctp
|
||||
import dtls/dtls
|
||||
|
||||
export chronicles
|
||||
|
||||
|
@ -36,6 +37,7 @@ type
|
|||
params*: SctpMessageParameters
|
||||
|
||||
SctpConnection* = ref object
|
||||
conn: DtlsConn
|
||||
state: SctpState
|
||||
connectEvent: AsyncEvent
|
||||
sctp: Sctp
|
||||
|
@ -45,8 +47,9 @@ type
|
|||
dataRecv: AsyncQueue[SctpMessage]
|
||||
|
||||
Sctp* = ref object
|
||||
dtls: Dtls
|
||||
udp: DatagramTransport
|
||||
connections: Table[TransportAddress, SctpConnection]
|
||||
#connections: Table[TransportAddress, SctpConnection]
|
||||
gotConnection: AsyncEvent
|
||||
timersHandler: Future[void]
|
||||
isServer: bool
|
||||
|
|
|
@ -19,27 +19,27 @@ type
|
|||
|
||||
proc handles(self: StunConn) {.async.} =
|
||||
while true: # TODO: while not self.conn.atEof()
|
||||
let (msg, address) = await self.conn.read()
|
||||
let (msg, raddr) = await self.conn.read()
|
||||
if Stun.isMessage(msg):
|
||||
let res = Stun.getResponse(msg, self.laddr)
|
||||
if res.isSome():
|
||||
await self.conn.write(res.get())
|
||||
await self.conn.write(raddr, res.get())
|
||||
else:
|
||||
self.dataRecv.addLastNoWait((msg, address))
|
||||
self.dataRecv.addLastNoWait((msg, raddr))
|
||||
|
||||
proc init(self: StunConn, conn: UdpConn, laddr: TransportAddress) {.async.} =
|
||||
proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) {.async.} =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
|
||||
self.dataRecv = newAsyncQueue()
|
||||
self.handlesFut = handles()
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.handlesFut = self.handles()
|
||||
|
||||
proc close(self: StunConn) {.async.} =
|
||||
proc close*(self: StunConn) {.async.} =
|
||||
self.handlesFut.cancel() # check before?
|
||||
self.conn.close()
|
||||
await self.conn.close()
|
||||
|
||||
proc write(self: StunConn, msg: seq[byte]) {.async.} =
|
||||
await self.conn.write(msg)
|
||||
proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
await self.conn.write(raddr, msg)
|
||||
|
||||
proc read(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
return await self.dataRecv.popFirst()
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
import webrtc_connection
|
||||
|
||||
logScope:
|
||||
topics = "webrtc udp"
|
||||
|
@ -20,7 +19,7 @@ type
|
|||
udp: DatagramTransport
|
||||
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
|
||||
|
||||
proc init(self: UdpConn, laddr: TransportAddress) {.async.} =
|
||||
proc init*(self: UdpConn, laddr: TransportAddress) {.async.} =
|
||||
self.laddr = laddr
|
||||
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
|
@ -28,18 +27,16 @@ proc init(self: UdpConn, laddr: TransportAddress) {.async.} =
|
|||
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m: ", msg.len()
|
||||
self.dataRecv.addLastNoWait((msg, address))
|
||||
|
||||
self.dataRecv = newAsyncQueue()
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.udp = newDatagramTransport(onReceive, local = laddr)
|
||||
|
||||
proc close(self: UdpConn) {.async.} =
|
||||
proc close*(self: UdpConn) {.async.} =
|
||||
self.udp.close()
|
||||
if not self.conn.isNil():
|
||||
await self.conn.close()
|
||||
|
||||
proc write(self: UdpConn, msg: seq[byte]) {.async.} =
|
||||
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
echo "\e[33m<UDP>\e[0;1m write\e[0m"
|
||||
await self.udp.sendTo(self.remote, msg)
|
||||
await self.udp.sendTo(raddr, msg)
|
||||
|
||||
proc read(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
echo "\e[33m<UDP>\e[0;1m read\e[0m"
|
||||
return await self.dataRecv.popFirst()
|
||||
|
|
Loading…
Reference in New Issue