Sctp comments + refacto
This commit is contained in:
parent
359a81df4a
commit
f49ca90491
|
@ -20,6 +20,7 @@ import mbedtls/md
|
|||
|
||||
import chronicles
|
||||
|
||||
# This sequence is used for debugging.
|
||||
const mb_ssl_states* = @[
|
||||
"MBEDTLS_SSL_HELLO_REQUEST",
|
||||
"MBEDTLS_SSL_CLIENT_HELLO",
|
||||
|
@ -53,14 +54,6 @@ const mb_ssl_states* = @[
|
|||
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH"
|
||||
]
|
||||
|
||||
proc mbedtls_pk_rsa*(pk: mbedtls_pk_context): ptr mbedtls_rsa_context =
|
||||
var key = pk
|
||||
case mbedtls_pk_get_type(addr key)
|
||||
of MBEDTLS_PK_RSA:
|
||||
return cast[ptr mbedtls_rsa_context](pk.private_pk_ctx)
|
||||
else:
|
||||
return nil
|
||||
|
||||
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
||||
var res: mbedtls_pk_context
|
||||
mb_pk_init(res)
|
||||
|
@ -72,6 +65,7 @@ template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
|||
template generateCertificate*(random: mbedtls_ctr_drbg_context,
|
||||
issuer_key: mbedtls_pk_context): mbedtls_x509_crt =
|
||||
let
|
||||
# To be honest, I have no clue what to put here as a name
|
||||
name = "C=FR,O=Status,CN=webrtc"
|
||||
time_format = initTimeFormat("YYYYMMddHHmmss")
|
||||
time_from = times.now().format(time_format)
|
||||
|
|
161
webrtc/sctp.nim
161
webrtc/sctp.nim
|
@ -18,9 +18,18 @@ export chronicles
|
|||
logScope:
|
||||
topics = "webrtc sctp"
|
||||
|
||||
# Implementation of an Sctp client and server using the usrsctp library.
|
||||
# Usrsctp is usable as a single thread but it's not the intended way to
|
||||
# use it. There's a lot of callbacks calling each other in a synchronous
|
||||
# way where we want to be able to call asynchronous procedure, but cannot.
|
||||
|
||||
# TODO:
|
||||
# - Replace doAssert by a proper exception management
|
||||
# - Find a clean way to manage SCTP ports
|
||||
# - Unregister address when closing
|
||||
|
||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||
|
||||
type
|
||||
SctpError* = object of CatchableError
|
||||
|
@ -67,26 +76,43 @@ type
|
|||
sentAddress: TransportAddress
|
||||
sentFuture: Future[void]
|
||||
|
||||
# Those two objects are only here for debugging purpose
|
||||
# These three objects are used for debugging/trace only
|
||||
SctpChunk = object
|
||||
chunkType: uint8
|
||||
flag: uint8
|
||||
length {.bin_value: it.data.len() + 4.}: uint16
|
||||
data {.bin_len: it.length - 4.}: seq[byte]
|
||||
|
||||
SctpPacketStructure = object
|
||||
SctpPacketHeader = object
|
||||
srcPort: uint16
|
||||
dstPort: uint16
|
||||
verifTag: uint32
|
||||
checksum: uint32
|
||||
|
||||
const
|
||||
IPPROTO_SCTP = 132
|
||||
SctpPacketStructure = object
|
||||
header: SctpPacketHeader
|
||||
chunks: seq[SctpChunk]
|
||||
|
||||
proc newSctpError(msg: string): ref SctpError =
|
||||
result = newException(SctpError, msg)
|
||||
const IPPROTO_SCTP = 132
|
||||
|
||||
proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
|
||||
# Only used for debugging/trace
|
||||
result.header = Binary.decode(buffer, SctpPacketHeader)
|
||||
var size = sizeof(SctpPacketStructure)
|
||||
while size < buffer.len:
|
||||
let chunk = Binary.decode(buffer[size..^1], SctpChunk)
|
||||
result.chunks.add(chunk)
|
||||
size.inc(chunk.length.int)
|
||||
while size mod 4 != 0:
|
||||
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
|
||||
size.inc(1)
|
||||
|
||||
# -- Asynchronous wrapper --
|
||||
|
||||
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
|
||||
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
|
||||
# an usrsctp function. If during the synchronous run of the usrsctp function
|
||||
# `sendCallback` is called, then `sentFuture` is set and waited.
|
||||
self.sentFuture = nil
|
||||
when type(body) is void:
|
||||
body
|
||||
|
@ -96,45 +122,7 @@ template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
|
|||
if self.sentFuture != nil: await self.sentFuture
|
||||
res
|
||||
|
||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||
|
||||
proc printSctpPacket(buffer: seq[byte]) =
|
||||
let s = Binary.decode(buffer, SctpPacketStructure)
|
||||
echo " => \e[31;1mStructure\e[0m: ", s
|
||||
var size = sizeof(SctpPacketStructure)
|
||||
var i = 1
|
||||
while size < buffer.len:
|
||||
let c = Binary.decode(buffer[size..^1], SctpChunk)
|
||||
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
|
||||
i.inc()
|
||||
size.inc(c.length.int)
|
||||
while size mod 4 != 0:
|
||||
size.inc()
|
||||
|
||||
proc packetPretty(packet: cstring): string =
|
||||
let data = $packet
|
||||
let ctn = data[23..^16]
|
||||
result = data[1..14]
|
||||
if ctn.len > 30:
|
||||
result = result & ctn[0..14] & " ... " & ctn[^14..^1]
|
||||
else:
|
||||
result = result & ctn
|
||||
|
||||
proc new(T: typedesc[SctpConn],
|
||||
sctp: Sctp,
|
||||
udp: DatagramTransport,
|
||||
address: TransportAddress,
|
||||
sctpSocket: ptr socket): T =
|
||||
T(sctp: sctp,
|
||||
state: Connecting,
|
||||
udp: udp,
|
||||
address: address,
|
||||
sctpSocket: sctpSocket,
|
||||
connectEvent: AsyncEvent(),
|
||||
#TODO add some limit for backpressure?
|
||||
dataRecv: newAsyncQueue[SctpMessage]()
|
||||
)
|
||||
# -- SctpConn --
|
||||
|
||||
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
|
||||
T(conn: conn,
|
||||
|
@ -142,10 +130,12 @@ proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
|
|||
state: Connecting,
|
||||
connectEvent: AsyncEvent(),
|
||||
acceptEvent: AsyncEvent(),
|
||||
dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure?
|
||||
dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
|
||||
)
|
||||
|
||||
proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
|
||||
# Used by DataChannel, returns SctpMessage in order to get the stream
|
||||
# and protocol ids
|
||||
return await self.dataRecv.popFirst()
|
||||
|
||||
proc toFlags(params: SctpMessageParameters): uint16 =
|
||||
|
@ -154,23 +144,24 @@ proc toFlags(params: SctpMessageParameters): uint16 =
|
|||
if params.unordered:
|
||||
result = result or SCTP_UNORDERED
|
||||
|
||||
proc write*(
|
||||
self: SctpConn,
|
||||
buf: seq[byte],
|
||||
sendParams = default(SctpMessageParameters),
|
||||
) {.async.} =
|
||||
trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket)
|
||||
proc write*(self: SctpConn, buf: seq[byte],
|
||||
sendParams = default(SctpMessageParameters)) {.async.} =
|
||||
# Used by DataChannel, writes buf on the Dtls connection.
|
||||
trace "Write", buf
|
||||
self.sctp.sentAddress = self.address
|
||||
|
||||
var cpy = buf
|
||||
let sendvErr =
|
||||
if sendParams == default(SctpMessageParameters):
|
||||
# If writes is called by DataChannel, sendParams should never
|
||||
# be the default value. This split is useful for testing.
|
||||
self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
|
||||
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
|
||||
else:
|
||||
let sendInfo = sctp_sndinfo(
|
||||
snd_sid: sendParams.streamId,
|
||||
# TODO: swapBytes => htonl?
|
||||
snd_ppid: sendParams.protocolId.swapBytes(),
|
||||
snd_flags: sendParams.toFlags)
|
||||
self.usrsctpAwait:
|
||||
|
@ -178,29 +169,26 @@ proc write*(
|
|||
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
|
||||
SCTP_SENDV_SNDINFO.cuint, 0)
|
||||
if sendvErr < 0:
|
||||
perror("usrsctp_sendv") # TODO: throw an exception
|
||||
trace "write sendv error?", sendvErr, sendParams
|
||||
# TODO: throw an exception
|
||||
perror("usrsctp_sendv")
|
||||
|
||||
proc write*(self: SctpConn, s: string) {.async.} =
|
||||
await self.write(s.toBytes())
|
||||
|
||||
proc close*(self: SctpConn) {.async.} =
|
||||
self.usrsctpAwait: self.sctpSocket.usrsctp_close()
|
||||
self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_close()
|
||||
|
||||
# -- usrsctp receive data callbacks --
|
||||
|
||||
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when we receive data after
|
||||
# connection has been established.
|
||||
let
|
||||
conn = cast[SctpConn](data)
|
||||
events = usrsctp_get_events(sock)
|
||||
|
||||
trace "Handle Upcall", events, state = conn.state
|
||||
if conn.state == Connecting:
|
||||
if bitand(events, SCTP_EVENT_ERROR) != 0:
|
||||
warn "Cannot connect", address = conn.address
|
||||
conn.state = Closed
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
conn.state = Connected
|
||||
conn.connectEvent.fire()
|
||||
|
||||
trace "Handle Upcall", events
|
||||
if bitand(events, SCTP_EVENT_READ) != 0:
|
||||
var
|
||||
message = SctpMessage(
|
||||
|
@ -212,8 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
rnLen = sizeof(sctp_recvv_rn).SockLen
|
||||
infotype: uint
|
||||
flags: int
|
||||
trace "recv from", sockuint64=cast[uint64](sock)
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint,
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]),
|
||||
message.data.len.uint,
|
||||
cast[ptr SockAddr](addr address),
|
||||
cast[ptr SockLen](addr addressLen),
|
||||
cast[pointer](addr message.info),
|
||||
|
@ -239,11 +227,12 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
except AsyncQueueFullError:
|
||||
trace "Queue full, dropping packet"
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
trace "sctp event write in the upcall"
|
||||
debug "sctp event write in the upcall"
|
||||
else:
|
||||
warn "Handle Upcall unexpected event", events
|
||||
|
||||
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when accepting a connection.
|
||||
trace "Handle Accept"
|
||||
var
|
||||
sconn: Sockaddr_conn
|
||||
|
@ -266,6 +255,27 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
addr recvinfo, sizeof(recvinfo).SockLen)
|
||||
conn.acceptEvent.fire()
|
||||
|
||||
proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when connecting
|
||||
trace "Handle Connect"
|
||||
let
|
||||
conn = cast[SctpConn](data)
|
||||
events = usrsctp_get_events(sock)
|
||||
|
||||
trace "Handle Upcall", events, state = conn.state
|
||||
if conn.state == Connecting:
|
||||
if bitand(events, SCTP_EVENT_ERROR) != 0:
|
||||
warn "Cannot connect", address = conn.address
|
||||
conn.state = Closed
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
conn.state = Connected
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data)
|
||||
conn.connectEvent.fire()
|
||||
else:
|
||||
warn "should be connecting", currentState = conn.state
|
||||
|
||||
# -- usrsctp send data callback --
|
||||
|
||||
proc sendCallback(ctx: pointer,
|
||||
buffer: pointer,
|
||||
length: uint,
|
||||
|
@ -273,20 +283,20 @@ proc sendCallback(ctx: pointer,
|
|||
set_df: uint8): cint {.cdecl.} =
|
||||
let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND)
|
||||
if data != nil:
|
||||
trace "sendCallback", data = data.packetPretty(), length
|
||||
trace "sendCallback", sctpPacket = data.getSctpPacket(), length
|
||||
usrsctp_freedumpbuffer(data)
|
||||
let sctpConn = cast[SctpConn](ctx)
|
||||
let buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||
proc testSend() {.async.} =
|
||||
try:
|
||||
trace "Send To", address = sctpConn.address
|
||||
# printSctpPacket(buf)
|
||||
# TODO: defined it printSctpPacket(buf)
|
||||
await sctpConn.conn.write(buf)
|
||||
except CatchableError as exc:
|
||||
trace "Send Failed", message = exc.msg
|
||||
sctpConn.sentFuture = testSend()
|
||||
|
||||
# -- Sctp --
|
||||
|
||||
proc timersHandler() {.async.} =
|
||||
while true:
|
||||
await sleepAsync(500.milliseconds)
|
||||
|
@ -315,6 +325,7 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T =
|
|||
return sctp
|
||||
|
||||
proc stop*(self: Sctp) {.async.} =
|
||||
# TODO: close every connections
|
||||
discard self.usrsctpAwait usrsctp_finish()
|
||||
self.udp.close()
|
||||
|
||||
|
@ -324,14 +335,14 @@ proc readLoopProc(res: SctpConn) {.async.} =
|
|||
msg = await res.conn.read()
|
||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
if not data.isNil():
|
||||
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
|
||||
trace "Receive data", remoteAddress = res.conn.raddr,
|
||||
sctpPacket = data.getSctpPacket()
|
||||
usrsctp_freedumpbuffer(data)
|
||||
# printSctpPacket(msg) TODO: defined it
|
||||
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
|
||||
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
|
||||
if not self.isServer:
|
||||
raise newSctpError("Not a server")
|
||||
raise newException(SctpError, "Not a server")
|
||||
var res = SctpConn.new(await self.dtls.accept(), self)
|
||||
usrsctp_register_address(cast[pointer](res))
|
||||
res.readLoop = res.readLoopProc()
|
||||
|
@ -373,7 +384,7 @@ proc connect*(self: Sctp,
|
|||
var nodelay: uint32 = 1
|
||||
var recvinfo: uint32 = 1
|
||||
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn))
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
|
||||
addr nodelay, sizeof(nodelay).SockLen)
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
|
||||
|
@ -391,6 +402,6 @@ proc connect*(self: Sctp,
|
|||
conn.state = Connecting
|
||||
conn.connectEvent.clear()
|
||||
await conn.connectEvent.wait()
|
||||
# TODO: check connection state, if closed throw some exception I guess
|
||||
# TODO: check connection state, if closed throw an exception
|
||||
self.connections[address] = conn
|
||||
return conn
|
||||
|
|
Loading…
Reference in New Issue