Sctp comments + refacto

This commit is contained in:
Ludovic Chenut 2024-03-06 13:47:32 +01:00
parent 359a81df4a
commit f49ca90491
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
2 changed files with 88 additions and 83 deletions

View File

@ -20,6 +20,7 @@ import mbedtls/md
import chronicles
# This sequence is used for debugging.
const mb_ssl_states* = @[
"MBEDTLS_SSL_HELLO_REQUEST",
"MBEDTLS_SSL_CLIENT_HELLO",
@ -53,14 +54,6 @@ const mb_ssl_states* = @[
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH"
]
proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context =
var key = pk
case mbedtls_pk_get_type(addr key)
of MBEDTLS_PK_RSA:
return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx)
else:
return nil
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
var res: mbedtls_pk_context
mb_pk_init(res)
@ -72,6 +65,7 @@ template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
template generateCertificate*(random: mbedtls_ctr_drbg_context,
issuer_key: mbedtls_pk_context): mbedtls_x509_crt =
let
# To be honest, I have no clue what to put here as a name
name = "C=FR,O=Status,CN=webrtc"
time_format = initTimeFormat("YYYYMMddHHmmss")
time_from = times.now().format(time_format)

View File

@ -18,9 +18,18 @@ export chronicles
logScope:
topics = "webrtc sctp"
# Implementation of an Sctp client and server using the usrsctp library.
# Usrsctp is usable as a single thread but it's not the intended way to
# use it. There's a lot of callbacks calling each other in a synchronous
# way where we want to be able to call asynchronous procedure, but cannot.
# TODO:
# - Replace doAssert by a proper exception management
# - Find a clean way to manage SCTP ports
# - Unregister address when closing
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
type
SctpError* = object of CatchableError
@ -67,26 +76,43 @@ type
sentAddress: TransportAddress
sentFuture: Future[void]
# Those two objects are only here for debugging purpose
# These three objects are used for debugging/trace only
SctpChunk = object
chunkType: uint8
flag: uint8
length {.bin_value: it.data.len() + 4.}: uint16
data {.bin_len: it.length - 4.}: seq[byte]
SctpPacketStructure = object
SctpPacketHeader = object
srcPort: uint16
dstPort: uint16
verifTag: uint32
checksum: uint32
const
IPPROTO_SCTP = 132
SctpPacketStructure = object
header: SctpPacketHeader
chunks: seq[SctpChunk]
proc newSctpError(msg: string): ref SctpError =
result = newException(SctpError, msg)
const IPPROTO_SCTP = 132
proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
# Only used for debugging/trace
result.header = Binary.decode(buffer, SctpPacketHeader)
var size = sizeof(SctpPacketStructure)
while size < buffer.len:
let chunk = Binary.decode(buffer[size..^1], SctpChunk)
result.chunks.add(chunk)
size.inc(chunk.length.int)
while size mod 4 != 0:
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
size.inc(1)
# -- Asynchronous wrapper --
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendCallback` is called, then `sentFuture` is set and waited.
self.sentFuture = nil
when type(body) is void:
body
@ -96,45 +122,7 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
if self.sentFuture != nil: await self.sentFuture
res
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
proc printSctpPacket(buffer: seq[byte]) =
let s = Binary.decode(buffer, SctpPacketStructure)
echo " => \e[31;1mStructure\e[0m: ", s
var size = sizeof(SctpPacketStructure)
var i = 1
while size < buffer.len:
let c = Binary.decode(buffer[size..^1], SctpChunk)
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
i.inc()
size.inc(c.length.int)
while size mod 4 != 0:
size.inc()
proc packetPretty(packet: cstring): string =
let data = $packet
let ctn = data[23..^16]
result = data[1..14]
if ctn.len > 30:
result = result & ctn[0..14] & " ... " & ctn[^14..^1]
else:
result = result & ctn
proc new(T: typedesc[SctpConn],
sctp: Sctp,
udp: DatagramTransport,
address: TransportAddress,
sctpSocket: ptr socket): T =
T(sctp: sctp,
state: Connecting,
udp: udp,
address: address,
sctpSocket: sctpSocket,
connectEvent: AsyncEvent(),
#TODO add some limit for backpressure?
dataRecv: newAsyncQueue[SctpMessage]()
)
# -- SctpConn --
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
T(conn: conn,
@ -142,10 +130,12 @@ proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
state: Connecting,
connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure?
dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
)
proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
return await self.dataRecv.popFirst()
proc toFlags(params: SctpMessageParameters): uint16 =
@ -154,23 +144,24 @@ proc toFlags(params: SctpMessageParameters): uint16 =
if params.unordered:
result = result or SCTP_UNORDERED
proc write*(
self: SctpConn,
buf: seq[byte],
sendParams = default(SctpMessageParameters),
) {.async.} =
trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket)
proc write*(self: SctpConn, buf: seq[byte],
sendParams = default(SctpMessageParameters)) {.async.} =
# Used by DataChannel, writes buf on the Dtls connection.
trace "Write", buf
self.sctp.sentAddress = self.address
var cpy = buf
let sendvErr =
if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else:
let sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
# TODO: swapBytes => htonl?
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags)
self.usrsctpAwait:
@ -178,29 +169,26 @@ proc write*(
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0:
perror("usrsctp_sendv") # TODO: throw an exception
trace "write sendv error?", sendvErr, sendParams
# TODO: throw an exception
perror("usrsctp_sendv")
proc write*(self: SctpConn, s: string) {.async.} =
await self.write(s.toBytes())
proc close*(self: SctpConn) {.async.} =
self.usrsctpAwait: self.sctpSocket.usrsctp_close()
self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
# -- usrsctp receive data callbacks --
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when we receive data after
# connection has been established.
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events, state = conn.state
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
conn.state = Closed
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
conn.connectEvent.fire()
trace "Handle Upcall", events
if bitand(events, SCTP_EVENT_READ) != 0:
var
message = SctpMessage(
@ -212,8 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
rnLen = sizeof(sctp_recvv_rn).SockLen
infotype: uint
flags: int
trace "recv from", sockuint64=cast[uint64](sock)
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint,
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]),
message.data.len.uint,
cast[ptr SockAddr](addr address),
cast[ptr SockLen](addr addressLen),
cast[pointer](addr message.info),
@ -239,11 +227,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
except AsyncQueueFullError:
trace "Queue full, dropping packet"
elif bitand(events, SCTP_EVENT_WRITE) != 0:
trace "sctp event write in the upcall"
debug "sctp event write in the upcall"
else:
warn "Handle Upcall unexpected event", events
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when accepting a connection.
trace "Handle Accept"
var
sconn: Sockaddr_conn
@ -266,6 +255,27 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
addr recvinfo, sizeof(recvinfo).SockLen)
conn.acceptEvent.fire()
proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when connecting
trace "Handle Connect"
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events, state = conn.state
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
conn.state = Closed
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data)
conn.connectEvent.fire()
else:
warn "should be connecting", currentState = conn.state
# -- usrsctp send data callback --
proc sendCallback(ctx: pointer,
buffer: pointer,
length: uint,
@ -273,20 +283,20 @@ proc sendCallback(ctx: pointer,
set_df: uint8): cint {.cdecl.} =
let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND)
if data != nil:
trace "sendCallback", data = data.packetPretty(), length
trace "sendCallback", sctpPacket = data.getSctpPacket(), length
usrsctp_freedumpbuffer(data)
let sctpConn = cast[SctpConn](ctx)
let buf = @(buffer.makeOpenArray(byte, int(length)))
proc testSend() {.async.} =
try:
trace "Send To", address = sctpConn.address
# printSctpPacket(buf)
# TODO: defined it printSctpPacket(buf)
await sctpConn.conn.write(buf)
except CatchableError as exc:
trace "Send Failed", message = exc.msg
sctpConn.sentFuture = testSend()
# -- Sctp --
proc timersHandler() {.async.} =
while true:
await sleepAsync(500.milliseconds)
@ -315,6 +325,7 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T =
return sctp
proc stop*(self: Sctp) {.async.} =
# TODO: close every connections
discard self.usrsctpAwait usrsctp_finish()
self.udp.close()
@ -324,14 +335,14 @@ proc readLoopProc(res: SctpConn) {.async.} =
msg = await res.conn.read()
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
if not data.isNil():
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
trace "Receive data", remoteAddress = res.conn.raddr,
sctpPacket = data.getSctpPacket()
usrsctp_freedumpbuffer(data)
# printSctpPacket(msg) TODO: defined it
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
if not self.isServer:
raise newSctpError("Not a server")
raise newException(SctpError, "Not a server")
var res = SctpConn.new(await self.dtls.accept(), self)
usrsctp_register_address(cast[pointer](res))
res.readLoop = res.readLoopProc()
@ -373,7 +384,7 @@ proc connect*(self: Sctp,
var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen)
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
@ -391,6 +402,6 @@ proc connect*(self: Sctp,
conn.state = Connecting
conn.connectEvent.clear()
await conn.connectEvent.wait()
# TODO: check connection state, if closed throw some exception I guess
# TODO: check connection state, if closed throw an exception
self.connections[address] = conn
return conn