From 2591a158baf33784e2d46aee2cee17924c48b501 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Fri, 23 Feb 2024 11:06:59 +0100 Subject: [PATCH] A lot of fixes --- webrtc/datachannel.nim | 29 +++++++++++--- webrtc/sctp.nim | 70 +++++++++++++++++++++++---------- webrtc/stun/stun_connection.nim | 1 + webrtc/udp_connection.nim | 2 +- 4 files changed, 75 insertions(+), 27 deletions(-) diff --git a/webrtc/datachannel.nim b/webrtc/datachannel.nim index 1fac163..56fba3a 100644 --- a/webrtc/datachannel.nim +++ b/webrtc/datachannel.nim @@ -72,13 +72,17 @@ type DataChannelConnection* = ref object readLoopFut: Future[void] streams: Table[uint16, DataChannelStream] + streamId: uint16 conn*: SctpConn incomingStreams: AsyncQueue[DataChannelStream] proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} = - return await stream.receivedData.popLast() + let x = await stream.receivedData.popFirst() + trace "read", length=x.len(), id=stream.id + return x proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} = + trace "write", length=buf.len(), id=stream.id var sendInfo = SctpMessageParameters( streamId: stream.id, @@ -105,14 +109,23 @@ proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.as endOfRecord: true, protocolId: uint32(WebRtcDcep) ) + trace "send control message", msg await stream.conn.write(encoded, sendInfo) proc openStream*( conn: DataChannelConnection, - streamId: uint16, + noiseHandshake: bool, reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} = + let streamId: uint16 = + if not noiseHandshake: + let res = conn.streamId + conn.streamId += 2 + res + else: + 0 + trace "open stream", streamId if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0: raise newException(ValueError, "reliabilityParameter should be 0") @@ -144,6 +157,7 @@ proc openStream*( proc handleData(conn: DataChannelConnection, msg: SctpMessage) = let streamId = msg.params.streamId + trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data if streamId notin conn.streams: raise newException(ValueError, "got data for unknown streamid") @@ -162,6 +176,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} = decoded = Binary.decode(msg.data, DataChannelMessage) streamId = msg.params.streamId + trace "handle control message", decoded, streamId = msg.params.streamId if decoded.messageType == Ack: if streamId notin conn.streams: raise newException(ValueError, "got ack for unknown streamid") @@ -178,6 +193,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} = ) conn.streams[streamId] = stream + conn.incomingStreams.addLastNoWait(stream) await stream.sendControlMessage(DataChannelMessage(messageType: Ack)) @@ -185,6 +201,7 @@ proc readLoop(conn: DataChannelConnection) {.async.} = try: while true: let message = await conn.conn.read() + # TODO: might be necessary to check the others protocolId at some point if message.params.protocolId == uint32(WebRtcDcep): #TODO should we really await? await conn.handleControl(message) @@ -195,12 +212,12 @@ proc readLoop(conn: DataChannelConnection) {.async.} = discard proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} = - if isNil(conn.readLoopFut): - conn.readLoopFut = conn.readLoop() return await conn.incomingStreams.popFirst() proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection = - DataChannelConnection( + result = DataChannelConnection( conn: conn, - incomingStreams: newAsyncQueue[DataChannelStream]() + incomingStreams: newAsyncQueue[DataChannelStream](), + streamId: 1'u16 # TODO: Serveur == 1, client == 2 ) + conn.readLoopFut = conn.readLoop() diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index 099d22e..20fd3ab 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -8,9 +8,10 @@ # those terms. import tables, bitops, posix, strutils, sequtils -import chronos, chronicles, stew/[ranges/ptr_arith, byteutils] +import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2] import usrsctp import dtls/dtls +import binary_serialization export chronicles @@ -37,7 +38,7 @@ type SctpMessage* = ref object data*: seq[byte] - info: sctp_rcvinfo + info: sctp_recvv_rn params*: SctpMessageParameters SctpConn* = ref object @@ -67,6 +68,19 @@ type sentAddress: TransportAddress sentFuture: Future[void] + # Those two objects are only here for debugging purpose + SctpChunk = object + chunkType: uint8 + flag: uint8 + length {.bin_value: it.data.len() + 4.}: uint16 + data {.bin_len: it.length - 4.}: seq[byte] + + SctpPacketStructure = object + srcPort: uint16 + dstPort: uint16 + verifTag: uint32 + checksum: uint32 + const IPPROTO_SCTP = 132 @@ -86,6 +100,19 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = proc perror(error: cstring) {.importc, cdecl, header: "".} proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} +proc printSctpPacket(buffer: seq[byte]) = + let s = Binary.decode(buffer, SctpPacketStructure) + echo " => \e[31;1mStructure\e[0m: ", s + var size = sizeof(SctpPacketStructure) + var i = 1 + while size < buffer.len: + let c = Binary.decode(buffer[size..^1], SctpChunk) + echo " ===> \e[32;1mChunk ", i, "\e[0m ", c + i.inc() + size.inc(c.length.int) + while size mod 4 != 0: + size.inc() + proc packetPretty(packet: cstring): string = let data = $packet let ctn = data[23..^16] @@ -137,22 +164,23 @@ proc write*( self.sctp.sentConnection = self self.sctp.sentAddress = self.address - let + var cpy = buf + var (sendInfo, infoType) = if sendParams != default(SctpMessageParameters): (sctp_sndinfo( snd_sid: sendParams.streamId, - snd_ppid: sendParams.protocolId, + snd_ppid: sendParams.protocolId.swapBytes(), snd_flags: sendParams.toFlags ), cuint(SCTP_SENDV_SNDINFO)) else: (default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO)) sendvErr = self.usrsctpAwait: - self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint, - nil, 0, unsafeAddr sendInfo, sizeof(sendInfo).SockLen, + self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0, + cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, infoType, 0) if sendvErr < 0: - perror("usrsctp_sendv") + perror("usrsctp_sendv") # TODO: throw an exception trace "write sendv error?", sendvErr, sendParams proc write*(self: SctpConn, s: string) {.async.} = @@ -182,7 +210,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = address: Sockaddr_storage rn: sctp_recvv_rn addressLen = sizeof(Sockaddr_storage).SockLen - rnLen = sizeof(message.info).SockLen + rnLen = sizeof(sctp_recvv_rn).SockLen infotype: uint flags: int trace "recv from", sockuint64=cast[uint64](sock) @@ -197,11 +225,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = perror("usrsctp_recvv") return elif n > 0: - if infotype == SCTP_RECVV_RCVINFO: - message.params = SctpMessageParameters( - #TODO endianness? - protocolId: message.info.rcv_ppid, - streamId: message.info.rcv_sid + # It might be necessary to check if infotype == SCTP_RECVV_RCVINFO + message.data.delete(n..\e[0m" let res = Stun.getResponse(msg, self.laddr) if res.isSome(): await self.conn.write(raddr, res.get()) diff --git a/webrtc/udp_connection.nim b/webrtc/udp_connection.nim index adf2463..8c873f7 100644 --- a/webrtc/udp_connection.nim +++ b/webrtc/udp_connection.nim @@ -24,7 +24,7 @@ proc init*(self: UdpConn, laddr: TransportAddress) = proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} = let msg = udp.getMessage() - echo "\e[33m\e[0;1m onReceive\e[0m: ", msg.len() + echo "\e[33m\e[0;1m onReceive\e[0m" self.dataRecv.addLastNoWait((msg, address)) self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()