diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim index 22398e9..9d44f1f 100644 --- a/webrtc/dtls/dtls.nim +++ b/webrtc/dtls/dtls.nim @@ -84,14 +84,16 @@ proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} = proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = trace "Dtls write", length = msg.len() var buf = msg - discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint) + # TODO: exception catching + discard mb_ssl_write(self.ssl, buf) proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = var res = newSeq[byte](8192) while true: let tmp = await self.dataRecv.popFirst() self.dataRecv.addFirstNoWait(tmp) - let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint) + # TODO: exception catching + let length = mb_ssl_read(self.ssl, res) if length == MBEDTLS_ERR_SSL_WANT_READ: continue if length < 0: @@ -191,10 +193,37 @@ proc serverHandshake(self: DtlsConn) {.async.} = continue elif res != 0: raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) - # var remoteCertPtr = mbedtls_ssl_get_peer_cert(addr self.ssl) - # let remoteCert = remoteCertPtr[] - # self.remoteCert = newSeq[byte](remoteCert.raw.len) - # copyMem(addr self.remoteCert[0], remoteCert.raw.p, remoteCert.raw.len) + +proc clientHandshake(self: DtlsConn) {.async.} = + var shouldRead = false + while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + self.sendFuture = nil + let res = mb_ssl_handshake_step(self.ssl) + if not self.sendFuture.isNil(): await self.sendFuture + shouldRead = false + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ: + # or self.ssl.private_state == MBEDTLS_SSL_SERVER_KEY_EXCHANGE: + # TODO: Might need to check directly on mbedtls C source + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ssl) + shouldRead = true + continue + elif res != 0: + raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) proc remoteCertificate*(conn: DtlsConn): seq[byte] = conn.remoteCert @@ -245,8 +274,8 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} = mb_ssl_set_timer_cb(res.ssl, res.timer) mb_ssl_setup(res.ssl, res.config) mb_ssl_session_reset(res.ssl) - mbedtls_ssl_set_verify(addr res.ssl, verify, cast[pointer](res)) - mbedtls_ssl_conf_authmode(addr res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) # TODO: create template + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) while true: @@ -263,20 +292,46 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} = continue return res -proc dial*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = - discard +proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = + var + selfvar = self + res = DtlsConn() -#import ../udp_connection -#import stew/byteutils -#proc main() {.async.} = -# let laddr = initTAddress("127.0.0.1:4433") -# let udp = UdpConn() -# await udp.init(laddr) -# let stun = StunConn() -# await stun.init(udp, laddr) -# let dtls = Dtls() -# dtls.start(stun, laddr) -# let x = await dtls.accept() -# echo "Recv: <", string.fromBytes(await x.read()), ">" -# -#waitFor(main()) + await res.init(self.conn, self.laddr) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + + var pkey = res.ctr_drbg.generateKey() + var srvcert = res.ctr_drbg.generateCertificate(pkey) + res.localCert = newSeq[byte](srvcert.raw.len) + copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len) + + mb_ctr_drbg_init(res.ctr_drbg) + mb_entropy_init(res.entropy) + mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0) + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + mb_ssl_set_bio(res.ssl, cast[pointer](res), + dtlsSend, dtlsRecv, nil) + + res.raddr = raddr + self.connections[raddr] = res + + try: + await res.clientHandshake() + except CatchableError as exc: + trace "Handshake fail", remoteAddress = raddr, error = exc.msg + self.connections.del(raddr) + raise exc + + return res diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index 20fd3ab..d77058b 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -398,7 +398,7 @@ proc readLoopProc(res: SctpConn) {.async.} = proc accept*(self: Sctp): Future[SctpConn] {.async.} = if not self.isServer: raise newSctpError("Not a server") - var res = SctpConn.new(await self.dtls.accept, self) + var res = SctpConn.new(await self.dtls.accept(), self) usrsctp_register_address(cast[pointer](res)) res.readLoop = res.readLoopProc() res.acceptEvent.clear() @@ -429,20 +429,21 @@ proc listen*(self: Sctp, sctpPort: uint16 = 5000) = proc connect*(self: Sctp, address: TransportAddress, sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = - discard + let + sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) + res = SctpConn.new(await self.dtls.connect(address), self) -# proc connect*(self: Sctp, -# address: TransportAddress, -# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = -# trace "Connect", address, sctpPort -# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort) -# if conn.state == Connected: -# return conn -# try: -# await conn.connectEvent.wait() # TODO: clear? -# except CancelledError as exc: -# conn.sctpSocket.usrsctp_close() -# return nil -# if conn.state != Connected: -# raise newSctpError("Cannot connect to " & $address) -# return conn + #usrsctp_register_address(cast[pointer](res)) + +# trace "Connect", address, sctpPort +# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort) +# if conn.state == Connected: +# return conn +# try: +# await conn.connectEvent.wait() # TODO: clear? +# except CancelledError as exc: +# conn.sctpSocket.usrsctp_close() +# return nil +# if conn.state != Connected: +# raise newSctpError("Cannot connect to " & $address) +# return conn