Multiple fixes

This commit is contained in:
Ludovic Chenut 2023-08-18 11:47:20 +02:00
parent d875ba1ad8
commit 237d723374
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
8 changed files with 46 additions and 32 deletions

View File

@ -8,7 +8,7 @@
# those terms. # those terms.
import std/times import std/times
import chronos import chronos, chronicles
import webrtc_connection import webrtc_connection
import mbedtls/ssl import mbedtls/ssl
@ -22,6 +22,10 @@ import mbedtls/x509_crt
import mbedtls/bignum import mbedtls/bignum
import mbedtls/error import mbedtls/error
import mbedtls/net_sockets import mbedtls/net_sockets
import mbedtls/timing
logScope:
topics = "webrtc dtls"
type type
DtlsConn* = ref object of WebRTCConn DtlsConn* = ref object of WebRTCConn
@ -31,6 +35,7 @@ type
entropy: mbedtls_entropy_context entropy: mbedtls_entropy_context
ctr_drbg: mbedtls_ctr_drbg_context ctr_drbg: mbedtls_ctr_drbg_context
timer: mbedtls_timing_delay_context
config: mbedtls_ssl_config config: mbedtls_ssl_config
ssl: mbedtls_ssl_context ssl: mbedtls_ssl_context
@ -57,7 +62,6 @@ proc generateCertificate(self: DtlsConn): mbedtls_x509_crt =
time_from = times.now().format(time_format) time_from = times.now().format(time_format)
time_to = (times.now() + times.years(1)).format(time_format) time_to = (times.now() + times.years(1)).format(time_format)
var issuer_key = self.generateKey() var issuer_key = self.generateKey()
var write_cert: mbedtls_x509write_cert var write_cert: mbedtls_x509write_cert
var serial_mpi: mbedtls_mpi var serial_mpi: mbedtls_mpi
@ -78,14 +82,15 @@ proc generateCertificate(self: DtlsConn): mbedtls_x509_crt =
mb_x509_crt_parse(result, buf) mb_x509_crt_parse(result, buf)
proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
echo "dtlsSend: " echo "Send: ", len
let self = cast[ptr DtlsConn](ctx) let self = cast[ptr DtlsConn](ctx)
self.sendEvent.fire() self.sendEvent.fire()
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
echo "dtlsRecv: " echo "Recv: ", len
let self = cast[ptr DtlsConn](ctx)[] let self = cast[ptr DtlsConn](ctx)[]
self.recvEvent.fire()
let x = self.read()
method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} = method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
await procCall(WebRTCConn(self).init(conn, address)) await procCall(WebRTCConn(self).init(conn, address))
@ -110,7 +115,10 @@ method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.asyn
mb_ssl_conf_read_timeout(self.config, 10000) # in milliseconds mb_ssl_conf_read_timeout(self.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(self.config, srvcert.next, nil) mb_ssl_conf_ca_chain(self.config, srvcert.next, nil)
mb_ssl_conf_own_cert(self.config, srvcert, pkey) mb_ssl_conf_own_cert(self.config, srvcert, pkey)
# cookies ? mbedtls_ssl_set_timer_cb(addr self.ssl, cast[pointer](addr self.timer),
mbedtls_timing_set_delay,
mbedtls_timing_get_delay)
# cookie ?
mb_ssl_setup(self.ssl, self.config) mb_ssl_setup(self.ssl, self.config)
mb_ssl_session_reset(self.ssl) mb_ssl_session_reset(self.ssl)
mb_ssl_set_bio(self.ssl, cast[pointer](addr selfvar), mb_ssl_set_bio(self.ssl, cast[pointer](addr selfvar),
@ -118,25 +126,21 @@ method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.asyn
while true: while true:
mb_ssl_handshake(self.ssl) mb_ssl_handshake(self.ssl)
method close*(self: DtlsConn) {.async.} =
discard
method write*(self: DtlsConn, msg: seq[byte]) {.async.} = method write*(self: DtlsConn, msg: seq[byte]) {.async.} =
var buf = msg var buf = msg
self.sendEvent.clear() self.sendEvent.clear()
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](buf.cstring), buf.len()) discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
await self.sendEvent.wait() await self.sendEvent.wait()
method read*(self: DtlsConn): Future[seq[byte]] {.async.} = method read*(self: DtlsConn): Future[seq[byte]] {.async.} =
var res = newString(4096) return await self.conn.read()
self.recvEvent.clear()
discard mbedtls_ssl_read(addr self.ssl, cast[ptr byte](res.cstring), 4096) method close*(self: DtlsConn) {.async.} =
await self.recvEvent.wait() discard
proc main {.async.} = proc main {.async.} =
let laddr = initTAddress("127.0.0.1:" & "4242") let laddr = initTAddress("127.0.0.1:" & "4242")
var dtls = DtlsConn() var dtls = DtlsConn()
await dtls.init(nil, laddr) await dtls.init(nil, laddr)
let cert = dtls.generateCertificate()
waitFor(main()) waitFor(main())

View File

@ -8,7 +8,7 @@
# those terms. # those terms.
import tables, bitops, posix, strutils, sequtils import tables, bitops, posix, strutils, sequtils
import chronos, chronicles, stew/ranges/ptr_arith import chronos, chronicles, stew/[ranges/ptr_arith, byteutils]
import usrsctp import usrsctp
export chronicles export chronicles
@ -101,10 +101,13 @@ proc write*(self: SctpConnection, buf: seq[byte]) {.async.} =
self.sctp.sentConnection = self self.sctp.sentConnection = self
self.sctp.sentAddress = self.address self.sctp.sentAddress = self.address
let sendvErr = self.sctp.usrsctpAwait: let sendvErr = self.sctp.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(addr buf[0], buf.len.uint, self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint,
nil, 0, nil, 0, nil, 0, nil, 0,
SCTP_SENDV_NOINFO, 0) SCTP_SENDV_NOINFO, 0)
proc write*(self: SctpConnection, s: string) {.async.} =
await self.write(s.toBytes())
proc close*(self: SctpConnection) {.async.} = proc close*(self: SctpConnection) {.async.} =
self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close() self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close()
@ -143,7 +146,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
if bitand(flags, MSG_NOTIFICATION) != 0: if bitand(flags, MSG_NOTIFICATION) != 0:
trace "Notification received", length = n trace "Notification received", length = n
else: else:
conn.dataRecv = conn.dataRecv.concat(buffer[0..n]) conn.dataRecv = conn.dataRecv.concat(buffer[0..<n])
conn.recvEvent.fire() conn.recvEvent.fire()
else: else:
warn "Handle Upcall unexpected event", events warn "Handle Upcall unexpected event", events
@ -237,7 +240,7 @@ proc startServer*(self: Sctp, sctpPort: uint16 = 5000) =
doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self)) doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self))
self.sockServer = sock self.sockServer = sock
proc closeServer(self: Sctp) = proc stopServer*(self: Sctp) =
if not self.isServer: if not self.isServer:
trace "Try to close a client" trace "Try to close a client"
return return
@ -254,7 +257,7 @@ proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} = proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
let let
msg = udp.getMessage() msg = udp.getMessage()
data = usrsctp_dumppacket(addr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
if data != nil: if data != nil:
if sctp.isServer: if sctp.isServer:
trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), address trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), address
@ -264,12 +267,12 @@ proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
if sctp.isServer: if sctp.isServer:
sctp.sentAddress = address sctp.sentAddress = address
usrsctp_conninput(cast[pointer](sctp), addr msg[0], uint(msg.len), 0) usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
else: else:
let conn = await sctp.getOrCreateConnection(udp, address) let conn = await sctp.getOrCreateConnection(udp, address)
sctp.sentConnection = conn sctp.sentConnection = conn
sctp.sentAddress = address sctp.sentAddress = address
usrsctp_conninput(cast[pointer](sctp), addr msg[0], uint(msg.len), 0) usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
let let
localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port)) localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port))
laddr = initTAddress("127.0.0.1:" & $port) laddr = initTAddress("127.0.0.1:" & $port)
@ -285,6 +288,10 @@ proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
return sctp return sctp
proc stop*(self: Sctp) {.async.} =
discard self.usrsctpAwait usrsctp_finish()
self.udp.close()
proc listen*(self: Sctp): Future[SctpConnection] {.async.} = proc listen*(self: Sctp): Future[SctpConnection] {.async.} =
if not self.isServer: if not self.isServer:
raise newSctpError("Not a server") raise newSctpError("Not a server")

View File

@ -13,7 +13,7 @@ import chronos,
binary_serialization, binary_serialization,
stew/objects, stew/objects,
stew/byteutils stew/byteutils
import stunattributes import stun_attributes
export binary_serialization export binary_serialization

View File

@ -11,7 +11,7 @@ import sequtils, typetraits
import binary_serialization, import binary_serialization,
stew/byteutils, stew/byteutils,
chronos chronos
import utils import ../utils
type type
StunAttributeEncodingError* = object of CatchableError StunAttributeEncodingError* = object of CatchableError

View File

@ -8,7 +8,7 @@
# those terms. # those terms.
import chronos import chronos
import webrtc_connection, stun import ../webrtc_connection, stun
type type
StunConn* = ref object of WebRTCConn StunConn* = ref object of WebRTCConn
@ -24,8 +24,8 @@ proc handles(self: StunConn) {.async.} =
if res.isSome(): if res.isSome():
await self.conn.write(res.get()) await self.conn.write(res.get())
else: else:
recvData.add(msg) self.recvData.add(msg)
recvEvent.fire() self.recvEvent.fire()
method init(self: StunConn, conn: WebRTCConn, address: TransportAddress) {.async.} = method init(self: StunConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
await procCall(WebRTCConn(self).init(conn, address)) await procCall(WebRTCConn(self).init(conn, address))

View File

@ -7,9 +7,12 @@
# This file may not be copied, modified, or distributed except according to # This file may not be copied, modified, or distributed except according to
# those terms. # those terms.
import chronos import chronos, chronicles
import webrtc_connection import webrtc_connection
logScope:
topics = "webrtc udp"
type type
UdpConn* = ref object of WebRTCConn UdpConn* = ref object of WebRTCConn
udp: DatagramTransport udp: DatagramTransport
@ -25,7 +28,7 @@ method init(self: UdpConn, conn: WebRTCConn, address: TransportAddress) {.async.
self.recvEvent.fire() self.recvEvent.fire()
self.recvEvent = newAsyncEvent() self.recvEvent = newAsyncEvent()
self.udp = newDatagramTransport(onReceive) self.udp = newDatagramTransport(onReceive, local = address)
method close(self: UdpConn) {.async.} = method close(self: UdpConn) {.async.} =
self.udp.close() self.udp.close()

View File

@ -8,7 +8,7 @@
# those terms. # those terms.
import chronos, chronicles import chronos, chronicles
import stun import stun/stun
logScope: logScope:
topics = "webrtc" topics = "webrtc"

View File

@ -11,8 +11,8 @@ import chronos
type type
WebRTCConn* = ref object of RootObj WebRTCConn* = ref object of RootObj
conn: WebRTCConn conn*: WebRTCConn
address: TransportAddress address*: TransportAddress
# isClosed: bool # isClosed: bool
# isEof: bool # isEof: bool