Merge pull request #5 from status-im/sctp-dtls-client

Sctp / Dtls Client
This commit is contained in:
Ludovic Chenut 2024-02-28 14:07:57 +01:00 committed by GitHub
commit 1f27c163b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 146 deletions

View File

@ -11,6 +11,5 @@ requires "nim >= 1.2.0",
"https://github.com/status-im/nim-binary-serialization.git",
"https://github.com/status-im/nim-mbedtls.git"
proc runTest(filename: string) =
discard

View File

@ -82,20 +82,25 @@ proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
self.dataRecv = newAsyncQueue[seq[byte]]()
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
trace "Dtls write", length = msg.len()
var buf = msg
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
try:
let write = mb_ssl_write(self.ssl, buf)
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
except MbedTLSError as exc:
trace "Dtls write error", errorMsg = exc.msg
raise exc
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
var res = newSeq[byte](8192)
while true:
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
# TODO: exception catching
let length = mb_ssl_read(self.ssl, res)
if length == MBEDTLS_ERR_SSL_WANT_READ:
continue
if length < 0:
trace "dtls read", error = $(length.mbedtls_high_level_strerr())
trace "dtls read", error = $(length.cint.mbedtls_high_level_strerr())
res.setLen(length)
return res
@ -162,17 +167,18 @@ proc stop*(self: Dtls) =
self.readLoop.cancel()
self.started = false
proc serverHandshake(self: DtlsConn) {.async.} =
var shouldRead = true
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
var shouldRead = isServer
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
if isServer:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
self.sendFuture = nil
@ -181,8 +187,7 @@ proc serverHandshake(self: DtlsConn) {.async.} =
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ or
self.ssl.private_state == MBEDTLS_SSL_CLIENT_KEY_EXCHANGE:
elif res == MBEDTLS_ERR_SSL_WANT_READ:
shouldRead = true
continue
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
@ -191,10 +196,6 @@ proc serverHandshake(self: DtlsConn) {.async.} =
continue
elif res != 0:
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
# var remoteCertPtr = mbedtls_ssl_get_peer_cert(addr self.ssl)
# let remoteCert = remoteCertPtr[]
# self.remoteCert = newSeq[byte](remoteCert.raw.len)
# copyMem(addr self.remoteCert[0], remoteCert.raw.p, remoteCert.raw.len)
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
conn.remoteCert
@ -245,8 +246,8 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
mb_ssl_set_timer_cb(res.ssl, res.timer)
mb_ssl_setup(res.ssl, res.config)
mb_ssl_session_reset(res.ssl)
mbedtls_ssl_set_verify(addr res.ssl, verify, cast[pointer](res))
mbedtls_ssl_conf_authmode(addr res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) # TODO: create template
mb_ssl_set_verify(res.ssl, verify, res)
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ssl, cast[pointer](res),
dtlsSend, dtlsRecv, nil)
while true:
@ -255,7 +256,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
res.raddr = raddr
res.dataRecv.addLastNoWait(buf)
self.connections[raddr] = res
await res.serverHandshake()
await res.dtlsHandshake(true)
break
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
@ -263,20 +264,48 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
continue
return res
proc dial*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
discard
proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
var
selfvar = self
res = DtlsConn()
#import ../udp_connection
#import stew/byteutils
#proc main() {.async.} =
# let laddr = initTAddress("127.0.0.1:4433")
# let udp = UdpConn()
# await udp.init(laddr)
# let stun = StunConn()
# await stun.init(udp, laddr)
# let dtls = Dtls()
# dtls.start(stun, laddr)
# let x = await dtls.accept()
# echo "Recv: <", string.fromBytes(await x.read()), ">"
#
#waitFor(main())
await res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
res.ctr_drbg = self.ctr_drbg
res.entropy = self.entropy
var pkey = res.ctr_drbg.generateKey()
var srvcert = res.ctr_drbg.generateCertificate(pkey)
res.localCert = newSeq[byte](srvcert.raw.len)
copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len)
mb_ctr_drbg_init(res.ctr_drbg)
mb_entropy_init(res.entropy)
mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0)
mb_ssl_config_defaults(res.config,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT)
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
mb_ssl_set_timer_cb(res.ssl, res.timer)
mb_ssl_setup(res.ssl, res.config)
mb_ssl_set_verify(res.ssl, verify, res)
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
res.raddr = raddr
self.connections[raddr] = res
try:
await res.dtlsHandshake(false)
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
self.connections.del(raddr)
raise exc
return res

View File

@ -64,7 +64,6 @@ type
sockServer: ptr socket
pendingConnections: seq[SctpConn]
pendingConnections2: Table[SockAddr, SctpConn]
sentConnection: SctpConn
sentAddress: TransportAddress
sentFuture: Future[void]
@ -161,24 +160,23 @@ proc write*(
sendParams = default(SctpMessageParameters),
) {.async.} =
trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket)
self.sctp.sentConnection = self
self.sctp.sentAddress = self.address
var cpy = buf
var
(sendInfo, infoType) =
if sendParams != default(SctpMessageParameters):
(sctp_sndinfo(
snd_sid: sendParams.streamId,
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags
), cuint(SCTP_SENDV_SNDINFO))
else:
(default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO))
sendvErr = self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0,
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
infoType, 0)
let sendvErr =
if sendParams == default(SctpMessageParameters):
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else:
let sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags)
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0:
perror("usrsctp_sendv") # TODO: throw an exception
trace "write sendv error?", sendvErr, sendParams
@ -194,7 +192,7 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events
trace "Handle Upcall", events, state = conn.state
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
@ -202,7 +200,8 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
conn.connectEvent.fire()
elif bitand(events, SCTP_EVENT_READ) != 0:
if bitand(events, SCTP_EVENT_READ) != 0:
var
message = SctpMessage(
data: newSeq[byte](4096)
@ -251,14 +250,15 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
slen: Socklen = sizeof(Sockaddr_conn).uint32
let
sctp = cast[Sctp](data)
# TODO: check if sctpSocket != nil
sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen)
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
let conn = cast[SctpConn](sconn.sconn_addr)
conn.sctpSocket = sctpSocket
conn.state = Connected
var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen)
@ -266,36 +266,6 @@ proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
addr recvinfo, sizeof(recvinfo).SockLen)
conn.acceptEvent.fire()
# proc getOrCreateConnection(self: Sctp,
# udp: DatagramTransport,
# address: TransportAddress,
# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
# #TODO remove the = 5000
# if self.connections.hasKey(address):
# return self.connections[address]
# trace "Create Connection", address
# let
# sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
# conn = SctpConn.new(self, udp, address, sctpSocket)
# var on: int = 1
# doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP,
# SCTP_RECVRCVINFO,
# addr on,
# sizeof(on).SockLen)
# doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
# doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
# var sconn: Sockaddr_conn
# sconn.sconn_family = AF_CONN
# sconn.sconn_port = htons(sctpPort)
# sconn.sconn_addr = cast[pointer](self)
# self.sentConnection = conn
# self.sentAddress = address
# let connErr = self.usrsctpAwait:
# conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
# doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno)
# self.connections[address] = conn
# return conn
proc sendCallback(ctx: pointer,
buffer: pointer,
length: uint,
@ -310,6 +280,7 @@ proc sendCallback(ctx: pointer,
proc testSend() {.async.} =
try:
trace "Send To", address = sctpConn.address
# printSctpPacket(buf)
# TODO: defined it printSctpPacket(buf)
await sctpConn.conn.write(buf)
except CatchableError as exc:
@ -343,43 +314,6 @@ proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T =
usrsctp_register_address(cast[pointer](sctp))
return sctp
#proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
# logScope: topics = "webrtc sctp"
# let sctp = T(gotConnection: newAsyncEvent())
# proc onReceive(udp: DatagramTransport, raddr: TransportAddress) {.async, gcsafe.} =
# let
# msg = udp.getMessage()
# data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
# if data != nil:
# if sctp.isServer:
# trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), raddr
# else:
# trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), raddr
# usrsctp_freedumpbuffer(data)
#
# if sctp.isServer:
# sctp.sentAddress = raddr
# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
# else:
# let conn = await sctp.getOrCreateConnection(udp, raddr)
# sctp.sentConnection = conn
# sctp.sentAddress = raddr
# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
# let
# localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port))
# laddr = initTAddress("127.0.0.1:" & $port)
# udp = newDatagramTransport(onReceive, local = laddr)
# trace "local address", localAddr, laddr
# sctp.udp = udp
# sctp.timersHandler = timersHandler()
#
# usrsctp_init_nothreads(port, sendCallback, printf)
# discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
# discard usrsctp_sysctl_set_sctp_ecn_enable(1)
# usrsctp_register_address(cast[pointer](sctp))
#
# return sctp
proc stop*(self: Sctp) {.async.} =
discard self.usrsctpAwait usrsctp_finish()
self.udp.close()
@ -392,13 +326,13 @@ proc readLoopProc(res: SctpConn) {.async.} =
if not data.isNil():
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
usrsctp_freedumpbuffer(data)
res.sctp.sentConnection = res
# printSctpPacket(msg) TODO: defined it
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
if not self.isServer:
raise newSctpError("Not a server")
var res = SctpConn.new(await self.dtls.accept, self)
var res = SctpConn.new(await self.dtls.accept(), self)
usrsctp_register_address(cast[pointer](res))
res.readLoop = res.readLoopProc()
res.acceptEvent.clear()
@ -429,20 +363,34 @@ proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
proc connect*(self: Sctp,
address: TransportAddress,
sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
discard
let
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
conn = SctpConn.new(await self.dtls.connect(address), self)
# proc connect*(self: Sctp,
# address: TransportAddress,
# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
# trace "Connect", address, sctpPort
# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
# if conn.state == Connected:
# return conn
# try:
# await conn.connectEvent.wait() # TODO: clear?
# except CancelledError as exc:
# conn.sctpSocket.usrsctp_close()
# return nil
# if conn.state != Connected:
# raise newSctpError("Cannot connect to " & $address)
# return conn
trace "Create Connection", address
conn.sctpSocket = sctpSocket
conn.state = Connected
var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen)
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
addr recvinfo, sizeof(recvinfo).SockLen)
var sconn: Sockaddr_conn
sconn.sconn_family = AF_CONN
sconn.sconn_port = htons(sctpPort)
sconn.sconn_addr = cast[pointer](conn)
self.sentAddress = address
usrsctp_register_address(cast[pointer](conn))
conn.readLoop = conn.readLoopProc()
let connErr = self.usrsctpAwait:
conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno)
conn.state = Connecting
conn.connectEvent.clear()
await conn.connectEvent.wait()
# TODO: check connection state, if closed throw some exception I guess
self.connections[address] = conn
return conn

View File

@ -24,7 +24,7 @@ proc init*(self: UdpConn, laddr: TransportAddress) =
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
let msg = udp.getMessage()
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m"
trace "UDP onReceive", msg
self.dataRecv.addLastNoWait((msg, address))
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
@ -34,9 +34,9 @@ proc close*(self: UdpConn) {.async.} =
self.udp.close()
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
echo "\e[33m<UDP>\e[0;1m write\e[0m"
trace "UDP write", msg
await self.udp.sendTo(raddr, msg)
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
echo "\e[33m<UDP>\e[0;1m read\e[0m"
trace "UDP read"
return await self.dataRecv.popFirst()