From f6ba79428124dbd89d086b8ece0ceb7d6e6d0010 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Wed, 28 Feb 2024 13:49:43 +0100 Subject: [PATCH] sctp/dtls client done --- webrtc.nimble | 1 - webrtc/dtls/dtls.nim | 72 ++++++------------ webrtc/sctp.nim | 150 ++++++++++++-------------------------- webrtc/udp_connection.nim | 6 +- 4 files changed, 74 insertions(+), 155 deletions(-) diff --git a/webrtc.nimble b/webrtc.nimble index 0031ecd..c30bde2 100644 --- a/webrtc.nimble +++ b/webrtc.nimble @@ -11,6 +11,5 @@ requires "nim >= 1.2.0", "https://github.com/status-im/nim-binary-serialization.git", "https://github.com/status-im/nim-mbedtls.git" - proc runTest(filename: string) = discard diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim index 9d44f1f..808a5f1 100644 --- a/webrtc/dtls/dtls.nim +++ b/webrtc/dtls/dtls.nim @@ -82,10 +82,13 @@ proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} = self.dataRecv = newAsyncQueue[seq[byte]]() proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = - trace "Dtls write", length = msg.len() var buf = msg - # TODO: exception catching - discard mb_ssl_write(self.ssl, buf) + try: + let write = mb_ssl_write(self.ssl, buf) + trace "Dtls write", msgLen = msg.len(), actuallyWrote = write + except MbedTLSError as exc: + trace "Dtls write error", errorMsg = exc.msg + raise exc proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = var res = newSeq[byte](8192) @@ -97,7 +100,7 @@ proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = if length == MBEDTLS_ERR_SSL_WANT_READ: continue if length < 0: - trace "dtls read", error = $(length.mbedtls_high_level_strerr()) + trace "dtls read", error = $(length.cint.mbedtls_high_level_strerr()) res.setLen(length) return res @@ -164,47 +167,18 @@ proc stop*(self: Dtls) = self.readLoop.cancel() self.started = false -proc serverHandshake(self: DtlsConn) {.async.} = - var shouldRead = true +proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} = + var shouldRead = isServer while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: if shouldRead: - case self.raddr.family - of AddressFamily.IPv4: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) - of AddressFamily.IPv6: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) - else: - raise newException(DtlsError, "Remote address isn't an IP address") - let tmp = await self.dataRecv.popFirst() - self.dataRecv.addFirstNoWait(tmp) - self.sendFuture = nil - let res = mb_ssl_handshake_step(self.ssl) - if not self.sendFuture.isNil(): await self.sendFuture - shouldRead = false - if res == MBEDTLS_ERR_SSL_WANT_WRITE: - continue - elif res == MBEDTLS_ERR_SSL_WANT_READ or - self.ssl.private_state == MBEDTLS_SSL_CLIENT_KEY_EXCHANGE: - shouldRead = true - continue - elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: - mb_ssl_session_reset(self.ssl) - shouldRead = true - continue - elif res != 0: - raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) - -proc clientHandshake(self: DtlsConn) {.async.} = - var shouldRead = false - while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: - if shouldRead: - case self.raddr.family - of AddressFamily.IPv4: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) - of AddressFamily.IPv6: - mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) - else: - raise newException(DtlsError, "Remote address isn't an IP address") + if isServer: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") let tmp = await self.dataRecv.popFirst() self.dataRecv.addFirstNoWait(tmp) self.sendFuture = nil @@ -214,8 +188,6 @@ proc clientHandshake(self: DtlsConn) {.async.} = if res == MBEDTLS_ERR_SSL_WANT_WRITE: continue elif res == MBEDTLS_ERR_SSL_WANT_READ: - # or self.ssl.private_state == MBEDTLS_SSL_SERVER_KEY_EXCHANGE: - # TODO: Might need to check directly on mbedtls C source shouldRead = true continue elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: @@ -284,7 +256,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} = res.raddr = raddr res.dataRecv.addLastNoWait(buf) self.connections[raddr] = res - await res.serverHandshake() + await res.dtlsHandshake(true) break except CatchableError as exc: trace "Handshake fail", remoteAddress = raddr, error = exc.msg @@ -301,6 +273,9 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = mb_ssl_init(res.ssl) mb_ssl_config_init(res.config) + res.ctr_drbg = self.ctr_drbg + res.entropy = self.entropy + var pkey = res.ctr_drbg.generateKey() var srvcert = res.ctr_drbg.generateCertificate(pkey) res.localCert = newSeq[byte](srvcert.raw.len) @@ -321,14 +296,13 @@ proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = mb_ssl_setup(res.ssl, res.config) mb_ssl_set_verify(res.ssl, verify, res) mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) - mb_ssl_set_bio(res.ssl, cast[pointer](res), - dtlsSend, dtlsRecv, nil) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) res.raddr = raddr self.connections[raddr] = res try: - await res.clientHandshake() + await res.dtlsHandshake(false) except CatchableError as exc: trace "Handshake fail", remoteAddress = raddr, error = exc.msg self.connections.del(raddr) diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index d77058b..8cd5f41 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -64,7 +64,6 @@ type sockServer: ptr socket pendingConnections: seq[SctpConn] pendingConnections2: Table[SockAddr, SctpConn] - sentConnection: SctpConn sentAddress: TransportAddress sentFuture: Future[void] @@ -161,24 +160,23 @@ proc write*( sendParams = default(SctpMessageParameters), ) {.async.} = trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket) - self.sctp.sentConnection = self self.sctp.sentAddress = self.address var cpy = buf - var - (sendInfo, infoType) = - if sendParams != default(SctpMessageParameters): - (sctp_sndinfo( - snd_sid: sendParams.streamId, - snd_ppid: sendParams.protocolId.swapBytes(), - snd_flags: sendParams.toFlags - ), cuint(SCTP_SENDV_SNDINFO)) - else: - (default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO)) - sendvErr = self.usrsctpAwait: - self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0, - cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, - infoType, 0) + let sendvErr = + if sendParams == default(SctpMessageParameters): + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, + nil, 0, SCTP_SENDV_NOINFO.cuint, 0) + else: + let sendInfo = sctp_sndinfo( + snd_sid: sendParams.streamId, + snd_ppid: sendParams.protocolId.swapBytes(), + snd_flags: sendParams.toFlags) + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, + cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, + SCTP_SENDV_SNDINFO.cuint, 0) if sendvErr < 0: perror("usrsctp_sendv") # TODO: throw an exception trace "write sendv error?", sendvErr, sendParams @@ -194,7 +192,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = conn = cast[SctpConn](data) events = usrsctp_get_events(sock) - trace "Handle Upcall", events + trace "Handle Upcall", events, state = conn.state if conn.state == Connecting: if bitand(events, SCTP_EVENT_ERROR) != 0: warn "Cannot connect", address = conn.address @@ -202,7 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = elif bitand(events, SCTP_EVENT_WRITE) != 0: conn.state = Connected conn.connectEvent.fire() - elif bitand(events, SCTP_EVENT_READ) != 0: + + if bitand(events, SCTP_EVENT_READ) != 0: var message = SctpMessage( data: newSeq[byte](4096) @@ -253,12 +252,12 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = sctp = cast[Sctp](data) sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen) - doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1) let conn = cast[SctpConn](sconn.sconn_addr) conn.sctpSocket = sctpSocket conn.state = Connected var nodelay: uint32 = 1 var recvinfo: uint32 = 1 + doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1) doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn)) doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, addr nodelay, sizeof(nodelay).SockLen) @@ -266,36 +265,6 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = addr recvinfo, sizeof(recvinfo).SockLen) conn.acceptEvent.fire() -# proc getOrCreateConnection(self: Sctp, -# udp: DatagramTransport, -# address: TransportAddress, -# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = -# #TODO remove the = 5000 -# if self.connections.hasKey(address): -# return self.connections[address] -# trace "Create Connection", address -# let -# sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) -# conn = SctpConn.new(self, udp, address, sctpSocket) -# var on: int = 1 -# doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, -# SCTP_RECVRCVINFO, -# addr on, -# sizeof(on).SockLen) -# doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1) -# doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn)) -# var sconn: Sockaddr_conn -# sconn.sconn_family = AF_CONN -# sconn.sconn_port = htons(sctpPort) -# sconn.sconn_addr = cast[pointer](self) -# self.sentConnection = conn -# self.sentAddress = address -# let connErr = self.usrsctpAwait: -# conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn))) -# doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno) -# self.connections[address] = conn -# return conn - proc sendCallback(ctx: pointer, buffer: pointer, length: uint, @@ -310,6 +279,7 @@ proc sendCallback(ctx: pointer, proc testSend() {.async.} = try: trace "Send To", address = sctpConn.address + # printSctpPacket(buf) # TODO: defined it printSctpPacket(buf) await sctpConn.conn.write(buf) except CatchableError as exc: @@ -343,43 +313,6 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T = usrsctp_register_address(cast[pointer](sctp)) return sctp -#proc new*(T: typedesc[Sctp], port: uint16 = 9899): T = -# logScope: topics = "webrtc sctp" -# let sctp = T(gotConnection: newAsyncEvent()) -# proc onReceive(udp: DatagramTransport, raddr: TransportAddress) {.async, gcsafe.} = -# let -# msg = udp.getMessage() -# data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) -# if data != nil: -# if sctp.isServer: -# trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), raddr -# else: -# trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), raddr -# usrsctp_freedumpbuffer(data) -# -# if sctp.isServer: -# sctp.sentAddress = raddr -# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0) -# else: -# let conn = await sctp.getOrCreateConnection(udp, raddr) -# sctp.sentConnection = conn -# sctp.sentAddress = raddr -# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0) -# let -# localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port)) -# laddr = initTAddress("127.0.0.1:" & $port) -# udp = newDatagramTransport(onReceive, local = laddr) -# trace "local address", localAddr, laddr -# sctp.udp = udp -# sctp.timersHandler = timersHandler() -# -# usrsctp_init_nothreads(port, sendCallback, printf) -# discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE) -# discard usrsctp_sysctl_set_sctp_ecn_enable(1) -# usrsctp_register_address(cast[pointer](sctp)) -# -# return sctp - proc stop*(self: Sctp) {.async.} = discard self.usrsctpAwait usrsctp_finish() self.udp.close() @@ -392,7 +325,7 @@ proc readLoopProc(res: SctpConn) {.async.} = if not data.isNil(): trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty() usrsctp_freedumpbuffer(data) - res.sctp.sentConnection = res + # printSctpPacket(msg) TODO: defined it usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0) proc accept*(self: Sctp): Future[SctpConn] {.async.} = @@ -431,19 +364,32 @@ proc connect*(self: Sctp, sctpPort: uint16 = 5000): Future[SctpConn] {.async.} = let sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) - res = SctpConn.new(await self.dtls.connect(address), self) + conn = SctpConn.new(await self.dtls.connect(address), self) - #usrsctp_register_address(cast[pointer](res)) - -# trace "Connect", address, sctpPort -# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort) -# if conn.state == Connected: -# return conn -# try: -# await conn.connectEvent.wait() # TODO: clear? -# except CancelledError as exc: -# conn.sctpSocket.usrsctp_close() -# return nil -# if conn.state != Connected: -# raise newSctpError("Cannot connect to " & $address) -# return conn + trace "Create Connection", address + conn.sctpSocket = sctpSocket + conn.state = Connected + var nodelay: uint32 = 1 + var recvinfo: uint32 = 1 + doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1) + doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn)) + doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, + addr nodelay, sizeof(nodelay).SockLen) + doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO, + addr recvinfo, sizeof(recvinfo).SockLen) + var sconn: Sockaddr_conn + sconn.sconn_family = AF_CONN + sconn.sconn_port = htons(sctpPort) + sconn.sconn_addr = cast[pointer](conn) + self.sentAddress = address + usrsctp_register_address(cast[pointer](conn)) + conn.readLoop = conn.readLoopProc() + let connErr = self.usrsctpAwait: + conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn))) + doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno) + conn.state = Connecting + conn.connectEvent.clear() + await conn.connectEvent.wait() + # TODO: check connection state, if closed throw some exception I guess + self.connections[address] = conn + return conn diff --git a/webrtc/udp_connection.nim b/webrtc/udp_connection.nim index 8c873f7..cfec28d 100644 --- a/webrtc/udp_connection.nim +++ b/webrtc/udp_connection.nim @@ -24,7 +24,7 @@ proc init*(self: UdpConn, laddr: TransportAddress) = proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} = let msg = udp.getMessage() - echo "\e[33m\e[0;1m onReceive\e[0m" + trace "UDP onReceive", msg self.dataRecv.addLastNoWait((msg, address)) self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]() @@ -34,9 +34,9 @@ proc close*(self: UdpConn) {.async.} = self.udp.close() proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} = - echo "\e[33m\e[0;1m write\e[0m" + trace "UDP write", msg await self.udp.sendTo(raddr, msg) proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} = - echo "\e[33m\e[0;1m read\e[0m" + trace "UDP read" return await self.dataRecv.popFirst()