few fixes on udp & try to fix dtls read and write
This commit is contained in:
parent
a65d905fb8
commit
11031a4706
|
@ -7,7 +7,8 @@
|
|||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/times
|
||||
import times, sequtils
|
||||
import strutils # to remove
|
||||
import chronos, chronicles
|
||||
import ./utils, ../webrtc_connection
|
||||
|
||||
|
@ -40,14 +41,16 @@ type
|
|||
|
||||
proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
echo "Send: ", len
|
||||
let self = cast[ptr DtlsConn](ctx)
|
||||
let self = cast[DtlsConn](ctx)
|
||||
self.sendEvent.fire()
|
||||
|
||||
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
echo "Recv: ", len
|
||||
let self = cast[ptr DtlsConn](ctx)[]
|
||||
|
||||
let x = self.read()
|
||||
var self = cast[DtlsConn](ctx)[]
|
||||
echo "Recv: ", self.recvData[0].len(), " ", len
|
||||
echo ctx.repr
|
||||
result = self.recvData[0].len().cint
|
||||
copyMem(buf, addr self.recvData[0][0], self.recvData[0].len())
|
||||
self.recvData.delete(0..0)
|
||||
|
||||
method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
|
||||
await procCall(WebRTCConn(self).init(conn, address))
|
||||
|
@ -67,13 +70,6 @@ method read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
|||
method close*(self: DtlsConn) {.async.} =
|
||||
discard
|
||||
|
||||
proc main {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:" & "4242")
|
||||
var dtls = DtlsConn()
|
||||
await dtls.init(nil, laddr)
|
||||
|
||||
waitFor(main())
|
||||
|
||||
type
|
||||
Dtls* = ref object of RootObj
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
|
@ -99,41 +95,75 @@ proc stop*(self: Dtls) =
|
|||
warn "Already stopped"
|
||||
return
|
||||
|
||||
self.stopped = false
|
||||
self.started = false
|
||||
|
||||
proc handshake(self: DtlsConn) {.async.} =
|
||||
var endpoint =
|
||||
if self.ssl.private_conf.private_endpoint == MBEDTLS_SSL_IS_SERVER:
|
||||
MBEDTLS_ERR_SSL_WANT_READ
|
||||
else:
|
||||
MBEDTLS_ERR_SSL_WANT_WRITE
|
||||
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
echo "State: ", toHex(self.ssl.private_state.int)
|
||||
if endpoint == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
self.recvData.add(await self.conn.read())
|
||||
echo "=====> ", self.recvData.len()
|
||||
let res = mbedtls_ssl_handshake_step(addr self.ssl)
|
||||
if res == MBEDTLS_ERR_SSL_WANT_READ or res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
echo "Result handshake step: ", (-res).toHex, " ",
|
||||
(-MBEDTLS_ERR_SSL_WANT_READ).toHex, " ",
|
||||
(-MBEDTLS_ERR_SSL_WANT_WRITE).toHex
|
||||
if res == MBEDTLS_ERR_SSL_WANT_READ or res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
echo if res == MBEDTLS_ERR_SSL_WANT_READ: "WANT_READ" else: "WANT_WRITE"
|
||||
continue
|
||||
elif res != 0:
|
||||
break # raise whatever
|
||||
endpoint = res
|
||||
|
||||
proc accept*(self: Dtls, conn: WebRTCConn): DtlsConn {.async.} =
|
||||
proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
|
||||
echo "1"
|
||||
var
|
||||
srvcert = self.generateCertificate()
|
||||
pkey = self.generateKey()
|
||||
srvcert = self.ctr_drbg.generateCertificate()
|
||||
pkey = self.ctr_drbg.generateKey()
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
let v = cast[pointer](res)
|
||||
echo v.repr
|
||||
|
||||
result = Dtls()
|
||||
result.init(conn, self.address)
|
||||
mb_ssl_init(result.ssl)
|
||||
mb_ssl_config_init(result.config)
|
||||
mb_ssl_config_defaults(result.config,
|
||||
await res.init(conn, self.address)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
mb_ssl_config_defaults(res.config,
|
||||
MBEDTLS_SSL_IS_SERVER,
|
||||
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT)
|
||||
mb_ssl_conf_rng(result.config, mbedtls_ctr_drbg_random, self.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(result.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(result.config, srvcert.next, nil)
|
||||
mb_ssl_conf_own_cert(result.config, srvcert, pkey)
|
||||
mbedtls_ssl_set_timer_cb(addr result.ssl, cast[pointer](addr result.timer),
|
||||
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, self.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
|
||||
mb_ssl_conf_own_cert(res.config, srvcert, pkey)
|
||||
mbedtls_ssl_set_timer_cb(addr res.ssl, cast[pointer](addr res.timer),
|
||||
mbedtls_timing_set_delay,
|
||||
mbedtls_timing_get_delay)
|
||||
# Add the cookie management (it works without, but it's more secure)
|
||||
mb_ssl_setup(result.ssl, result.config)
|
||||
mb_ssl_session_reset(result.ssl)
|
||||
mb_ssl_set_bio(result.ssl, cast[pointer](result),
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_session_reset(res.ssl)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
await result.handshake()
|
||||
await res.handshake()
|
||||
return res
|
||||
|
||||
proc dial*(self: Dtls, address: TransportAddress): DtlsConn =
|
||||
discard
|
||||
|
||||
import ../udp_connection
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4433")
|
||||
let udp = UdpConn()
|
||||
await udp.init(nil, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.start(laddr)
|
||||
echo "Before accept"
|
||||
let x = await dtls.accept(udp)
|
||||
echo "After accept"
|
||||
|
||||
waitFor(main())
|
||||
|
|
|
@ -24,14 +24,14 @@ proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context =
|
|||
else:
|
||||
return nil
|
||||
|
||||
proc generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
||||
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
||||
var res: mbedtls_pk_context
|
||||
mb_pk_init(res)
|
||||
discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))
|
||||
mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 4096, 65537)
|
||||
return res
|
||||
res
|
||||
|
||||
proc generateCertificate*(random: mbedtls_ctr_drbg_context): mbedtls_x509_crt =
|
||||
template generateCertificate*(random: mbedtls_ctr_drbg_context): mbedtls_x509_crt =
|
||||
let
|
||||
name = "C=FR,O=webrtc,CN=webrtc"
|
||||
time_format = initTimeFormat("YYYYMMddHHmmss")
|
||||
|
@ -55,4 +55,6 @@ proc generateCertificate*(random: mbedtls_ctr_drbg_context): mbedtls_x509_crt =
|
|||
let serial_hex = mb_mpi_read_string(serial_mpi, 16)
|
||||
mb_x509write_crt_set_serial(write_cert, serial_mpi)
|
||||
let buf = mb_x509write_crt_pem(write_cert, 4096, mbedtls_ctr_drbg_random, random)
|
||||
mb_x509_crt_parse(result, buf)
|
||||
var res: mbedtls_x509_crt
|
||||
mb_x509_crt_parse(res, buf)
|
||||
res
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
import webrtc_connection
|
||||
|
||||
|
@ -20,7 +21,7 @@ type
|
|||
recvEvent: AsyncEvent
|
||||
|
||||
method init(self: UdpConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
|
||||
procCall(WebRTCConn(self).init(conn, address))
|
||||
await procCall(WebRTCConn(self).init(conn, address))
|
||||
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
let msg = udp.getMessage()
|
||||
|
@ -33,12 +34,12 @@ method init(self: UdpConn, conn: WebRTCConn, address: TransportAddress) {.async.
|
|||
method close(self: UdpConn) {.async.} =
|
||||
self.udp.close()
|
||||
if not self.conn.isNil():
|
||||
self.conn.close()
|
||||
await self.conn.close()
|
||||
|
||||
method write(self: UdpConn, msg: seq[byte]) {.async.} =
|
||||
await self.udp.sendTo(self.address, msg)
|
||||
|
||||
method read(self: UdpConn): seq[byte] {.async.} =
|
||||
method read(self: UdpConn): Future[seq[byte]] {.async.} =
|
||||
while self.recvData.len() <= 0:
|
||||
self.recvEvent.clear()
|
||||
await self.recvEvent.wait()
|
||||
|
|
Loading…
Reference in New Issue