Sctp comments + refacto

This commit is contained in:
Ludovic Chenut 2024-03-06 13:47:32 +01:00
parent 359a81df4a
commit f49ca90491
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
2 changed files with 88 additions and 83 deletions

View File

@ -20,6 +20,7 @@ import mbedtls/md
import chronicles import chronicles
# This sequence is used for debugging.
const mb_ssl_states* = @[ const mb_ssl_states* = @[
"MBEDTLS_SSL_HELLO_REQUEST", "MBEDTLS_SSL_HELLO_REQUEST",
"MBEDTLS_SSL_CLIENT_HELLO", "MBEDTLS_SSL_CLIENT_HELLO",
@ -53,14 +54,6 @@ const mb_ssl_states* = @[
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH"
] ]
proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context =
var key = pk
case mbedtls_pk_get_type(addr key)
of MBEDTLS_PK_RSA:
return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx)
else:
return nil
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
var res: mbedtls_pk_context var res: mbedtls_pk_context
mb_pk_init(res) mb_pk_init(res)
@ -72,6 +65,7 @@ template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
template generateCertificate*(random: mbedtls_ctr_drbg_context, template generateCertificate*(random: mbedtls_ctr_drbg_context,
issuer_key: mbedtls_pk_context): mbedtls_x509_crt = issuer_key: mbedtls_pk_context): mbedtls_x509_crt =
let let
# To be honest, I have no clue what to put here as a name
name = "C=FR,O=Status,CN=webrtc" name = "C=FR,O=Status,CN=webrtc"
time_format = initTimeFormat("YYYYMMddHHmmss") time_format = initTimeFormat("YYYYMMddHHmmss")
time_from = times.now().format(time_format) time_from = times.now().format(time_format)

View File

@ -18,9 +18,18 @@ export chronicles
logScope: logScope:
topics = "webrtc sctp" topics = "webrtc sctp"
# Implementation of an Sctp client and server using the usrsctp library.
# Usrsctp is usable as a single thread but it's not the intended way to
# use it. There's a lot of callbacks calling each other in a synchronous
# way where we want to be able to call asynchronous procedure, but cannot.
# TODO: # TODO:
# - Replace doAssert by a proper exception management # - Replace doAssert by a proper exception management
# - Find a clean way to manage SCTP ports # - Find a clean way to manage SCTP ports
# - Unregister address when closing
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
type type
SctpError* = object of CatchableError SctpError* = object of CatchableError
@ -67,26 +76,43 @@ type
sentAddress: TransportAddress sentAddress: TransportAddress
sentFuture: Future[void] sentFuture: Future[void]
# Those two objects are only here for debugging purpose # These three objects are used for debugging/trace only
SctpChunk = object SctpChunk = object
chunkType: uint8 chunkType: uint8
flag: uint8 flag: uint8
length {.bin_value: it.data.len() + 4.}: uint16 length {.bin_value: it.data.len() + 4.}: uint16
data {.bin_len: it.length - 4.}: seq[byte] data {.bin_len: it.length - 4.}: seq[byte]
SctpPacketStructure = object SctpPacketHeader = object
srcPort: uint16 srcPort: uint16
dstPort: uint16 dstPort: uint16
verifTag: uint32 verifTag: uint32
checksum: uint32 checksum: uint32
const SctpPacketStructure = object
IPPROTO_SCTP = 132 header: SctpPacketHeader
chunks: seq[SctpChunk]
proc newSctpError(msg: string): ref SctpError = const IPPROTO_SCTP = 132
result = newException(SctpError, msg)
proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
# Only used for debugging/trace
result.header = Binary.decode(buffer, SctpPacketHeader)
var size = sizeof(SctpPacketStructure)
while size < buffer.len:
let chunk = Binary.decode(buffer[size..^1], SctpChunk)
result.chunks.add(chunk)
size.inc(chunk.length.int)
while size mod 4 != 0:
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
size.inc(1)
# -- Asynchronous wrapper --
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendCallback` is called, then `sentFuture` is set and waited.
self.sentFuture = nil self.sentFuture = nil
when type(body) is void: when type(body) is void:
body body
@ -96,45 +122,7 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
if self.sentFuture != nil: await self.sentFuture if self.sentFuture != nil: await self.sentFuture
res res
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".} # -- SctpConn --
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
proc printSctpPacket(buffer: seq[byte]) =
let s = Binary.decode(buffer, SctpPacketStructure)
echo " => \e[31;1mStructure\e[0m: ", s
var size = sizeof(SctpPacketStructure)
var i = 1
while size < buffer.len:
let c = Binary.decode(buffer[size..^1], SctpChunk)
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
i.inc()
size.inc(c.length.int)
while size mod 4 != 0:
size.inc()
proc packetPretty(packet: cstring): string =
let data = $packet
let ctn = data[23..^16]
result = data[1..14]
if ctn.len > 30:
result = result & ctn[0..14] & " ... " & ctn[^14..^1]
else:
result = result & ctn
proc new(T: typedesc[SctpConn],
sctp: Sctp,
udp: DatagramTransport,
address: TransportAddress,
sctpSocket: ptr socket): T =
T(sctp: sctp,
state: Connecting,
udp: udp,
address: address,
sctpSocket: sctpSocket,
connectEvent: AsyncEvent(),
#TODO add some limit for backpressure?
dataRecv: newAsyncQueue[SctpMessage]()
)
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T = proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
T(conn: conn, T(conn: conn,
@ -142,10 +130,12 @@ proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
state: Connecting, state: Connecting,
connectEvent: AsyncEvent(), connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(), acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure? dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
) )
proc read*(self: SctpConn): Future[SctpMessage] {.async.} = proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
return await self.dataRecv.popFirst() return await self.dataRecv.popFirst()
proc toFlags(params: SctpMessageParameters): uint16 = proc toFlags(params: SctpMessageParameters): uint16 =
@ -154,23 +144,24 @@ proc toFlags(params: SctpMessageParameters): uint16 =
if params.unordered: if params.unordered:
result = result or SCTP_UNORDERED result = result or SCTP_UNORDERED
proc write*( proc write*(self: SctpConn, buf: seq[byte],
self: SctpConn, sendParams = default(SctpMessageParameters)) {.async.} =
buf: seq[byte], # Used by DataChannel, writes buf on the Dtls connection.
sendParams = default(SctpMessageParameters), trace "Write", buf
) {.async.} =
trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket)
self.sctp.sentAddress = self.address self.sctp.sentAddress = self.address
var cpy = buf var cpy = buf
let sendvErr = let sendvErr =
if sendParams == default(SctpMessageParameters): if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait: self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0) nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else: else:
let sendInfo = sctp_sndinfo( let sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId, snd_sid: sendParams.streamId,
# TODO: swapBytes => htonl?
snd_ppid: sendParams.protocolId.swapBytes(), snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags) snd_flags: sendParams.toFlags)
self.usrsctpAwait: self.usrsctpAwait:
@ -178,29 +169,26 @@ proc write*(
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0) SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0: if sendvErr < 0:
perror("usrsctp_sendv") # TODO: throw an exception # TODO: throw an exception
trace "write sendv error?", sendvErr, sendParams perror("usrsctp_sendv")
proc write*(self: SctpConn, s: string) {.async.} = proc write*(self: SctpConn, s: string) {.async.} =
await self.write(s.toBytes()) await self.write(s.toBytes())
proc close*(self: SctpConn) {.async.} = proc close*(self: SctpConn) {.async.} =
self.usrsctpAwait: self.sctpSocket.usrsctp_close() self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
# -- usrsctp receive data callbacks --
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when we receive data after
# connection has been established.
let let
conn = cast[SctpConn](data) conn = cast[SctpConn](data)
events = usrsctp_get_events(sock) events = usrsctp_get_events(sock)
trace "Handle Upcall", events, state = conn.state trace "Handle Upcall", events
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
conn.state = Closed
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
conn.connectEvent.fire()
if bitand(events, SCTP_EVENT_READ) != 0: if bitand(events, SCTP_EVENT_READ) != 0:
var var
message = SctpMessage( message = SctpMessage(
@ -212,8 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
rnLen = sizeof(sctp_recvv_rn).SockLen rnLen = sizeof(sctp_recvv_rn).SockLen
infotype: uint infotype: uint
flags: int flags: int
trace "recv from", sockuint64=cast[uint64](sock) let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]),
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint, message.data.len.uint,
cast[ptr SockAddr](addr address), cast[ptr SockAddr](addr address),
cast[ptr SockLen](addr addressLen), cast[ptr SockLen](addr addressLen),
cast[pointer](addr message.info), cast[pointer](addr message.info),
@ -239,11 +227,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
except AsyncQueueFullError: except AsyncQueueFullError:
trace "Queue full, dropping packet" trace "Queue full, dropping packet"
elif bitand(events, SCTP_EVENT_WRITE) != 0: elif bitand(events, SCTP_EVENT_WRITE) != 0:
trace "sctp event write in the upcall" debug "sctp event write in the upcall"
else: else:
warn "Handle Upcall unexpected event", events warn "Handle Upcall unexpected event", events
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when accepting a connection.
trace "Handle Accept" trace "Handle Accept"
var var
sconn: Sockaddr_conn sconn: Sockaddr_conn
@ -266,6 +255,27 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
addr recvinfo, sizeof(recvinfo).SockLen) addr recvinfo, sizeof(recvinfo).SockLen)
conn.acceptEvent.fire() conn.acceptEvent.fire()
proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when connecting
trace "Handle Connect"
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events, state = conn.state
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
conn.state = Closed
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data)
conn.connectEvent.fire()
else:
warn "should be connecting", currentState = conn.state
# -- usrsctp send data callback --
proc sendCallback(ctx: pointer, proc sendCallback(ctx: pointer,
buffer: pointer, buffer: pointer,
length: uint, length: uint,
@ -273,20 +283,20 @@ proc sendCallback(ctx: pointer,
set_df: uint8): cint {.cdecl.} = set_df: uint8): cint {.cdecl.} =
let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND) let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND)
if data != nil: if data != nil:
trace "sendCallback", data = data.packetPretty(), length trace "sendCallback", sctpPacket = data.getSctpPacket(), length
usrsctp_freedumpbuffer(data) usrsctp_freedumpbuffer(data)
let sctpConn = cast[SctpConn](ctx) let sctpConn = cast[SctpConn](ctx)
let buf = @(buffer.makeOpenArray(byte, int(length))) let buf = @(buffer.makeOpenArray(byte, int(length)))
proc testSend() {.async.} = proc testSend() {.async.} =
try: try:
trace "Send To", address = sctpConn.address trace "Send To", address = sctpConn.address
# printSctpPacket(buf)
# TODO: defined it printSctpPacket(buf)
await sctpConn.conn.write(buf) await sctpConn.conn.write(buf)
except CatchableError as exc: except CatchableError as exc:
trace "Send Failed", message = exc.msg trace "Send Failed", message = exc.msg
sctpConn.sentFuture = testSend() sctpConn.sentFuture = testSend()
# -- Sctp --
proc timersHandler() {.async.} = proc timersHandler() {.async.} =
while true: while true:
await sleepAsync(500.milliseconds) await sleepAsync(500.milliseconds)
@ -315,6 +325,7 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T =
return sctp return sctp
proc stop*(self: Sctp) {.async.} = proc stop*(self: Sctp) {.async.} =
# TODO: close every connections
discard self.usrsctpAwait usrsctp_finish() discard self.usrsctpAwait usrsctp_finish()
self.udp.close() self.udp.close()
@ -324,14 +335,14 @@ proc readLoopProc(res: SctpConn) {.async.} =
msg = await res.conn.read() msg = await res.conn.read()
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND) data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
if not data.isNil(): if not data.isNil():
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty() trace "Receive data", remoteAddress = res.conn.raddr,
sctpPacket = data.getSctpPacket()
usrsctp_freedumpbuffer(data) usrsctp_freedumpbuffer(data)
# printSctpPacket(msg) TODO: defined it
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0) usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
proc accept*(self: Sctp): Future[SctpConn] {.async.} = proc accept*(self: Sctp): Future[SctpConn] {.async.} =
if not self.isServer: if not self.isServer:
raise newSctpError("Not a server") raise newException(SctpError, "Not a server")
var res = SctpConn.new(await self.dtls.accept(), self) var res = SctpConn.new(await self.dtls.accept(), self)
usrsctp_register_address(cast[pointer](res)) usrsctp_register_address(cast[pointer](res))
res.readLoop = res.readLoopProc() res.readLoop = res.readLoopProc()
@ -373,7 +384,7 @@ proc connect*(self: Sctp,
var nodelay: uint32 = 1 var nodelay: uint32 = 1
var recvinfo: uint32 = 1 var recvinfo: uint32 = 1
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1) doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn)) doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen) addr nodelay, sizeof(nodelay).SockLen)
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO, doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
@ -391,6 +402,6 @@ proc connect*(self: Sctp,
conn.state = Connecting conn.state = Connecting
conn.connectEvent.clear() conn.connectEvent.clear()
await conn.connectEvent.wait() await conn.connectEvent.wait()
# TODO: check connection state, if closed throw some exception I guess # TODO: check connection state, if closed throw an exception
self.connections[address] = conn self.connections[address] = conn
return conn return conn