commit
474f3d30ad
|
@ -0,0 +1,29 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/udp_connection
|
||||
import ../webrtc/stun/stun_connection
|
||||
import ../webrtc/dtls/dtls
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc sendPong(conn: SctpConn) {.async.} =
|
||||
var i = 0
|
||||
while true:
|
||||
let msg = await conn.read()
|
||||
echo "Received: ", string.fromBytes(msg.data)
|
||||
await conn.write(("pong " & $i).toBytes)
|
||||
i.inc()
|
||||
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4242")
|
||||
let udp = UdpConn()
|
||||
udp.init(laddr)
|
||||
let stun = StunConn()
|
||||
stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.start(stun, laddr)
|
||||
let sctp = Sctp.new(dtls, laddr)
|
||||
sctp.listen(13)
|
||||
while true:
|
||||
let conn = await sctp.accept()
|
||||
asyncSpawn conn.sendPong()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,25 @@
|
|||
import ../webrtc/datachannel
|
||||
import chronos/unittest2/asynctests
|
||||
import binary_serialization
|
||||
|
||||
suite "DataChannel encoding":
|
||||
test "DataChannelOpenMessage":
|
||||
let msg = @[
|
||||
0x03'u8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72]
|
||||
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
|
||||
check Binary.decode(msg, DataChannelMessage).openMessage ==
|
||||
DataChannelOpenMessage(
|
||||
channelType: Reliable,
|
||||
priority: 0,
|
||||
reliabilityParameter: 0,
|
||||
labelLength: 3,
|
||||
protocolLength: 3,
|
||||
label: @[102, 111, 111],
|
||||
protocol: @[98, 97, 114]
|
||||
)
|
||||
|
||||
test "DataChannelAck":
|
||||
let msg = @[0x02'u8]
|
||||
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
|
||||
check Binary.decode(msg, DataChannelMessage).messageType == Ack
|
|
@ -3,7 +3,7 @@ version = "0.0.1"
|
|||
author = "Status Research & Development GmbH"
|
||||
description = "Webrtc stack"
|
||||
license = "MIT"
|
||||
#installDirs = @["usrsctp"]
|
||||
installDirs = @["usrsctp", "webrtc"]
|
||||
|
||||
requires "nim >= 1.2.0",
|
||||
"chronicles >= 0.10.2",
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import tables
|
||||
|
||||
import chronos,
|
||||
chronicles,
|
||||
binary_serialization
|
||||
|
||||
import sctp
|
||||
|
||||
export binary_serialization
|
||||
|
||||
logScope:
|
||||
topics = "webrtc datachannel"
|
||||
|
||||
type
|
||||
DataChannelProtocolIds* {.size: 4.} = enum
|
||||
WebRtcDcep = 50
|
||||
WebRtcString = 51
|
||||
WebRtcBinary = 53
|
||||
WebRtcStringEmpty = 56
|
||||
WebRtcBinaryEmpty = 57
|
||||
|
||||
DataChannelMessageType* {.size: 1.} = enum
|
||||
Reserved = 0x00
|
||||
Ack = 0x02
|
||||
Open = 0x03
|
||||
|
||||
DataChannelMessage* = object
|
||||
case messageType*: DataChannelMessageType
|
||||
of Open: openMessage*: DataChannelOpenMessage
|
||||
else: discard
|
||||
|
||||
DataChannelType {.size: 1.} = enum
|
||||
Reliable = 0x00
|
||||
PartialReliableRexmit = 0x01
|
||||
PartialReliableTimed = 0x02
|
||||
ReliableUnordered = 0x80
|
||||
PartialReliableRexmitUnordered = 0x81
|
||||
PartialReliableTimedUnorderd = 0x82
|
||||
|
||||
DataChannelOpenMessage* = object
|
||||
channelType*: DataChannelType
|
||||
priority*: uint16
|
||||
reliabilityParameter*: uint32
|
||||
labelLength* {.bin_value: it.label.len.}: uint16
|
||||
protocolLength* {.bin_value: it.protocol.len.}: uint16
|
||||
label* {.bin_len: it.labelLength.}: seq[byte]
|
||||
protocol* {.bin_len: it.protocolLength.}: seq[byte]
|
||||
|
||||
proc ordered(t: DataChannelType): bool =
|
||||
t in [Reliable, PartialReliableRexmit, PartialReliableTimed]
|
||||
|
||||
type
|
||||
#TODO handle closing
|
||||
DataChannelStream* = ref object
|
||||
id: uint16
|
||||
conn: SctpConn
|
||||
reliability: DataChannelType
|
||||
reliabilityParameter: uint32
|
||||
receivedData: AsyncQueue[seq[byte]]
|
||||
acked: bool
|
||||
|
||||
#TODO handle closing
|
||||
DataChannelConnection* = ref object
|
||||
readLoopFut: Future[void]
|
||||
streams: Table[uint16, DataChannelStream]
|
||||
streamId: uint16
|
||||
conn*: SctpConn
|
||||
incomingStreams: AsyncQueue[DataChannelStream]
|
||||
|
||||
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
|
||||
let x = await stream.receivedData.popFirst()
|
||||
trace "read", length=x.len(), id=stream.id
|
||||
return x
|
||||
|
||||
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
|
||||
trace "write", length=buf.len(), id=stream.id
|
||||
var
|
||||
sendInfo = SctpMessageParameters(
|
||||
streamId: stream.id,
|
||||
endOfRecord: true,
|
||||
protocolId: uint32(WebRtcBinary)
|
||||
)
|
||||
|
||||
if stream.acked:
|
||||
sendInfo.unordered = not stream.reliability.ordered
|
||||
#TODO add reliability params
|
||||
|
||||
if buf.len == 0:
|
||||
trace "Datachannel write empty"
|
||||
sendInfo.protocolId = uint32(WebRtcBinaryEmpty)
|
||||
await stream.conn.write(@[0'u8], sendInfo)
|
||||
else:
|
||||
await stream.conn.write(buf, sendInfo)
|
||||
|
||||
proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.async.} =
|
||||
let
|
||||
encoded = Binary.encode(msg)
|
||||
sendInfo = SctpMessageParameters(
|
||||
streamId: stream.id,
|
||||
endOfRecord: true,
|
||||
protocolId: uint32(WebRtcDcep)
|
||||
)
|
||||
trace "send control message", msg
|
||||
|
||||
await stream.conn.write(encoded, sendInfo)
|
||||
|
||||
proc openStream*(
|
||||
conn: DataChannelConnection,
|
||||
noiseHandshake: bool,
|
||||
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
|
||||
let streamId: uint16 =
|
||||
if not noiseHandshake:
|
||||
let res = conn.streamId
|
||||
conn.streamId += 2
|
||||
res
|
||||
else:
|
||||
0
|
||||
|
||||
trace "open stream", streamId
|
||||
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
|
||||
raise newException(ValueError, "reliabilityParameter should be 0")
|
||||
|
||||
if streamId in conn.streams:
|
||||
raise newException(ValueError, "streamId already used")
|
||||
|
||||
#TODO: we should request more streams when required
|
||||
# https://github.com/sctplab/usrsctp/blob/a0cbf4681474fab1e89d9e9e2d5c3694fce50359/programs/rtcweb.c#L304C16-L304C16
|
||||
|
||||
var stream = DataChannelStream(
|
||||
id: streamId, conn: conn.conn,
|
||||
reliability: reliability,
|
||||
reliabilityParameter: reliabilityParameter,
|
||||
receivedData: newAsyncQueue[seq[byte]]()
|
||||
)
|
||||
|
||||
conn.streams[streamId] = stream
|
||||
|
||||
let
|
||||
msg = DataChannelMessage(
|
||||
messageType: Open,
|
||||
openMessage: DataChannelOpenMessage(
|
||||
channelType: reliability,
|
||||
reliabilityParameter: reliabilityParameter
|
||||
)
|
||||
)
|
||||
await stream.sendControlMessage(msg)
|
||||
return stream
|
||||
|
||||
proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
|
||||
let streamId = msg.params.streamId
|
||||
trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data
|
||||
|
||||
if streamId notin conn.streams:
|
||||
raise newException(ValueError, "got data for unknown streamid")
|
||||
|
||||
let stream = conn.streams[streamId]
|
||||
|
||||
#TODO handle string vs binary
|
||||
if msg.params.protocolId in [uint32(WebRtcStringEmpty), uint32(WebRtcBinaryEmpty)]:
|
||||
# PPID indicate empty message
|
||||
stream.receivedData.addLastNoWait(@[])
|
||||
else:
|
||||
stream.receivedData.addLastNoWait(msg.data)
|
||||
|
||||
proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
|
||||
let
|
||||
decoded = Binary.decode(msg.data, DataChannelMessage)
|
||||
streamId = msg.params.streamId
|
||||
|
||||
trace "handle control message", decoded, streamId = msg.params.streamId
|
||||
if decoded.messageType == Ack:
|
||||
if streamId notin conn.streams:
|
||||
raise newException(ValueError, "got ack for unknown streamid")
|
||||
conn.streams[streamId].acked = true
|
||||
elif decoded.messageType == Open:
|
||||
if streamId in conn.streams:
|
||||
raise newException(ValueError, "got open for already existing streamid")
|
||||
|
||||
let stream = DataChannelStream(
|
||||
id: streamId, conn: conn.conn,
|
||||
reliability: decoded.openMessage.channelType,
|
||||
reliabilityParameter: decoded.openMessage.reliabilityParameter,
|
||||
receivedData: newAsyncQueue[seq[byte]]()
|
||||
)
|
||||
|
||||
conn.streams[streamId] = stream
|
||||
conn.incomingStreams.addLastNoWait(stream)
|
||||
|
||||
await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
|
||||
|
||||
proc readLoop(conn: DataChannelConnection) {.async.} =
|
||||
try:
|
||||
while true:
|
||||
let message = await conn.conn.read()
|
||||
# TODO: might be necessary to check the others protocolId at some point
|
||||
if message.params.protocolId == uint32(WebRtcDcep):
|
||||
#TODO should we really await?
|
||||
await conn.handleControl(message)
|
||||
else:
|
||||
conn.handleData(message)
|
||||
|
||||
except CatchableError as exc:
|
||||
discard
|
||||
|
||||
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
|
||||
return await conn.incomingStreams.popFirst()
|
||||
|
||||
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
|
||||
result = DataChannelConnection(
|
||||
conn: conn,
|
||||
incomingStreams: newAsyncQueue[DataChannelStream](),
|
||||
streamId: 1'u16 # TODO: Serveur == 1, client == 2
|
||||
)
|
||||
conn.readLoopFut = conn.readLoop()
|
|
@ -7,9 +7,9 @@
|
|||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import times, sequtils
|
||||
import times, deques, tables
|
||||
import chronos, chronicles
|
||||
import ./utils, ../webrtc_connection
|
||||
import ./utils, ../stun/stun_connection
|
||||
|
||||
import mbedtls/ssl
|
||||
import mbedtls/ssl_cookie
|
||||
|
@ -29,11 +29,19 @@ import mbedtls/timing
|
|||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# TODO: Check the viability of the add/pop first/last of the asyncqueue with the limit.
|
||||
# There might be some errors (or crashes) in weird cases with the no wait option
|
||||
|
||||
const
|
||||
PendingHandshakeLimit = 1024
|
||||
|
||||
type
|
||||
DtlsError* = object of CatchableError
|
||||
DtlsConn* = ref object of WebRTCConn
|
||||
recvData: seq[seq[byte]]
|
||||
recvEvent: AsyncEvent
|
||||
DtlsConn* = ref object
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
raddr*: TransportAddress
|
||||
dataRecv: AsyncQueue[seq[byte]]
|
||||
sendFuture: Future[void]
|
||||
|
||||
timer: mbedtls_timing_delay_context
|
||||
|
@ -46,75 +54,131 @@ type
|
|||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
entropy: mbedtls_entropy_context
|
||||
|
||||
localCert: seq[byte]
|
||||
remoteCert: seq[byte]
|
||||
|
||||
proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
var self = cast[DtlsConn](ctx)
|
||||
var toWrite = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr toWrite[0], buf, len)
|
||||
self.sendFuture = self.conn.write(toWrite)
|
||||
trace "dtls send", len
|
||||
self.sendFuture = self.conn.write(self.raddr, toWrite)
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
var self = cast[DtlsConn](ctx)
|
||||
result = self.recvData[0].len().cint
|
||||
copyMem(buf, addr self.recvData[0][0], self.recvData[0].len())
|
||||
self.recvData.delete(0..0)
|
||||
let self = cast[DtlsConn](ctx)
|
||||
if self.dataRecv.len() == 0:
|
||||
return MBEDTLS_ERR_SSL_WANT_READ
|
||||
|
||||
method init*(self: DtlsConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
|
||||
await procCall(WebRTCConn(self).init(conn, address))
|
||||
var dataRecv = self.dataRecv.popFirstNoWait()
|
||||
copyMem(buf, addr dataRecv[0], dataRecv.len())
|
||||
result = dataRecv.len().cint
|
||||
trace "dtls receive", len, result
|
||||
|
||||
method write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
proc init*(self: DtlsConn, conn: StunConn, laddr: TransportAddress) {.async.} =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.dataRecv = newAsyncQueue[seq[byte]]()
|
||||
|
||||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
trace "Dtls write", length = msg.len()
|
||||
var buf = msg
|
||||
discard mbedtls_ssl_write(addr self.ssl, cast[ptr byte](addr buf[0]), buf.len().uint)
|
||||
|
||||
method read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
return await self.conn.read()
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
var res = newSeq[byte](8192)
|
||||
while true:
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
if length == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
continue
|
||||
if length < 0:
|
||||
trace "dtls read", error = $(length.mbedtls_high_level_strerr())
|
||||
res.setLen(length)
|
||||
return res
|
||||
|
||||
method close*(self: DtlsConn) {.async.} =
|
||||
proc close*(self: DtlsConn) {.async.} =
|
||||
discard
|
||||
|
||||
method getRemoteAddress*(self: DtlsConn): TransportAddress =
|
||||
self.conn.getRemoteAddress()
|
||||
|
||||
type
|
||||
Dtls* = ref object of RootObj
|
||||
address: TransportAddress
|
||||
connections: Table[TransportAddress, DtlsConn]
|
||||
pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])]
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
started: bool
|
||||
readLoop: Future[void]
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
entropy: mbedtls_entropy_context
|
||||
|
||||
proc start*(self: Dtls, address: TransportAddress) =
|
||||
serverPrivKey: mbedtls_pk_context
|
||||
serverCert: mbedtls_x509_crt
|
||||
localCert: seq[byte]
|
||||
|
||||
proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
|
||||
raddr: TransportAddress, buf: seq[byte]) =
|
||||
for kv in aq.mitems():
|
||||
if kv[0] == raddr:
|
||||
kv[1] = buf
|
||||
return
|
||||
aq.addLastNoWait((raddr, buf))
|
||||
|
||||
proc start*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
||||
if self.started:
|
||||
warn "Already started"
|
||||
return
|
||||
|
||||
self.address = address
|
||||
proc readLoop() {.async.} =
|
||||
while true:
|
||||
let (buf, raddr) = await self.conn.read()
|
||||
if self.connections.hasKey(raddr):
|
||||
self.connections[raddr].dataRecv.addLastNoWait(buf)
|
||||
else:
|
||||
self.pendingHandshakes.updateOrAdd(raddr, buf)
|
||||
|
||||
self.connections = initTable[TransportAddress, DtlsConn]()
|
||||
self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit)
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.started = true
|
||||
self.readLoop = readLoop()
|
||||
|
||||
mb_ctr_drbg_init(self.ctr_drbg)
|
||||
mb_entropy_init(self.entropy)
|
||||
mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0)
|
||||
|
||||
self.serverPrivKey = self.ctr_drbg.generateKey()
|
||||
self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey)
|
||||
self.localCert = newSeq[byte](self.serverCert.raw.len)
|
||||
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
|
||||
|
||||
proc stop*(self: Dtls) =
|
||||
if not self.started:
|
||||
warn "Already stopped"
|
||||
return
|
||||
|
||||
self.readLoop.cancel()
|
||||
self.started = false
|
||||
|
||||
proc serverHandshake(self: DtlsConn) {.async.} =
|
||||
var shouldRead = true
|
||||
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
self.recvData.add(await self.conn.read())
|
||||
var ta = self.getRemoteAddress()
|
||||
case ta.family
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, ta.address_v4)
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, ta.address_v6)
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
shouldRead = false
|
||||
if not self.sendFuture.isNil(): await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ or
|
||||
|
@ -127,25 +191,46 @@ proc serverHandshake(self: DtlsConn) {.async.} =
|
|||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
# var remoteCertPtr = mbedtls_ssl_get_peer_cert(addr self.ssl)
|
||||
# let remoteCert = remoteCertPtr[]
|
||||
# self.remoteCert = newSeq[byte](remoteCert.raw.len)
|
||||
# copyMem(addr self.remoteCert[0], remoteCert.raw.p, remoteCert.raw.len)
|
||||
|
||||
proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
|
||||
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.remoteCert
|
||||
|
||||
proc localCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.localCert
|
||||
|
||||
proc localCertificate*(self: Dtls): seq[byte] =
|
||||
self.localCert
|
||||
|
||||
proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
|
||||
state: cint, pflags: ptr uint32): cint {.cdecl.} =
|
||||
var self = cast[DtlsConn](ctx)
|
||||
let cert = pcert[]
|
||||
|
||||
self.remoteCert = newSeq[byte](cert.raw.len)
|
||||
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
|
||||
return 0
|
||||
|
||||
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
let v = cast[pointer](res)
|
||||
|
||||
await res.init(conn, self.address)
|
||||
await res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
mb_ssl_cookie_init(res.cookie)
|
||||
mb_ssl_cache_init(res.cache)
|
||||
|
||||
mb_ctr_drbg_init(res.ctr_drbg)
|
||||
mb_entropy_init(res.entropy)
|
||||
mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0)
|
||||
res.ctr_drbg = self.ctr_drbg
|
||||
res.entropy = self.entropy
|
||||
|
||||
var pkey = res.ctr_drbg.generateKey()
|
||||
var srvcert = res.ctr_drbg.generateCertificate(pkey)
|
||||
var pkey = self.serverPrivKey
|
||||
var srvcert = self.serverCert
|
||||
res.localCert = self.localCert
|
||||
|
||||
mb_ssl_config_defaults(res.config,
|
||||
MBEDTLS_SSL_IS_SERVER,
|
||||
|
@ -160,22 +245,38 @@ proc accept*(self: Dtls, conn: WebRTCConn): Future[DtlsConn] {.async.} =
|
|||
mb_ssl_set_timer_cb(res.ssl, res.timer)
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_session_reset(res.ssl)
|
||||
mbedtls_ssl_set_verify(addr res.ssl, verify, cast[pointer](res))
|
||||
mbedtls_ssl_conf_authmode(addr res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) # TODO: create template
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res),
|
||||
dtlsSend, dtlsRecv, nil)
|
||||
await res.serverHandshake()
|
||||
while true:
|
||||
let (raddr, buf) = await self.pendingHandshakes.popFirst()
|
||||
try:
|
||||
res.raddr = raddr
|
||||
res.dataRecv.addLastNoWait(buf)
|
||||
self.connections[raddr] = res
|
||||
await res.serverHandshake()
|
||||
break
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
continue
|
||||
return res
|
||||
|
||||
proc dial*(self: Dtls, address: TransportAddress): DtlsConn =
|
||||
proc dial*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
||||
discard
|
||||
|
||||
import ../udp_connection
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4433")
|
||||
let udp = UdpConn()
|
||||
await udp.init(nil, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.start(laddr)
|
||||
let x = await dtls.accept(udp)
|
||||
echo "After accept"
|
||||
|
||||
waitFor(main())
|
||||
#import ../udp_connection
|
||||
#import stew/byteutils
|
||||
#proc main() {.async.} =
|
||||
# let laddr = initTAddress("127.0.0.1:4433")
|
||||
# let udp = UdpConn()
|
||||
# await udp.init(laddr)
|
||||
# let stun = StunConn()
|
||||
# await stun.init(udp, laddr)
|
||||
# let dtls = Dtls()
|
||||
# dtls.start(stun, laddr)
|
||||
# let x = await dtls.accept()
|
||||
# echo "Recv: <", string.fromBytes(await x.read()), ">"
|
||||
#
|
||||
#waitFor(main())
|
||||
|
|
428
webrtc/sctp.nim
428
webrtc/sctp.nim
|
@ -8,14 +8,20 @@
|
|||
# those terms.
|
||||
|
||||
import tables, bitops, posix, strutils, sequtils
|
||||
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils]
|
||||
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
|
||||
import usrsctp
|
||||
import dtls/dtls
|
||||
import binary_serialization
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "webrtc sctp"
|
||||
|
||||
# TODO:
|
||||
# - Replace doAssert by a proper exception management
|
||||
# - Find a clean way to manage SCTP ports
|
||||
|
||||
type
|
||||
SctpError* = object of CatchableError
|
||||
|
||||
|
@ -24,27 +30,56 @@ type
|
|||
Connected
|
||||
Closed
|
||||
|
||||
SctpConnection* = ref object
|
||||
SctpMessageParameters* = object
|
||||
protocolId*: uint32
|
||||
streamId*: uint16
|
||||
endOfRecord*: bool
|
||||
unordered*: bool
|
||||
|
||||
SctpMessage* = ref object
|
||||
data*: seq[byte]
|
||||
info: sctp_recvv_rn
|
||||
params*: SctpMessageParameters
|
||||
|
||||
SctpConn* = ref object
|
||||
conn*: DtlsConn
|
||||
state: SctpState
|
||||
connectEvent: AsyncEvent
|
||||
acceptEvent: AsyncEvent
|
||||
readLoop: Future[void]
|
||||
sctp: Sctp
|
||||
udp: DatagramTransport
|
||||
address: TransportAddress
|
||||
sctpSocket: ptr socket
|
||||
recvEvent: AsyncEvent
|
||||
dataRecv: seq[byte]
|
||||
dataRecv: AsyncQueue[SctpMessage]
|
||||
sentFuture: Future[void]
|
||||
|
||||
Sctp* = ref object
|
||||
dtls: Dtls
|
||||
udp: DatagramTransport
|
||||
connections: Table[TransportAddress, SctpConnection]
|
||||
connections: Table[TransportAddress, SctpConn]
|
||||
gotConnection: AsyncEvent
|
||||
timersHandler: Future[void]
|
||||
isServer: bool
|
||||
sockServer: ptr socket
|
||||
pendingConnections: seq[SctpConnection]
|
||||
sentFuture: Future[void]
|
||||
sentConnection: SctpConnection
|
||||
pendingConnections: seq[SctpConn]
|
||||
pendingConnections2: Table[SockAddr, SctpConn]
|
||||
sentConnection: SctpConn
|
||||
sentAddress: TransportAddress
|
||||
sentFuture: Future[void]
|
||||
|
||||
# Those two objects are only here for debugging purpose
|
||||
SctpChunk = object
|
||||
chunkType: uint8
|
||||
flag: uint8
|
||||
length {.bin_value: it.data.len() + 4.}: uint16
|
||||
data {.bin_len: it.length - 4.}: seq[byte]
|
||||
|
||||
SctpPacketStructure = object
|
||||
srcPort: uint16
|
||||
dstPort: uint16
|
||||
verifTag: uint32
|
||||
checksum: uint32
|
||||
|
||||
const
|
||||
IPPROTO_SCTP = 132
|
||||
|
@ -52,19 +87,32 @@ const
|
|||
proc newSctpError(msg: string): ref SctpError =
|
||||
result = newException(SctpError, msg)
|
||||
|
||||
template usrsctpAwait(sctp: Sctp, body: untyped): untyped =
|
||||
sctp.sentFuture = nil
|
||||
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
|
||||
self.sentFuture = nil
|
||||
when type(body) is void:
|
||||
body
|
||||
if sctp.sentFuture != nil: await sctp.sentFuture
|
||||
if self.sentFuture != nil: await self.sentFuture
|
||||
else:
|
||||
let res = body
|
||||
if sctp.sentFuture != nil: await sctp.sentFuture
|
||||
if self.sentFuture != nil: await self.sentFuture
|
||||
res
|
||||
|
||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||
|
||||
proc printSctpPacket(buffer: seq[byte]) =
|
||||
let s = Binary.decode(buffer, SctpPacketStructure)
|
||||
echo " => \e[31;1mStructure\e[0m: ", s
|
||||
var size = sizeof(SctpPacketStructure)
|
||||
var i = 1
|
||||
while size < buffer.len:
|
||||
let c = Binary.decode(buffer[size..^1], SctpChunk)
|
||||
echo " ===> \e[32;1mChunk ", i, "\e[0m ", c
|
||||
i.inc()
|
||||
size.inc(c.length.int)
|
||||
while size mod 4 != 0:
|
||||
size.inc()
|
||||
|
||||
proc packetPretty(packet: cstring): string =
|
||||
let data = $packet
|
||||
let ctn = data[23..^16]
|
||||
|
@ -74,7 +122,7 @@ proc packetPretty(packet: cstring): string =
|
|||
else:
|
||||
result = result & ctn
|
||||
|
||||
proc new(T: typedesc[SctpConnection],
|
||||
proc new(T: typedesc[SctpConn],
|
||||
sctp: Sctp,
|
||||
udp: DatagramTransport,
|
||||
address: TransportAddress,
|
||||
|
@ -85,36 +133,67 @@ proc new(T: typedesc[SctpConnection],
|
|||
address: address,
|
||||
sctpSocket: sctpSocket,
|
||||
connectEvent: AsyncEvent(),
|
||||
recvEvent: AsyncEvent())
|
||||
#TODO add some limit for backpressure?
|
||||
dataRecv: newAsyncQueue[SctpMessage]()
|
||||
)
|
||||
|
||||
proc read*(self: SctpConnection): Future[seq[byte]] {.async.} =
|
||||
trace "Read"
|
||||
if self.dataRecv.len == 0:
|
||||
self.recvEvent.clear()
|
||||
await self.recvEvent.wait()
|
||||
let res = self.dataRecv
|
||||
self.dataRecv = @[]
|
||||
return res
|
||||
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
|
||||
T(conn: conn,
|
||||
sctp: sctp,
|
||||
state: Connecting,
|
||||
connectEvent: AsyncEvent(),
|
||||
acceptEvent: AsyncEvent(),
|
||||
dataRecv: newAsyncQueue[SctpMessage]() #TODO add some limit for backpressure?
|
||||
)
|
||||
|
||||
proc write*(self: SctpConnection, buf: seq[byte]) {.async.} =
|
||||
trace "Write", buf
|
||||
proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
|
||||
return await self.dataRecv.popFirst()
|
||||
|
||||
proc toFlags(params: SctpMessageParameters): uint16 =
|
||||
if params.endOfRecord:
|
||||
result = result or SCTP_EOR
|
||||
if params.unordered:
|
||||
result = result or SCTP_UNORDERED
|
||||
|
||||
proc write*(
|
||||
self: SctpConn,
|
||||
buf: seq[byte],
|
||||
sendParams = default(SctpMessageParameters),
|
||||
) {.async.} =
|
||||
trace "Write", buf, sctp = cast[uint64](self), sock = cast[uint64](self.sctpSocket)
|
||||
self.sctp.sentConnection = self
|
||||
self.sctp.sentAddress = self.address
|
||||
let sendvErr = self.sctp.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint,
|
||||
nil, 0, nil, 0,
|
||||
SCTP_SENDV_NOINFO, 0)
|
||||
|
||||
proc write*(self: SctpConnection, s: string) {.async.} =
|
||||
var cpy = buf
|
||||
var
|
||||
(sendInfo, infoType) =
|
||||
if sendParams != default(SctpMessageParameters):
|
||||
(sctp_sndinfo(
|
||||
snd_sid: sendParams.streamId,
|
||||
snd_ppid: sendParams.protocolId.swapBytes(),
|
||||
snd_flags: sendParams.toFlags
|
||||
), cuint(SCTP_SENDV_SNDINFO))
|
||||
else:
|
||||
(default(sctp_sndinfo), cuint(SCTP_SENDV_NOINFO))
|
||||
sendvErr = self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len.uint, nil, 0,
|
||||
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
|
||||
infoType, 0)
|
||||
if sendvErr < 0:
|
||||
perror("usrsctp_sendv") # TODO: throw an exception
|
||||
trace "write sendv error?", sendvErr, sendParams
|
||||
|
||||
proc write*(self: SctpConn, s: string) {.async.} =
|
||||
await self.write(s.toBytes())
|
||||
|
||||
proc close*(self: SctpConnection) {.async.} =
|
||||
self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close()
|
||||
proc close*(self: SctpConn) {.async.} =
|
||||
self.usrsctpAwait: self.sctpSocket.usrsctp_close()
|
||||
|
||||
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
let
|
||||
conn = cast[SctpConn](data)
|
||||
events = usrsctp_get_events(sock)
|
||||
conn = cast[SctpConnection](data)
|
||||
|
||||
trace "Handle Upcall", events
|
||||
if conn.state == Connecting:
|
||||
if bitand(events, SCTP_EVENT_ERROR) != 0:
|
||||
|
@ -125,17 +204,20 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
conn.connectEvent.fire()
|
||||
elif bitand(events, SCTP_EVENT_READ) != 0:
|
||||
var
|
||||
buffer = newSeq[byte](4096)
|
||||
message = SctpMessage(
|
||||
data: newSeq[byte](4096)
|
||||
)
|
||||
address: Sockaddr_storage
|
||||
rn: sctp_recvv_rn
|
||||
addressLen = sizeof(Sockaddr_storage).SockLen
|
||||
rnLen = sizeof(sctp_recvv_rn).SockLen
|
||||
infotype: uint
|
||||
flags: int
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr buffer[0]), buffer.len.uint,
|
||||
trace "recv from", sockuint64=cast[uint64](sock)
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), message.data.len.uint,
|
||||
cast[ptr SockAddr](addr address),
|
||||
cast[ptr SockLen](addr addressLen),
|
||||
cast[pointer](addr rn),
|
||||
cast[pointer](addr message.info),
|
||||
cast[ptr SockLen](addr rnLen),
|
||||
cast[ptr cuint](addr infotype),
|
||||
cast[ptr cint](addr flags))
|
||||
|
@ -143,59 +225,78 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
perror("usrsctp_recvv")
|
||||
return
|
||||
elif n > 0:
|
||||
# It might be necessary to check if infotype == SCTP_RECVV_RCVINFO
|
||||
message.data.delete(n..<message.data.len())
|
||||
trace "message info from handle upcall", msginfo = message.info
|
||||
message.params = SctpMessageParameters(
|
||||
protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
|
||||
streamId: message.info.recvv_rcvinfo.rcv_sid
|
||||
)
|
||||
if bitand(flags, MSG_NOTIFICATION) != 0:
|
||||
trace "Notification received", length = n
|
||||
else:
|
||||
conn.dataRecv = conn.dataRecv.concat(buffer[0..<n])
|
||||
conn.recvEvent.fire()
|
||||
try:
|
||||
conn.dataRecv.addLastNoWait(message)
|
||||
except AsyncQueueFullError:
|
||||
trace "Queue full, dropping packet"
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
trace "sctp event write in the upcall"
|
||||
else:
|
||||
warn "Handle Upcall unexpected event", events
|
||||
|
||||
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
trace "Handle Accept"
|
||||
var
|
||||
sconn: Sockaddr_conn
|
||||
slen: Socklen = sizeof(Sockaddr_conn).uint32
|
||||
let
|
||||
sctp = cast[Sctp](data)
|
||||
sctpSocket = usrsctp_accept(sctp.sockServer, nil, nil)
|
||||
sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen)
|
||||
|
||||
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
|
||||
let conn = SctpConnection.new(sctp, sctp.udp, sctp.sentAddress, sctpSocket)
|
||||
sctp.connections[sctp.sentAddress] = conn
|
||||
sctp.pendingConnections.add(conn)
|
||||
let conn = cast[SctpConn](sconn.sconn_addr)
|
||||
conn.sctpSocket = sctpSocket
|
||||
conn.state = Connected
|
||||
doAssert 0 == sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
||||
sctp.gotConnection.fire()
|
||||
var nodelay: uint32 = 1
|
||||
var recvinfo: uint32 = 1
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
|
||||
addr nodelay, sizeof(nodelay).SockLen)
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
|
||||
addr recvinfo, sizeof(recvinfo).SockLen)
|
||||
conn.acceptEvent.fire()
|
||||
|
||||
proc getOrCreateConnection(self: Sctp,
|
||||
udp: DatagramTransport,
|
||||
address: TransportAddress,
|
||||
sctpPort: uint16 = 5000): Future[SctpConnection] {.async.} =
|
||||
#TODO remove the = 5000
|
||||
if self.connections.hasKey(address):
|
||||
return self.connections[address]
|
||||
trace "Create Connection", address
|
||||
let
|
||||
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
conn = SctpConnection.new(self, udp, address, sctpSocket)
|
||||
var on: int = 1
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP,
|
||||
SCTP_RECVRCVINFO,
|
||||
addr on,
|
||||
sizeof(on).SockLen)
|
||||
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
|
||||
var sconn: Sockaddr_conn
|
||||
sconn.sconn_family = AF_CONN
|
||||
sconn.sconn_port = htons(sctpPort)
|
||||
sconn.sconn_addr = cast[pointer](self)
|
||||
self.sentConnection = conn
|
||||
self.sentAddress = address
|
||||
let connErr = self.usrsctpAwait:
|
||||
conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
|
||||
doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno) # TODO raise
|
||||
self.connections[address] = conn
|
||||
return conn
|
||||
# proc getOrCreateConnection(self: Sctp,
|
||||
# udp: DatagramTransport,
|
||||
# address: TransportAddress,
|
||||
# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
# #TODO remove the = 5000
|
||||
# if self.connections.hasKey(address):
|
||||
# return self.connections[address]
|
||||
# trace "Create Connection", address
|
||||
# let
|
||||
# sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
# conn = SctpConn.new(self, udp, address, sctpSocket)
|
||||
# var on: int = 1
|
||||
# doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP,
|
||||
# SCTP_RECVRCVINFO,
|
||||
# addr on,
|
||||
# sizeof(on).SockLen)
|
||||
# doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
|
||||
# doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
|
||||
# var sconn: Sockaddr_conn
|
||||
# sconn.sconn_family = AF_CONN
|
||||
# sconn.sconn_port = htons(sctpPort)
|
||||
# sconn.sconn_addr = cast[pointer](self)
|
||||
# self.sentConnection = conn
|
||||
# self.sentAddress = address
|
||||
# let connErr = self.usrsctpAwait:
|
||||
# conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
|
||||
# doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno)
|
||||
# self.connections[address] = conn
|
||||
# return conn
|
||||
|
||||
proc sendCallback(address: pointer,
|
||||
proc sendCallback(ctx: pointer,
|
||||
buffer: pointer,
|
||||
length: uint,
|
||||
tos: uint8,
|
||||
|
@ -204,42 +305,22 @@ proc sendCallback(address: pointer,
|
|||
if data != nil:
|
||||
trace "sendCallback", data = data.packetPretty(), length
|
||||
usrsctp_freedumpbuffer(data)
|
||||
let sctp = cast[Sctp](address)
|
||||
let sctpConn = cast[SctpConn](ctx)
|
||||
let buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||
proc testSend() {.async.} =
|
||||
try:
|
||||
let
|
||||
buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||
address = sctp.sentAddress
|
||||
trace "Send To", address
|
||||
await sendTo(sctp.udp, address, buf, int(length))
|
||||
trace "Send To", address = sctpConn.address
|
||||
# TODO: defined it printSctpPacket(buf)
|
||||
await sctpConn.conn.write(buf)
|
||||
except CatchableError as exc:
|
||||
trace "Send Failed", message = exc.msg
|
||||
sctp.sentFuture = testSend()
|
||||
sctpConn.sentFuture = testSend()
|
||||
|
||||
proc timersHandler() {.async.} =
|
||||
while true:
|
||||
await sleepAsync(500.milliseconds)
|
||||
usrsctp_handle_timers(500)
|
||||
|
||||
proc startServer*(self: Sctp, sctpPort: uint16 = 5000) =
|
||||
if self.isServer:
|
||||
trace "Try to start the server twice"
|
||||
return
|
||||
self.isServer = true
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
|
||||
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
var on: int = 1
|
||||
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
|
||||
var sin: Sockaddr_in
|
||||
sin.sin_family = posix.AF_INET.uint16
|
||||
sin.sin_port = htons(sctpPort)
|
||||
sin.sin_addr.s_addr = htonl(INADDR_ANY)
|
||||
doAssert 0 == usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in)))
|
||||
doAssert 0 >= usrsctp_listen(sock, 1)
|
||||
doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self))
|
||||
self.sockServer = sock
|
||||
|
||||
proc stopServer*(self: Sctp) =
|
||||
if not self.isServer:
|
||||
trace "Try to close a client"
|
||||
|
@ -251,70 +332,117 @@ proc stopServer*(self: Sctp) =
|
|||
pc.sctpSocket.usrsctp_close()
|
||||
self.sockServer.usrsctp_close()
|
||||
|
||||
proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
|
||||
logScope: topics = "webrtc sctp"
|
||||
let sctp = T(gotConnection: newAsyncEvent())
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
let
|
||||
msg = udp.getMessage()
|
||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
if data != nil:
|
||||
if sctp.isServer:
|
||||
trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), address
|
||||
else:
|
||||
trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), address
|
||||
usrsctp_freedumpbuffer(data)
|
||||
proc new*(T: typedesc[Sctp], dtls: Dtls, laddr: TransportAddress): T =
|
||||
let sctp = T(gotConnection: newAsyncEvent(),
|
||||
timersHandler: timersHandler(),
|
||||
dtls: dtls)
|
||||
|
||||
if sctp.isServer:
|
||||
sctp.sentAddress = address
|
||||
usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
else:
|
||||
let conn = await sctp.getOrCreateConnection(udp, address)
|
||||
sctp.sentConnection = conn
|
||||
sctp.sentAddress = address
|
||||
usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
let
|
||||
localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port))
|
||||
laddr = initTAddress("127.0.0.1:" & $port)
|
||||
udp = newDatagramTransport(onReceive, local = laddr)
|
||||
trace "local address", localAddr, laddr
|
||||
sctp.udp = udp
|
||||
sctp.timersHandler = timersHandler()
|
||||
|
||||
usrsctp_init_nothreads(port, sendCallback, printf)
|
||||
usrsctp_init_nothreads(laddr.port.uint16, sendCallback, printf)
|
||||
discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
|
||||
discard usrsctp_sysctl_set_sctp_ecn_enable(1)
|
||||
usrsctp_register_address(cast[pointer](sctp))
|
||||
|
||||
return sctp
|
||||
|
||||
#proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
|
||||
# logScope: topics = "webrtc sctp"
|
||||
# let sctp = T(gotConnection: newAsyncEvent())
|
||||
# proc onReceive(udp: DatagramTransport, raddr: TransportAddress) {.async, gcsafe.} =
|
||||
# let
|
||||
# msg = udp.getMessage()
|
||||
# data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
# if data != nil:
|
||||
# if sctp.isServer:
|
||||
# trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), raddr
|
||||
# else:
|
||||
# trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), raddr
|
||||
# usrsctp_freedumpbuffer(data)
|
||||
#
|
||||
# if sctp.isServer:
|
||||
# sctp.sentAddress = raddr
|
||||
# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
# else:
|
||||
# let conn = await sctp.getOrCreateConnection(udp, raddr)
|
||||
# sctp.sentConnection = conn
|
||||
# sctp.sentAddress = raddr
|
||||
# usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
# let
|
||||
# localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port))
|
||||
# laddr = initTAddress("127.0.0.1:" & $port)
|
||||
# udp = newDatagramTransport(onReceive, local = laddr)
|
||||
# trace "local address", localAddr, laddr
|
||||
# sctp.udp = udp
|
||||
# sctp.timersHandler = timersHandler()
|
||||
#
|
||||
# usrsctp_init_nothreads(port, sendCallback, printf)
|
||||
# discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
|
||||
# discard usrsctp_sysctl_set_sctp_ecn_enable(1)
|
||||
# usrsctp_register_address(cast[pointer](sctp))
|
||||
#
|
||||
# return sctp
|
||||
|
||||
proc stop*(self: Sctp) {.async.} =
|
||||
discard self.usrsctpAwait usrsctp_finish()
|
||||
self.udp.close()
|
||||
|
||||
proc listen*(self: Sctp): Future[SctpConnection] {.async.} =
|
||||
proc readLoopProc(res: SctpConn) {.async.} =
|
||||
while true:
|
||||
let
|
||||
msg = await res.conn.read()
|
||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
if not data.isNil():
|
||||
trace "Receive data", remoteAddress = res.conn.raddr, data = data.packetPretty()
|
||||
usrsctp_freedumpbuffer(data)
|
||||
res.sctp.sentConnection = res
|
||||
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
|
||||
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
|
||||
if not self.isServer:
|
||||
raise newSctpError("Not a server")
|
||||
trace "Listening"
|
||||
if self.pendingConnections.len == 0:
|
||||
self.gotConnection.clear()
|
||||
await self.gotConnection.wait()
|
||||
let res = self.pendingConnections[0]
|
||||
self.pendingConnections.delete(0)
|
||||
var res = SctpConn.new(await self.dtls.accept, self)
|
||||
usrsctp_register_address(cast[pointer](res))
|
||||
res.readLoop = res.readLoopProc()
|
||||
res.acceptEvent.clear()
|
||||
await res.acceptEvent.wait()
|
||||
return res
|
||||
|
||||
proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
|
||||
if self.isServer:
|
||||
trace "Try to start the server twice"
|
||||
return
|
||||
self.isServer = true
|
||||
trace "Listening", sctpPort
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_delayed_sack_time_default(0)
|
||||
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
var on: int = 1
|
||||
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
|
||||
var sin: Sockaddr_in
|
||||
sin.sin_family = posix.AF_INET.uint16
|
||||
sin.sin_port = htons(sctpPort)
|
||||
sin.sin_addr.s_addr = htonl(INADDR_ANY)
|
||||
doAssert 0 == usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in)))
|
||||
doAssert 0 >= usrsctp_listen(sock, 1)
|
||||
doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self))
|
||||
self.sockServer = sock
|
||||
|
||||
proc connect*(self: Sctp,
|
||||
address: TransportAddress,
|
||||
sctpPort: uint16 = 5000): Future[SctpConnection] {.async.} =
|
||||
trace "Connect", address
|
||||
let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
|
||||
if conn.state == Connected:
|
||||
return conn
|
||||
try:
|
||||
await conn.connectEvent.wait()
|
||||
except CancelledError as exc:
|
||||
conn.sctpSocket.usrsctp_close()
|
||||
return nil
|
||||
if conn.state != Connected:
|
||||
raise newSctpError("Cannot connect to " & $address)
|
||||
return conn
|
||||
sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
discard
|
||||
|
||||
# proc connect*(self: Sctp,
|
||||
# address: TransportAddress,
|
||||
# sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
# trace "Connect", address, sctpPort
|
||||
# let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
|
||||
# if conn.state == Connected:
|
||||
# return conn
|
||||
# try:
|
||||
# await conn.connectEvent.wait() # TODO: clear?
|
||||
# except CancelledError as exc:
|
||||
# conn.sctpSocket.usrsctp_close()
|
||||
# return nil
|
||||
# if conn.state != Connected:
|
||||
# raise newSctpError("Cannot connect to " & $address)
|
||||
# return conn
|
||||
|
|
|
@ -8,44 +8,39 @@
|
|||
# those terms.
|
||||
|
||||
import chronos
|
||||
import ../webrtc_connection, stun
|
||||
import ../udp_connection, stun
|
||||
|
||||
type
|
||||
StunConn* = ref object of WebRTCConn
|
||||
recvData: seq[seq[byte]]
|
||||
recvEvent: AsyncEvent
|
||||
StunConn* = ref object
|
||||
conn: UdpConn
|
||||
laddr: TransportAddress
|
||||
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
|
||||
handlesFut: Future[void]
|
||||
|
||||
proc handles(self: StunConn) {.async.} =
|
||||
while true: # TODO: while not self.conn.atEof()
|
||||
let msg = await self.conn.read()
|
||||
let (msg, raddr) = await self.conn.read()
|
||||
if Stun.isMessage(msg):
|
||||
let res = Stun.getResponse(msg, self.address)
|
||||
echo "\e[35;1m<STUN>\e[0m"
|
||||
let res = Stun.getResponse(msg, self.laddr)
|
||||
if res.isSome():
|
||||
await self.conn.write(res.get())
|
||||
await self.conn.write(raddr, res.get())
|
||||
else:
|
||||
self.recvData.add(msg)
|
||||
self.recvEvent.fire()
|
||||
self.dataRecv.addLastNoWait((msg, raddr))
|
||||
|
||||
method init(self: StunConn, conn: WebRTCConn, address: TransportAddress) {.async.} =
|
||||
await procCall(WebRTCConn(self).init(conn, address))
|
||||
proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
|
||||
self.recvEvent = newAsyncEvent()
|
||||
self.handlesFut = handles()
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.handlesFut = self.handles()
|
||||
|
||||
method close(self: StunConn) {.async.} =
|
||||
proc close*(self: StunConn) {.async.} =
|
||||
self.handlesFut.cancel() # check before?
|
||||
self.conn.close()
|
||||
await self.conn.close()
|
||||
|
||||
method write(self: StunConn, msg: seq[byte]) {.async.} =
|
||||
await self.conn.write(msg)
|
||||
proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
await self.conn.write(raddr, msg)
|
||||
|
||||
method read(self: StunConn): Future[seq[byte]] {.async.} =
|
||||
while self.recvData.len() <= 0:
|
||||
self.recvEvent.clear()
|
||||
await self.recvEvent.wait()
|
||||
result = self.recvData[0]
|
||||
self.recvData.delete(0..0)
|
||||
|
||||
method getRemoteAddress*(self: StunConn): TransportAddress =
|
||||
self.conn.getRemoteAddress()
|
||||
proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
return await self.dataRecv.popFirst()
|
||||
|
|
|
@ -9,47 +9,34 @@
|
|||
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
import webrtc_connection
|
||||
|
||||
logScope:
|
||||
topics = "webrtc udp"
|
||||
|
||||
type
|
||||
UdpConn* = ref object of WebRTCConn
|
||||
UdpConn* = ref object
|
||||
laddr*: TransportAddress
|
||||
udp: DatagramTransport
|
||||
remote: TransportAddress
|
||||
recvData: seq[seq[byte]]
|
||||
recvEvent: AsyncEvent
|
||||
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
|
||||
|
||||
method init(self: UdpConn, conn: WebRTCConn, addrss: TransportAddress) {.async.} =
|
||||
await procCall(WebRTCConn(self).init(conn, addrss))
|
||||
proc init*(self: UdpConn, laddr: TransportAddress) =
|
||||
self.laddr = laddr
|
||||
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
let msg = udp.getMessage()
|
||||
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m: ", udp.getMessage().len()
|
||||
self.remote = address
|
||||
self.recvData.add(msg)
|
||||
self.recvEvent.fire()
|
||||
echo "\e[33m<UDP>\e[0;1m onReceive\e[0m"
|
||||
self.dataRecv.addLastNoWait((msg, address))
|
||||
|
||||
self.recvEvent = newAsyncEvent()
|
||||
self.udp = newDatagramTransport(onReceive, local = addrss)
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.udp = newDatagramTransport(onReceive, local = laddr)
|
||||
|
||||
method close(self: UdpConn) {.async.} =
|
||||
proc close*(self: UdpConn) {.async.} =
|
||||
self.udp.close()
|
||||
if not self.conn.isNil():
|
||||
await self.conn.close()
|
||||
|
||||
method write(self: UdpConn, msg: seq[byte]) {.async.} =
|
||||
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
echo "\e[33m<UDP>\e[0;1m write\e[0m"
|
||||
await self.udp.sendTo(self.remote, msg)
|
||||
await self.udp.sendTo(raddr, msg)
|
||||
|
||||
method read(self: UdpConn): Future[seq[byte]] {.async.} =
|
||||
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
echo "\e[33m<UDP>\e[0;1m read\e[0m"
|
||||
while self.recvData.len() <= 0:
|
||||
self.recvEvent.clear()
|
||||
await self.recvEvent.wait()
|
||||
result = self.recvData[0]
|
||||
self.recvData.delete(0..0)
|
||||
|
||||
method getRemoteAddress*(self: UdpConn): TransportAddress =
|
||||
self.remote
|
||||
return await self.dataRecv.popFirst()
|
||||
|
|
|
@ -4,7 +4,7 @@ import strformat, os
|
|||
import nativesockets
|
||||
|
||||
# C include directory
|
||||
const root = currentSourcePath.parentDir
|
||||
const root = currentSourcePath.parentDir.parentDir
|
||||
const usrsctpInclude = root/"usrsctp"/"usrsctplib"
|
||||
|
||||
{.passc: fmt"-I{usrsctpInclude}".}
|
||||
|
@ -47,30 +47,29 @@ const usrsctpInclude = root/"usrsctp"/"usrsctplib"
|
|||
{.passc: "-DHAVE_NETINET_IP_ICMP_H=1".}
|
||||
{.passc: "-DHAVE_NET_ROUTE_H=1".}
|
||||
{.passc: "-D_GNU_SOURCE".}
|
||||
{.passc: "-I./usrsctp/usrsctplib".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_input.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_asconf.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_pcb.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_usrreq.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_cc_functions.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_auth.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_userspace.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_output.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_callout.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_crc32.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_sysctl.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_sha1.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_timer.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctputil.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_bsd_addr.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_peeloff.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_indata.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_ss_functions.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_socket.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet6/sctp6_usrreq.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_mbuf.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_environment.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_recv_thread.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_input.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_asconf.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_pcb.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_usrreq.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_cc_functions.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_auth.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_userspace.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_output.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_callout.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_crc32.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_sysctl.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_sha1.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_timer.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctputil.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_bsd_addr.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_peeloff.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_indata.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_ss_functions.c".}
|
||||
{.compile: usrsctpInclude / "user_socket.c".}
|
||||
{.compile: usrsctpInclude / "netinet6/sctp6_usrreq.c".}
|
||||
{.compile: usrsctpInclude / "user_mbuf.c".}
|
||||
{.compile: usrsctpInclude / "user_environment.c".}
|
||||
{.compile: usrsctpInclude / "user_recv_thread.c".}
|
||||
const
|
||||
MSG_NOTIFICATION* = 0x00002000
|
||||
AF_CONN* = 123
|
||||
|
|
|
@ -8,41 +8,34 @@
|
|||
# those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
import stun/stun
|
||||
|
||||
import udp_connection
|
||||
import stun/stun_connection
|
||||
import dtls/dtls
|
||||
import sctp, datachannel
|
||||
|
||||
logScope:
|
||||
topics = "webrtc"
|
||||
|
||||
let fut = newFuture[void]()
|
||||
type
|
||||
WebRTC* = object
|
||||
udp: DatagramTransport
|
||||
WebRTC* = ref object
|
||||
udp*: UdpConn
|
||||
stun*: StunConn
|
||||
dtls*: Dtls
|
||||
sctp*: Sctp
|
||||
port: int
|
||||
|
||||
proc new*(T: typedesc[WebRTC], port: uint16 = 42657): T =
|
||||
logScope: topics = "webrtc"
|
||||
var webrtc = T()
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
let
|
||||
msg = udp.getMessage()
|
||||
if Stun.isMessage(msg):
|
||||
let res = Stun.getResponse(msg, address)
|
||||
if res.isSome():
|
||||
await udp.sendTo(address, res.get())
|
||||
|
||||
trace "onReceive", isStun = Stun.isMessage(msg)
|
||||
if not fut.completed(): fut.complete()
|
||||
|
||||
let
|
||||
laddr = initTAddress("127.0.0.1:" & $port)
|
||||
udp = newDatagramTransport(onReceive, local = laddr)
|
||||
trace "local address", laddr
|
||||
webrtc.udp = udp
|
||||
proc new*(T: typedesc[WebRTC], address: TransportAddress): T =
|
||||
var webrtc = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls())
|
||||
webrtc.udp.init(address)
|
||||
webrtc.stun.init(webrtc.udp, address)
|
||||
webrtc.dtls.start(webrtc.stun, address)
|
||||
webrtc.sctp = Sctp.new(webrtc.dtls, address)
|
||||
return webrtc
|
||||
#
|
||||
#proc main {.async.} =
|
||||
# echo "/ip4/127.0.0.1/udp/42657/webrtc/certhash/uEiDKBGpmOW3zQhiCHagHZ8igwfKNIp8rQCJWd5E5mIhGHw/p2p/12D3KooWFjMiMZLaCKEZRvMqKp5qUGduS6iBZ9RWQgYZXYtAAaPC"
|
||||
# discard WebRTC.new()
|
||||
# await fut
|
||||
# await sleepAsync(10.seconds)
|
||||
#
|
||||
#waitFor(main())
|
||||
|
||||
proc listen*(w: WebRTC) =
|
||||
w.sctp.listen()
|
||||
|
||||
proc accept*(w: WebRTC): Future[DataChannelConnection] {.async.} =
|
||||
let sctpConn = await w.sctp.accept()
|
||||
result = DataChannelConnection.new(sctpConn)
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2023 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import chronos
|
||||
|
||||
type
|
||||
WebRTCConn* = ref object of RootObj
|
||||
conn*: WebRTCConn
|
||||
address*: TransportAddress
|
||||
# isClosed: bool
|
||||
# isEof: bool
|
||||
|
||||
method init*(self: WebRTCConn, conn: WebRTCConn, address: TransportAddress) {.async, base.} =
|
||||
self.conn = conn
|
||||
self.address = address
|
||||
|
||||
method close*(self: WebRTCConn) {.async, base.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method write*(self: WebRTCConn, msg: seq[byte]) {.async, base.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method read*(self: WebRTCConn): Future[seq[byte]] {.async, base.} =
|
||||
doAssert(false, "not implemented!")
|
||||
|
||||
method getRemoteAddress*(self: WebRTCConn): TransportAddress {.base.} =
|
||||
doAssert(false, "not implemented")
|
Loading…
Reference in New Issue