From f49ca90491c8e82661017549f290701ddfde14a0 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Wed, 6 Mar 2024 13:47:32 +0100 Subject: [PATCH] Sctp comments + refacto --- webrtc/dtls/utils.nim | 10 +-- webrtc/sctp.nim | 161 ++++++++++++++++++++++-------------------- 2 files changed, 88 insertions(+), 83 deletions(-) diff --git a/webrtc/dtls/utils.nim b/webrtc/dtls/utils.nim index 6f9ad5b..ebecd40 100644 --- a/webrtc/dtls/utils.nim +++ b/webrtc/dtls/utils.nim @@ -20,6 +20,7 @@ import mbedtls/md import chronicles +# This sequence is used for debugging. const mb_ssl_states* = @[ "MBEDTLS_SSL_HELLO_REQUEST", "MBEDTLS_SSL_CLIENT_HELLO", @@ -53,14 +54,6 @@ const mb_ssl_states* = @[ "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" ] -proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context = - var key = pk - case mbedtls_pk_get_type(addr key) - of MBEDTLS_PK_RSA: - return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx) - else: - return nil - template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = var res: mbedtls_pk_context mb_pk_init(res) @@ -72,6 +65,7 @@ template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = template generateCertificate*(random: mbedtls_ctr_drbg_context, issuer_key: mbedtls_pk_context): mbedtls_x509_crt = let + # To be honest, I have no clue what to put here as a name name = "C=FR,O=Status,CN=webrtc" time_format = initTimeFormat("YYYYMMddHHmmss") time_from = times.now().format(time_format) diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index e0b952d..4d8afde 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -18,9 +18,18 @@ export chronicles logScope: topics = "webrtc sctp" +# Implementation of an Sctp client and server using the usrsctp library. +# Usrsctp is usable as a single thread but it's not the intended way to +# use it. There's a lot of callbacks calling each other in a synchronous +# way where we want to be able to call asynchronous procedure, but cannot. + # TODO: # - Replace doAssert by a proper exception management # - Find a clean way to manage SCTP ports +# - Unregister address when closing + +proc perror(error: cstring) {.importc, cdecl, header: "".} +proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} type SctpError* = object of CatchableError @@ -67,26 +76,43 @@ type sentAddress: TransportAddress sentFuture: Future[void] - # Those two objects are only here for debugging purpose + # These three objects are used for debugging/trace only SctpChunk = object chunkType: uint8 flag: uint8 length {.bin_value: it.data.len() + 4.}: uint16 data {.bin_len: it.length - 4.}: seq[byte] - SctpPacketStructure = object + SctpPacketHeader = object srcPort: uint16 dstPort: uint16 verifTag: uint32 checksum: uint32 -const - IPPROTO_SCTP = 132 + SctpPacketStructure = object + header: SctpPacketHeader + chunks: seq[SctpChunk] -proc newSctpError(msg: string): ref SctpError = - result = newException(SctpError, msg) +const IPPROTO_SCTP = 132 + +proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure = + # Only used for debugging/trace + result.header = Binary.decode(buffer, SctpPacketHeader) + var size = sizeof(SctpPacketStructure) + while size < buffer.len: + let chunk = Binary.decode(buffer[size..^1], SctpChunk) + result.chunks.add(chunk) + size.inc(chunk.length.int) + while size mod 4 != 0: + # padding; could use `size.inc(-size %% 4)` instead but it lacks clarity + size.inc(1) + +# -- Asynchronous wrapper -- template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = + # usrsctpAwait is template which set `sentFuture` to nil then calls (usually) + # an usrsctp function. If during the synchronous run of the usrsctp function + # `sendCallback` is called, then `sentFuture` is set and waited. self.sentFuture = nil when type(body) is void: body @@ -96,45 +122,7 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = if self.sentFuture != nil: await self.sentFuture res -proc perror(error: cstring) {.importc, cdecl, header: "".} -proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} - -proc printSctpPacket(buffer: seq[byte]) = - let s = Binary.decode(buffer, SctpPacketStructure) - echo " => \e[31;1mStructure\e[0m: ", s - var size = sizeof(SctpPacketStructure) - var i = 1 - while size < buffer.len: - let c = Binary.decode(buffer[size..^1], SctpChunk) - echo " ===> \e[32;1mChunk ", i, "\e[0m ", c - i.inc() - size.inc(c.length.int) - while size mod 4 != 0: - size.inc() - -proc packetPretty(packet: cstring): string = - let data = $packet - let ctn = data[23..^16] - result = data[1..14] - if ctn.len > 30: - result = result & ctn[0..14] & " ... " & ctn[^14..^1] - else: - result = result & ctn - -proc new(T: typedesc[SctpConn], - sctp: Sctp, - udp: DatagramTransport, - address: TransportAddress, - sctpSocket: ptr socket): T = - T(sctp: sctp, - state: Connecting, - udp: udp, - address: address, - sctpSocket: sctpSocket, - connectEvent: AsyncEvent(), - #TODO add some limit for backpressure? - dataRecv: newAsyncQueue[SctpMessage]() - ) +# -- SctpConn -- proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T = T(conn: conn, @@ -142,10 +130,12 @@ proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T = state: Connecting, connectEvent: AsyncEvent(), acceptEvent: AsyncEvent(), - dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure? + dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure? ) proc read*(self: SctpConn): Future[SctpMessage] {.async.} = + # Used by DataChannel, returns SctpMessage in order to get the stream + # and protocol ids return await self.dataRecv.popFirst() proc toFlags(params: SctpMessageParameters): uint16 = @@ -154,23 +144,24 @@ proc toFlags(params: SctpMessageParameters): uint16 = if params.unordered: result = result or SCTP_UNORDERED -proc write*( - self: SctpConn, - buf: seq[byte], - sendParams = default(SctpMessageParameters), - ) {.async.} = - trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket) +proc write*(self: SctpConn, buf: seq[byte], + sendParams = default(SctpMessageParameters)) {.async.} = + # Used by DataChannel, writes buf on the Dtls connection. + trace "Write", buf self.sctp.sentAddress = self.address var cpy = buf let sendvErr = if sendParams == default(SctpMessageParameters): + # If writes is called by DataChannel, sendParams should never + # be the default value. This split is useful for testing. self.usrsctpAwait: self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, nil, 0, SCTP_SENDV_NOINFO.cuint, 0) else: let sendInfo = sctp_sndinfo( snd_sid: sendParams.streamId, + # TODO: swapBytes => htonl? snd_ppid: sendParams.protocolId.swapBytes(), snd_flags: sendParams.toFlags) self.usrsctpAwait: @@ -178,29 +169,26 @@ proc write*( cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, SCTP_SENDV_SNDINFO.cuint, 0) if sendvErr < 0: - perror("usrsctp_sendv") # TODO: throw an exception - trace "write sendv error?", sendvErr, sendParams + # TODO: throw an exception + perror("usrsctp_sendv") proc write*(self: SctpConn, s: string) {.async.} = await self.write(s.toBytes()) proc close*(self: SctpConn) {.async.} = - self.usrsctpAwait: self.sctpSocket.usrsctp_close() + self.usrsctpAwait: + self.sctpSocket.usrsctp_close() + +# -- usrsctp receive data callbacks -- proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when we receive data after + # connection has been established. let conn = cast[SctpConn](data) events = usrsctp_get_events(sock) - trace "Handle Upcall", events, state = conn.state - if conn.state == Connecting: - if bitand(events, SCTP_EVENT_ERROR) != 0: - warn "Cannot connect", address = conn.address - conn.state = Closed - elif bitand(events, SCTP_EVENT_WRITE) != 0: - conn.state = Connected - conn.connectEvent.fire() - + trace "Handle Upcall", events if bitand(events, SCTP_EVENT_READ) != 0: var message = SctpMessage( @@ -212,8 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = rnLen = sizeof(sctp_recvv_rn).SockLen infotype: uint flags: int - trace "recv from", sockuint64=cast[uint64](sock) - let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint, + let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), + message.data.len.uint, cast[ptr SockAddr](addr address), cast[ptr SockLen](addr addressLen), cast[pointer](addr message.info), @@ -239,11 +227,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = except AsyncQueueFullError: trace "Queue full, dropping packet" elif bitand(events, SCTP_EVENT_WRITE) != 0: - trace "sctp event write in the upcall" + debug "sctp event write in the upcall" else: warn "Handle Upcall unexpected event", events proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when accepting a connection. trace "Handle Accept" var sconn: Sockaddr_conn @@ -266,6 +255,27 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = addr recvinfo, sizeof(recvinfo).SockLen) conn.acceptEvent.fire() +proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when connecting + trace "Handle Connect" + let + conn = cast[SctpConn](data) + events = usrsctp_get_events(sock) + + trace "Handle Upcall", events, state = conn.state + if conn.state == Connecting: + if bitand(events, SCTP_EVENT_ERROR) != 0: + warn "Cannot connect", address = conn.address + conn.state = Closed + elif bitand(events, SCTP_EVENT_WRITE) != 0: + conn.state = Connected + doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data) + conn.connectEvent.fire() + else: + warn "should be connecting", currentState = conn.state + +# -- usrsctp send data callback -- + proc sendCallback(ctx: pointer, buffer: pointer, length: uint, @@ -273,20 +283,20 @@ proc sendCallback(ctx: pointer, set_df: uint8): cint {.cdecl.} = let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND) if data != nil: - trace "sendCallback", data = data.packetPretty(), length + trace "sendCallback", sctpPacket = data.getSctpPacket(), length usrsctp_freedumpbuffer(data) let sctpConn = cast[SctpConn](ctx) let buf = @(buffer.makeOpenArray(byte, int(length))) proc testSend() {.async.} = try: trace "Send To", address = sctpConn.address - # printSctpPacket(buf) - # TODO: defined it printSctpPacket(buf) await sctpConn.conn.write(buf) except CatchableError as exc: trace "Send Failed", message = exc.msg sctpConn.sentFuture = testSend() +# -- Sctp -- + proc timersHandler() {.async.} = while true: await sleepAsync(500.milliseconds) @@ -315,6 +325,7 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T = return sctp proc stop*(self: Sctp) {.async.} = + # TODO: close every connections discard self.usrsctpAwait usrsctp_finish() self.udp.close() @@ -324,14 +335,14 @@ proc readLoopProc(res: SctpConn) {.async.} = msg = await res.conn.read() data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) if not data.isNil(): - trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty() + trace "Receive data", remoteAddress = res.conn.raddr, + sctpPacket = data.getSctpPacket() usrsctp_freedumpbuffer(data) - # printSctpPacket(msg) TODO: defined it usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0) proc accept*(self: Sctp): Future[SctpConn] {.async.} = if not self.isServer: - raise newSctpError("Not a server") + raise newException(SctpError, "Not a server") var res = SctpConn.new(await self.dtls.accept(), self) usrsctp_register_address(cast[pointer](res)) res.readLoop = res.readLoopProc() @@ -373,7 +384,7 @@ proc connect*(self: Sctp, var nodelay: uint32 = 1 var recvinfo: uint32 = 1 doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1) - doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn)) + doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn)) doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, addr nodelay, sizeof(nodelay).SockLen) doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO, @@ -391,6 +402,6 @@ proc connect*(self: Sctp, conn.state = Connecting conn.connectEvent.clear() await conn.connectEvent.wait() - # TODO: check connection state, if closed throw some exception I guess + # TODO: check connection state, if closed throw an exception self.connections[address] = conn return conn