A lot of fixes

This commit is contained in:
Ludovic Chenut 2024-02-23 11:06:59 +01:00
parent 9a6657922a
commit 2591a158ba
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
4 changed files with 75 additions and 27 deletions

View File

@ -72,13 +72,17 @@ type
DataChannelConnection* = ref object DataChannelConnection* = ref object
readLoopFut: Future[void] readLoopFut: Future[void]
streams: Table[uint16, DataChannelStream] streams: Table[uint16, DataChannelStream]
streamId: uint16
conn*: SctpConn conn*: SctpConn
incomingStreams: AsyncQueue[DataChannelStream] incomingStreams: AsyncQueue[DataChannelStream]
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} = proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
return await stream.receivedData.popLast() let x = await stream.receivedData.popFirst()
trace "read", length=x.len(), id=stream.id
return x
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} = proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
trace "write", length=buf.len(), id=stream.id
var var
sendInfo = SctpMessageParameters( sendInfo = SctpMessageParameters(
streamId: stream.id, streamId: stream.id,
@ -105,14 +109,23 @@ proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.as
endOfRecord: true, endOfRecord: true,
protocolId: uint32(WebRtcDcep) protocolId: uint32(WebRtcDcep)
) )
trace "send control message", msg
await stream.conn.write(encoded, sendInfo) await stream.conn.write(encoded, sendInfo)
proc openStream*( proc openStream*(
conn: DataChannelConnection, conn: DataChannelConnection,
streamId: uint16, noiseHandshake: bool,
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} = reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
let streamId: uint16 =
if not noiseHandshake:
let res = conn.streamId
conn.streamId += 2
res
else:
0
trace "open stream", streamId
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0: if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
raise newException(ValueError, "reliabilityParameter should be 0") raise newException(ValueError, "reliabilityParameter should be 0")
@ -144,6 +157,7 @@ proc openStream*(
proc handleData(conn: DataChannelConnection, msg: SctpMessage) = proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
let streamId = msg.params.streamId let streamId = msg.params.streamId
trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data
if streamId notin conn.streams: if streamId notin conn.streams:
raise newException(ValueError, "got data for unknown streamid") raise newException(ValueError, "got data for unknown streamid")
@ -162,6 +176,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
decoded = Binary.decode(msg.data, DataChannelMessage) decoded = Binary.decode(msg.data, DataChannelMessage)
streamId = msg.params.streamId streamId = msg.params.streamId
trace "handle control message", decoded, streamId = msg.params.streamId
if decoded.messageType == Ack: if decoded.messageType == Ack:
if streamId notin conn.streams: if streamId notin conn.streams:
raise newException(ValueError, "got ack for unknown streamid") raise newException(ValueError, "got ack for unknown streamid")
@ -178,6 +193,7 @@ proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
) )
conn.streams[streamId] = stream conn.streams[streamId] = stream
conn.incomingStreams.addLastNoWait(stream)
await stream.sendControlMessage(DataChannelMessage(messageType: Ack)) await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
@ -185,6 +201,7 @@ proc readLoop(conn: DataChannelConnection) {.async.} =
try: try:
while true: while true:
let message = await conn.conn.read() let message = await conn.conn.read()
# TODO: might be necessary to check the others protocolId at some point
if message.params.protocolId == uint32(WebRtcDcep): if message.params.protocolId == uint32(WebRtcDcep):
#TODO should we really await? #TODO should we really await?
await conn.handleControl(message) await conn.handleControl(message)
@ -195,12 +212,12 @@ proc readLoop(conn: DataChannelConnection) {.async.} =
discard discard
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} = proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
if isNil(conn.readLoopFut):
conn.readLoopFut = conn.readLoop()
return await conn.incomingStreams.popFirst() return await conn.incomingStreams.popFirst()
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection = proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
DataChannelConnection( result = DataChannelConnection(
conn: conn, conn: conn,
incomingStreams: newAsyncQueue[DataChannelStream]() incomingStreams: newAsyncQueue[DataChannelStream](),
streamId: 1'u16 # TODO: Serveur == 1, client == 2
) )
conn.readLoopFut = conn.readLoop()

View File

@ -8,9 +8,10 @@
# those terms. # those terms.
import tables, bitops, posix, strutils, sequtils import tables, bitops, posix, strutils, sequtils
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils] import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
import usrsctp import usrsctp
import dtls/dtls import dtls/dtls
import binary_serialization
export chronicles export chronicles
@ -37,7 +38,7 @@ type
SctpMessage* = ref object SctpMessage* = ref object
data*: seq[byte] data*: seq[byte]
info: sctp_rcvinfo info: sctp_recvv_rn
params*: SctpMessageParameters params*: SctpMessageParameters
SctpConn* = ref object SctpConn* = ref object
@ -67,6 +68,19 @@ type
sentAddress: TransportAddress sentAddress: TransportAddress
sentFuture: Future[void] sentFuture: Future[void]
# Those two objects are only here for debugging purpose
SctpChunk = object
chunkType: uint8
flag: uint8
length {.bin_value: it.data.len() + 4.}: uint16
data {.bin_len: it.length - 4.}: seq[byte]
SctpPacketStructure = object
srcPort: uint16
dstPort: uint16
verifTag: uint32
checksum: uint32
const const
IPPROTO_SCTP = 132 IPPROTO_SCTP = 132
@ -86,6 +100,19 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".} proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.} proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
proc printSctpPacket(buffer: seq[byte]) =
let s = Binary.decode(buffer, SctpPacketStructure)
echo " => \e[31;1mStructure\e[0m: ", s
var size = sizeof(SctpPacketStructure)
var i = 1
while size < buffer.len:
let c = Binary.decode(buffer[size..^1], SctpChunk)
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
i.inc()
size.inc(c.length.int)
while size mod 4 != 0:
size.inc()
proc packetPretty(packet: cstring): string = proc packetPretty(packet: cstring): string =
let data = $packet let data = $packet
let ctn = data[23..^16] let ctn = data[23..^16]
@ -137,22 +164,23 @@ proc write*(
self.sctp.sentConnection = self self.sctp.sentConnection = self
self.sctp.sentAddress = self.address self.sctp.sentAddress = self.address
let var cpy = buf
var
(sendInfo, infoType) = (sendInfo, infoType) =
if sendParams != default(SctpMessageParameters): if sendParams != default(SctpMessageParameters):
(sctp_sndinfo( (sctp_sndinfo(
snd_sid: sendParams.streamId, snd_sid: sendParams.streamId,
snd_ppid: sendParams.protocolId, snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags snd_flags: sendParams.toFlags
), cuint(SCTP_SENDV_SNDINFO)) ), cuint(SCTP_SENDV_SNDINFO))
else: else:
(default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO)) (default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO))
sendvErr = self.usrsctpAwait: sendvErr = self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint, self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0,
nil, 0, unsafeAddr sendInfo, sizeof(sendInfo).SockLen, cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
infoType, 0) infoType, 0)
if sendvErr < 0: if sendvErr < 0:
perror("usrsctp_sendv") perror("usrsctp_sendv") # TODO: throw an exception
trace "write sendv error?", sendvErr, sendParams trace "write sendv error?", sendvErr, sendParams
proc write*(self: SctpConn, s: string) {.async.} = proc write*(self: SctpConn, s: string) {.async.} =
@ -182,7 +210,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
address: Sockaddr_storage address: Sockaddr_storage
rn: sctp_recvv_rn rn: sctp_recvv_rn
addressLen = sizeof(Sockaddr_storage).SockLen addressLen = sizeof(Sockaddr_storage).SockLen
rnLen = sizeof(message.info).SockLen rnLen = sizeof(sctp_recvv_rn).SockLen
infotype: uint infotype: uint
flags: int flags: int
trace "recv from", sockuint64=cast[uint64](sock) trace "recv from", sockuint64=cast[uint64](sock)
@ -197,11 +225,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
perror("usrsctp_recvv") perror("usrsctp_recvv")
return return
elif n > 0: elif n > 0:
if infotype == SCTP_RECVV_RCVINFO: # It might be necessary to check if infotype == SCTP_RECVV_RCVINFO
message.data.delete(n..<message.data.len())
trace "message info from handle upcall", msginfo = message.info
message.params = SctpMessageParameters( message.params = SctpMessageParameters(
#TODO endianness? protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
protocolId: message.info.rcv_ppid, streamId: message.info.recvv_rcvinfo.rcv_sid
streamId: message.info.rcv_sid
) )
if bitand(flags, MSG_NOTIFICATION) != 0: if bitand(flags, MSG_NOTIFICATION) != 0:
trace "Notification received", length = n trace "Notification received", length = n
@ -229,11 +258,12 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
conn.sctpSocket = sctpSocket conn.sctpSocket = sctpSocket
conn.state = Connected conn.state = Connected
var nodelay: uint32 = 1 var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn)) doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
SCTP_NODELAY, addr nodelay, sizeof(nodelay).SockLen)
addr nodelay, doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
sizeof(nodelay).SockLen) addr recvinfo, sizeof(recvinfo).SockLen)
conn.acceptEvent.fire() conn.acceptEvent.fire()
# proc getOrCreateConnection(self: Sctp, # proc getOrCreateConnection(self: Sctp,
@ -276,10 +306,11 @@ proc sendCallback(ctx: pointer,
trace "sendCallback", data = data.packetPretty(), length trace "sendCallback", data = data.packetPretty(), length
usrsctp_freedumpbuffer(data) usrsctp_freedumpbuffer(data)
let sctpConn = cast[SctpConn](ctx) let sctpConn = cast[SctpConn](ctx)
let buf = @(buffer.makeOpenArray(byte, int(length)))
proc testSend() {.async.} = proc testSend() {.async.} =
try: try:
let buf = @(buffer.makeOpenArray(byte, int(length)))
trace "Send To", address = sctpConn.address trace "Send To", address = sctpConn.address
# TODO: defined it printSctpPacket(buf)
await sctpConn.conn.write(buf) await sctpConn.conn.write(buf)
except CatchableError as exc: except CatchableError as exc:
trace "Send Failed", message = exc.msg trace "Send Failed", message = exc.msg
@ -355,11 +386,9 @@ proc stop*(self: Sctp) {.async.} =
proc readLoopProc(res: SctpConn) {.async.} = proc readLoopProc(res: SctpConn) {.async.} =
while true: while true:
trace "Read Loop Proc Before"
let let
msg = await res.conn.read() msg = await res.conn.read()
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
trace "Read Loop Proc Before", isnil=data.isNil()
if not data.isNil(): if not data.isNil():
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty() trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
usrsctp_freedumpbuffer(data) usrsctp_freedumpbuffer(data)
@ -384,6 +413,7 @@ proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
trace "Listening", sctpPort trace "Listening", sctpPort
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2) doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0) doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
doAssert 0 == usrsctp_sysctl_set_sctp_delayed_sack_time_default(0)
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil) let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
var on: int = 1 var on: int = 1
doAssert 0 == usrsctp_set_non_blocking(sock, 1) doAssert 0 == usrsctp_set_non_blocking(sock, 1)

View File

@ -21,6 +21,7 @@ proc handles(self: StunConn) {.async.} =
while true: # TODO: while not self.conn.atEof() while true: # TODO: while not self.conn.atEof()
let (msg, raddr) = await self.conn.read() let (msg, raddr) = await self.conn.read()
if Stun.isMessage(msg): if Stun.isMessage(msg):
echo "\e[35;1m<STUN>\e[0m"
let res = Stun.getResponse(msg, self.laddr) let res = Stun.getResponse(msg, self.laddr)
if res.isSome(): if res.isSome():
await self.conn.write(raddr, res.get()) await self.conn.write(raddr, res.get())

View File

@ -24,7 +24,7 @@ proc init*(self: UdpConn, laddr: TransportAddress) =
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} = proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
let msg = udp.getMessage() let msg = udp.getMessage()
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m: ", msg.len() echo "\e[33m<UDP>\e[0;1m onReceive\e[0m"
self.dataRecv.addLastNoWait((msg, address)) self.dataRecv.addLastNoWait((msg, address))
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]() self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()