Merge remote-tracking branch 'origin/master' into ci
This commit is contained in:
commit
e4b98d379e
|
@ -0,0 +1,28 @@
|
|||
# Nim-Webrtc
|
||||
|
||||
![Stability: experimental](https://img.shields.io/badge/stability-experimental-orange.svg)
|
||||
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
|
||||
[![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
||||
|
||||
A simple WebRTC stack first implemented for [libp2p WebRTC direct transport](https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md).
|
||||
It uses a wrapper from two different C libraries:
|
||||
- [usrsctp]() for the SCTP stack
|
||||
- [mbedtls]() for the DTLS stack
|
||||
|
||||
## Usage
|
||||
|
||||
## Installation
|
||||
|
||||
## TODO
|
||||
|
||||
## License
|
||||
|
||||
Licensed and distributed under either of
|
||||
|
||||
* MIT license: [LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT
|
||||
|
||||
or
|
||||
|
||||
* Apache License, Version 2.0, ([LICENSE-APACHEv2](LICENSE-APACHEv2) or http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
at your option. This file may not be copied, modified, or distributed except according to those terms.
|
|
@ -11,7 +11,7 @@ fi
|
|||
cd "${root}/usrsctp" && ./bootstrap && ./configure && make && cd -
|
||||
|
||||
# add prelude
|
||||
cat "${root}/prelude.nim" > "${outputFile}"
|
||||
cat "${root}/prelude_usrsctp.nim" > "${outputFile}"
|
||||
|
||||
# assemble list of C files to be compiled
|
||||
for file in `find ${root}/usrsctp/usrsctplib -name '*.c'`; do
|
|
@ -0,0 +1,24 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/udp_connection
|
||||
import ../webrtc/stun/stun_connection
|
||||
import ../webrtc/dtls/dtls
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4244")
|
||||
let udp = UdpConn()
|
||||
udp.init(laddr)
|
||||
let stun = StunConn()
|
||||
stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.init(stun, laddr)
|
||||
let sctp = Sctp()
|
||||
sctp.init(dtls, laddr)
|
||||
let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13)
|
||||
while true:
|
||||
await conn.write("ping".toBytes)
|
||||
let msg = await conn.read()
|
||||
echo "Received: ", string.fromBytes(msg.data)
|
||||
await sleepAsync(1.seconds)
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,30 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/udp_connection
|
||||
import ../webrtc/stun/stun_connection
|
||||
import ../webrtc/dtls/dtls
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc sendPong(conn: SctpConn) {.async.} =
|
||||
var i = 0
|
||||
while true:
|
||||
let msg = await conn.read()
|
||||
echo "Received: ", string.fromBytes(msg.data)
|
||||
await conn.write(("pong " & $i).toBytes)
|
||||
i.inc()
|
||||
|
||||
proc main() {.async.} =
|
||||
let laddr = initTAddress("127.0.0.1:4242")
|
||||
let udp = UdpConn()
|
||||
udp.init(laddr)
|
||||
let stun = StunConn()
|
||||
stun.init(udp, laddr)
|
||||
let dtls = Dtls()
|
||||
dtls.init(stun, laddr)
|
||||
let sctp = Sctp()
|
||||
sctp.init(dtls, laddr)
|
||||
sctp.listen(13)
|
||||
while true:
|
||||
let conn = await sctp.accept()
|
||||
asyncSpawn conn.sendPong()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,25 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/sctp as sc
|
||||
|
||||
let sctp = Sctp.new(port = 4242)
|
||||
proc serv(fut: Future[void]) {.async.} =
|
||||
sctp.startServer(13)
|
||||
fut.complete()
|
||||
let conn = await sctp.listen()
|
||||
echo "await read()"
|
||||
let msg = await conn.read()
|
||||
echo "read() finished"
|
||||
echo "Receive: ", string.fromBytes(msg)
|
||||
await conn.close()
|
||||
sctp.stopServer()
|
||||
|
||||
proc main() {.async.} =
|
||||
let fut = Future[void]()
|
||||
asyncSpawn serv(fut)
|
||||
await fut
|
||||
let address = TransportAddress(initTAddress("127.0.0.1:4242"))
|
||||
let conn = await sctp.connect(address, sctpPort = 13)
|
||||
await conn.write("test".toBytes)
|
||||
await conn.close()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,14 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc main() {.async.} =
|
||||
let
|
||||
sctp = Sctp.new(port = 4244)
|
||||
address = TransportAddress(initTAddress("127.0.0.1:4242"))
|
||||
conn = await sctp.connect(address, sctpPort = 13)
|
||||
await conn.write("test".toBytes)
|
||||
let msg = await conn.read()
|
||||
echo "Client read() finished ; receive: ", string.fromBytes(msg)
|
||||
await conn.close()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,13 @@
|
|||
import chronos, stew/byteutils
|
||||
import ../webrtc/sctp
|
||||
|
||||
proc main() {.async.} =
|
||||
let sctp = Sctp.new(port = 4242)
|
||||
sctp.startServer(13)
|
||||
let conn = await sctp.listen()
|
||||
let msg = await conn.read()
|
||||
echo "Receive: ", string.fromBytes(msg)
|
||||
await conn.close()
|
||||
sctp.stopServer()
|
||||
|
||||
waitFor(main())
|
|
@ -0,0 +1,25 @@
|
|||
import ../webrtc/datachannel
|
||||
import chronos/unittest2/asynctests
|
||||
import binary_serialization
|
||||
|
||||
suite "DataChannel encoding":
|
||||
test "DataChannelOpenMessage":
|
||||
let msg = @[
|
||||
0x03'u8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72]
|
||||
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
|
||||
check Binary.decode(msg, DataChannelMessage).openMessage ==
|
||||
DataChannelOpenMessage(
|
||||
channelType: Reliable,
|
||||
priority: 0,
|
||||
reliabilityParameter: 0,
|
||||
labelLength: 3,
|
||||
protocolLength: 3,
|
||||
label: @[102, 111, 111],
|
||||
protocol: @[98, 97, 114]
|
||||
)
|
||||
|
||||
test "DataChannelAck":
|
||||
let msg = @[0x02'u8]
|
||||
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
|
||||
check Binary.decode(msg, DataChannelMessage).messageType == Ack
|
|
@ -0,0 +1,14 @@
|
|||
import ../webrtc/stun
|
||||
import ./asyncunit
|
||||
import binary_serialization
|
||||
|
||||
suite "Stun suite":
|
||||
test "Stun encoding/decoding with padding":
|
||||
let msg = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x00, 0x06, 0x00, 0x63, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x3a, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x00, 0xc0, 0x57, 0x00, 0x04, 0x00, 0x00, 0x03, 0xe7, 0x80, 0x2a, 0x00, 0x08, 0x86, 0x63, 0xfd, 0x45, 0xa9, 0xe5, 0x4c, 0xdb, 0x00, 0x24, 0x00, 0x04, 0x6e, 0x00, 0x1e, 0xff, 0x00, 0x08, 0x00, 0x14, 0x16, 0xff, 0x70, 0x8d, 0x97, 0x0b, 0xd6, 0xa3, 0x5b, 0xac, 0x8f, 0x4c, 0x85, 0xe6, 0xa6, 0xac, 0xaa, 0x7a, 0x68, 0x27, 0x80, 0x28, 0x00, 0x04, 0x79, 0x5e, 0x03, 0xd8 ]
|
||||
check msg == encode(StunMessage.decode(msg))
|
||||
|
||||
test "Error while decoding":
|
||||
let msgLengthFailed = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d ]
|
||||
expect AssertionDefect: discard StunMessage.decode(msgLengthFailed)
|
||||
let msgAttrFailed = @[ 0x00'u8, 0x01, 0x00, 0x08, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x28, 0x00, 0x05, 0x79, 0x5e, 0x03, 0xd8 ]
|
||||
expect AssertionDefect: discard StunMessage.decode(msgAttrFailed)
|
|
@ -3,8 +3,13 @@ version = "0.0.1"
|
|||
author = "Status Research & Development GmbH"
|
||||
description = "Webrtc stack"
|
||||
license = "MIT"
|
||||
#installDirs = @["usrsctp"]
|
||||
installDirs = @["usrsctp", "webrtc"]
|
||||
|
||||
requires "nim >= 1.2.0",
|
||||
"chronicles >= 0.10.2",
|
||||
"chronos >= 3.0.6"
|
||||
"chronos >= 3.0.6",
|
||||
"https://github.com/status-im/nim-binary-serialization.git",
|
||||
"https://github.com/status-im/nim-mbedtls.git"
|
||||
|
||||
proc runTest(filename: string) =
|
||||
discard
|
||||
|
|
|
@ -0,0 +1,227 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import tables
|
||||
|
||||
import chronos,
|
||||
chronicles,
|
||||
binary_serialization
|
||||
|
||||
import sctp
|
||||
|
||||
export binary_serialization
|
||||
|
||||
logScope:
|
||||
topics = "webrtc datachannel"
|
||||
|
||||
# Implementation of the DataChannel protocol, mostly following
|
||||
# https://www.rfc-editor.org/rfc/rfc8831.html and
|
||||
# https://www.rfc-editor.org/rfc/rfc8832.html
|
||||
|
||||
type
|
||||
DataChannelProtocolIds* {.size: 4.} = enum
|
||||
WebRtcDcep = 50
|
||||
WebRtcString = 51
|
||||
WebRtcBinary = 53
|
||||
WebRtcStringEmpty = 56
|
||||
WebRtcBinaryEmpty = 57
|
||||
|
||||
DataChannelMessageType* {.size: 1.} = enum
|
||||
Reserved = 0x00
|
||||
Ack = 0x02
|
||||
Open = 0x03
|
||||
|
||||
DataChannelMessage* = object
|
||||
case messageType*: DataChannelMessageType
|
||||
of Open: openMessage*: DataChannelOpenMessage
|
||||
else: discard
|
||||
|
||||
DataChannelType {.size: 1.} = enum
|
||||
Reliable = 0x00
|
||||
PartialReliableRexmit = 0x01
|
||||
PartialReliableTimed = 0x02
|
||||
ReliableUnordered = 0x80
|
||||
PartialReliableRexmitUnordered = 0x81
|
||||
PartialReliableTimedUnorderd = 0x82
|
||||
|
||||
DataChannelOpenMessage* = object
|
||||
channelType*: DataChannelType
|
||||
priority*: uint16
|
||||
reliabilityParameter*: uint32
|
||||
labelLength* {.bin_value: it.label.len.}: uint16
|
||||
protocolLength* {.bin_value: it.protocol.len.}: uint16
|
||||
label* {.bin_len: it.labelLength.}: seq[byte]
|
||||
protocol* {.bin_len: it.protocolLength.}: seq[byte]
|
||||
|
||||
proc ordered(t: DataChannelType): bool =
|
||||
t in [Reliable, PartialReliableRexmit, PartialReliableTimed]
|
||||
|
||||
type
|
||||
#TODO handle closing
|
||||
DataChannelStream* = ref object
|
||||
id: uint16
|
||||
conn: SctpConn
|
||||
reliability: DataChannelType
|
||||
reliabilityParameter: uint32
|
||||
receivedData: AsyncQueue[seq[byte]]
|
||||
acked: bool
|
||||
|
||||
#TODO handle closing
|
||||
DataChannelConnection* = ref object
|
||||
readLoopFut: Future[void]
|
||||
streams: Table[uint16, DataChannelStream]
|
||||
streamId: uint16
|
||||
conn*: SctpConn
|
||||
incomingStreams: AsyncQueue[DataChannelStream]
|
||||
|
||||
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
|
||||
let x = await stream.receivedData.popFirst()
|
||||
trace "read", length=x.len(), id=stream.id
|
||||
return x
|
||||
|
||||
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
|
||||
trace "write", length=buf.len(), id=stream.id
|
||||
var
|
||||
sendInfo = SctpMessageParameters(
|
||||
streamId: stream.id,
|
||||
endOfRecord: true,
|
||||
protocolId: uint32(WebRtcBinary)
|
||||
)
|
||||
|
||||
if stream.acked:
|
||||
sendInfo.unordered = not stream.reliability.ordered
|
||||
#TODO add reliability params
|
||||
|
||||
if buf.len == 0:
|
||||
trace "Datachannel write empty"
|
||||
sendInfo.protocolId = uint32(WebRtcBinaryEmpty)
|
||||
await stream.conn.write(@[0'u8], sendInfo)
|
||||
else:
|
||||
await stream.conn.write(buf, sendInfo)
|
||||
|
||||
proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.async.} =
|
||||
let
|
||||
encoded = Binary.encode(msg)
|
||||
sendInfo = SctpMessageParameters(
|
||||
streamId: stream.id,
|
||||
endOfRecord: true,
|
||||
protocolId: uint32(WebRtcDcep)
|
||||
)
|
||||
trace "send control message", msg
|
||||
|
||||
await stream.conn.write(encoded, sendInfo)
|
||||
|
||||
proc openStream*(
|
||||
conn: DataChannelConnection,
|
||||
noiseHandshake: bool,
|
||||
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
|
||||
let streamId: uint16 =
|
||||
if not noiseHandshake:
|
||||
let res = conn.streamId
|
||||
conn.streamId += 2
|
||||
res
|
||||
else:
|
||||
0
|
||||
|
||||
trace "open stream", streamId
|
||||
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
|
||||
raise newException(ValueError, "reliabilityParameter should be 0")
|
||||
|
||||
if streamId in conn.streams:
|
||||
raise newException(ValueError, "streamId already used")
|
||||
|
||||
#TODO: we should request more streams when required
|
||||
# https://github.com/sctplab/usrsctp/blob/a0cbf4681474fab1e89d9e9e2d5c3694fce50359/programs/rtcweb.c#L304C16-L304C16
|
||||
|
||||
var stream = DataChannelStream(
|
||||
id: streamId, conn: conn.conn,
|
||||
reliability: reliability,
|
||||
reliabilityParameter: reliabilityParameter,
|
||||
receivedData: newAsyncQueue[seq[byte]]()
|
||||
)
|
||||
|
||||
conn.streams[streamId] = stream
|
||||
|
||||
let
|
||||
msg = DataChannelMessage(
|
||||
messageType: Open,
|
||||
openMessage: DataChannelOpenMessage(
|
||||
channelType: reliability,
|
||||
reliabilityParameter: reliabilityParameter
|
||||
)
|
||||
)
|
||||
await stream.sendControlMessage(msg)
|
||||
return stream
|
||||
|
||||
proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
|
||||
let streamId = msg.params.streamId
|
||||
trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data
|
||||
|
||||
if streamId notin conn.streams:
|
||||
raise newException(ValueError, "got data for unknown streamid")
|
||||
|
||||
let stream = conn.streams[streamId]
|
||||
|
||||
#TODO handle string vs binary
|
||||
if msg.params.protocolId in [uint32(WebRtcStringEmpty), uint32(WebRtcBinaryEmpty)]:
|
||||
# PPID indicate empty message
|
||||
stream.receivedData.addLastNoWait(@[])
|
||||
else:
|
||||
stream.receivedData.addLastNoWait(msg.data)
|
||||
|
||||
proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
|
||||
let
|
||||
decoded = Binary.decode(msg.data, DataChannelMessage)
|
||||
streamId = msg.params.streamId
|
||||
|
||||
trace "handle control message", decoded, streamId = msg.params.streamId
|
||||
if decoded.messageType == Ack:
|
||||
if streamId notin conn.streams:
|
||||
raise newException(ValueError, "got ack for unknown streamid")
|
||||
conn.streams[streamId].acked = true
|
||||
elif decoded.messageType == Open:
|
||||
if streamId in conn.streams:
|
||||
raise newException(ValueError, "got open for already existing streamid")
|
||||
|
||||
let stream = DataChannelStream(
|
||||
id: streamId, conn: conn.conn,
|
||||
reliability: decoded.openMessage.channelType,
|
||||
reliabilityParameter: decoded.openMessage.reliabilityParameter,
|
||||
receivedData: newAsyncQueue[seq[byte]]()
|
||||
)
|
||||
|
||||
conn.streams[streamId] = stream
|
||||
conn.incomingStreams.addLastNoWait(stream)
|
||||
|
||||
await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
|
||||
|
||||
proc readLoop(conn: DataChannelConnection) {.async.} =
|
||||
try:
|
||||
while true:
|
||||
let message = await conn.conn.read()
|
||||
# TODO: check the protocolId
|
||||
if message.params.protocolId == uint32(WebRtcDcep):
|
||||
#TODO should we really await?
|
||||
await conn.handleControl(message)
|
||||
else:
|
||||
conn.handleData(message)
|
||||
|
||||
except CatchableError as exc:
|
||||
discard
|
||||
|
||||
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
|
||||
return await conn.incomingStreams.popFirst()
|
||||
|
||||
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
|
||||
result = DataChannelConnection(
|
||||
conn: conn,
|
||||
incomingStreams: newAsyncQueue[DataChannelStream](),
|
||||
streamId: 1'u16 # TODO: Serveur == 1, client == 2
|
||||
)
|
||||
result.readLoopFut = result.readLoop()
|
|
@ -0,0 +1,381 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import times, deques, tables, sequtils
|
||||
import chronos, chronicles
|
||||
import ./utils, ../stun/stun_connection
|
||||
|
||||
import mbedtls/ssl
|
||||
import mbedtls/ssl_cookie
|
||||
import mbedtls/ssl_cache
|
||||
import mbedtls/pk
|
||||
import mbedtls/md
|
||||
import mbedtls/entropy
|
||||
import mbedtls/ctr_drbg
|
||||
import mbedtls/rsa
|
||||
import mbedtls/x509
|
||||
import mbedtls/x509_crt
|
||||
import mbedtls/bignum
|
||||
import mbedtls/error
|
||||
import mbedtls/net_sockets
|
||||
import mbedtls/timing
|
||||
|
||||
logScope:
|
||||
topics = "webrtc dtls"
|
||||
|
||||
# Implementation of a DTLS client and a DTLS Server by using the mbedtls library.
|
||||
# Multiple things here are unintuitive partly because of the callbacks
|
||||
# used by mbedtls and that those callbacks cannot be async.
|
||||
#
|
||||
# TODO:
|
||||
# - Check the viability of the add/pop first/last of the asyncqueue with the limit.
|
||||
# There might be some errors (or crashes) with some edge cases with the no wait option
|
||||
# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE
|
||||
# - Not critical - May be interesting to split Dtls and DtlsConn into two files
|
||||
|
||||
# This limit is arbitrary, it could be interesting to make it configurable.
|
||||
const PendingHandshakeLimit = 1024
|
||||
|
||||
# -- DtlsConn --
|
||||
# A Dtls connection to a specific IP address recovered by the receiving part of
|
||||
# the Udp "connection"
|
||||
|
||||
type
|
||||
DtlsError* = object of CatchableError
|
||||
DtlsConn* = ref object
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
raddr*: TransportAddress
|
||||
dataRecv: AsyncQueue[seq[byte]]
|
||||
sendFuture: Future[void]
|
||||
closed: bool
|
||||
closeEvent: AsyncEvent
|
||||
|
||||
timer: mbedtls_timing_delay_context
|
||||
|
||||
ssl: mbedtls_ssl_context
|
||||
config: mbedtls_ssl_config
|
||||
cookie: mbedtls_ssl_cookie_ctx
|
||||
cache: mbedtls_ssl_cache_context
|
||||
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
entropy: mbedtls_entropy_context
|
||||
|
||||
localCert: seq[byte]
|
||||
remoteCert: seq[byte]
|
||||
|
||||
proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.dataRecv = newAsyncQueue[seq[byte]]()
|
||||
self.closed = false
|
||||
self.closeEvent = newAsyncEvent()
|
||||
|
||||
proc join(self: DtlsConn) {.async.} =
|
||||
await self.closeEvent.wait()
|
||||
|
||||
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
|
||||
var shouldRead = isServer
|
||||
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
|
||||
if shouldRead:
|
||||
if isServer:
|
||||
case self.raddr.family
|
||||
of AddressFamily.IPv4:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
|
||||
of AddressFamily.IPv6:
|
||||
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
|
||||
else:
|
||||
raise newException(DtlsError, "Remote address isn't an IP address")
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
self.sendFuture = nil
|
||||
let res = mb_ssl_handshake_step(self.ssl)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
shouldRead = false
|
||||
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
shouldRead = true
|
||||
continue
|
||||
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
|
||||
mb_ssl_session_reset(self.ssl)
|
||||
shouldRead = isServer
|
||||
continue
|
||||
elif res != 0:
|
||||
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
|
||||
|
||||
proc close*(self: DtlsConn) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to close DtlsConn twice"
|
||||
return
|
||||
|
||||
self.closed = true
|
||||
self.sendFuture = nil
|
||||
# TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls
|
||||
let x = mbedtls_ssl_close_notify(addr self.ssl)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
self.closeEvent.fire()
|
||||
|
||||
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to write on an already closed DtlsConn"
|
||||
return
|
||||
var buf = msg
|
||||
try:
|
||||
let sendFuture = newFuture[void]("DtlsConn write")
|
||||
self.sendFuture = nil
|
||||
let write = mb_ssl_write(self.ssl, buf)
|
||||
if not self.sendFuture.isNil():
|
||||
await self.sendFuture
|
||||
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
|
||||
except MbedTLSError as exc:
|
||||
trace "Dtls write error", errorMsg = exc.msg
|
||||
raise exc
|
||||
|
||||
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to read on an already closed DtlsConn"
|
||||
return
|
||||
var res = newSeq[byte](8192)
|
||||
while true:
|
||||
let tmp = await self.dataRecv.popFirst()
|
||||
self.dataRecv.addFirstNoWait(tmp)
|
||||
# TODO: Find a clear way to use the template `mb_ssl_read` without
|
||||
# messing up things with exception
|
||||
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
|
||||
if length == MBEDTLS_ERR_SSL_WANT_READ:
|
||||
continue
|
||||
if length < 0:
|
||||
raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr()))
|
||||
res.setLen(length)
|
||||
return res
|
||||
|
||||
# -- Dtls --
|
||||
# The Dtls object read every messages from the UdpConn/StunConn and, if the address
|
||||
# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue
|
||||
# to be accepted later, if the address is stored, add the message received to the
|
||||
# corresponding DtlsConn `dataRecv` queue.
|
||||
|
||||
type
|
||||
Dtls* = ref object of RootObj
|
||||
connections: Table[TransportAddress, DtlsConn]
|
||||
pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])]
|
||||
conn: StunConn
|
||||
laddr: TransportAddress
|
||||
started: bool
|
||||
readLoop: Future[void]
|
||||
ctr_drbg: mbedtls_ctr_drbg_context
|
||||
entropy: mbedtls_entropy_context
|
||||
|
||||
serverPrivKey: mbedtls_pk_context
|
||||
serverCert: mbedtls_x509_crt
|
||||
localCert: seq[byte]
|
||||
|
||||
proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
|
||||
raddr: TransportAddress, buf: seq[byte]) =
|
||||
for kv in aq.mitems():
|
||||
if kv[0] == raddr:
|
||||
kv[1] = buf
|
||||
return
|
||||
aq.addLastNoWait((raddr, buf))
|
||||
|
||||
proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
|
||||
if self.started:
|
||||
warn "Already started"
|
||||
return
|
||||
|
||||
proc readLoop() {.async.} =
|
||||
while true:
|
||||
let (buf, raddr) = await self.conn.read()
|
||||
if self.connections.hasKey(raddr):
|
||||
self.connections[raddr].dataRecv.addLastNoWait(buf)
|
||||
else:
|
||||
self.pendingHandshakes.updateOrAdd(raddr, buf)
|
||||
|
||||
self.connections = initTable[TransportAddress, DtlsConn]()
|
||||
self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit)
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.started = true
|
||||
self.readLoop = readLoop()
|
||||
|
||||
mb_ctr_drbg_init(self.ctr_drbg)
|
||||
mb_entropy_init(self.entropy)
|
||||
mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0)
|
||||
|
||||
self.serverPrivKey = self.ctr_drbg.generateKey()
|
||||
self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey)
|
||||
self.localCert = newSeq[byte](self.serverCert.raw.len)
|
||||
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
|
||||
|
||||
proc stop*(self: Dtls) {.async.} =
|
||||
if not self.started:
|
||||
warn "Already stopped"
|
||||
return
|
||||
|
||||
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
|
||||
self.readLoop.cancel()
|
||||
self.started = false
|
||||
|
||||
# -- Remote / Local certificate getter --
|
||||
|
||||
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.remoteCert
|
||||
|
||||
proc localCertificate*(conn: DtlsConn): seq[byte] =
|
||||
conn.localCert
|
||||
|
||||
proc localCertificate*(self: Dtls): seq[byte] =
|
||||
self.localCert
|
||||
|
||||
# -- MbedTLS Callbacks --
|
||||
|
||||
proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
|
||||
state: cint, pflags: ptr uint32): cint {.cdecl.} =
|
||||
# verify is the procedure called by mbedtls when receiving the remote
|
||||
# certificate. It's usually used to verify the validity of the certificate.
|
||||
# We use this procedure to store the remote certificate as it's mandatory
|
||||
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
|
||||
var self = cast[DtlsConn](ctx)
|
||||
let cert = pcert[]
|
||||
|
||||
self.remoteCert = newSeq[byte](cert.raw.len)
|
||||
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
|
||||
return 0
|
||||
|
||||
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
|
||||
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
|
||||
# we store the future of this write and await it after the end of the
|
||||
# function (see write or dtlsHanshake for example).
|
||||
var self = cast[DtlsConn](ctx)
|
||||
var toWrite = newSeq[byte](len)
|
||||
if len > 0:
|
||||
copyMem(addr toWrite[0], buf, len)
|
||||
trace "dtls send", len
|
||||
self.sendFuture = self.conn.write(self.raddr, toWrite)
|
||||
result = len.cint
|
||||
|
||||
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
|
||||
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
|
||||
# As we cannot asynchronously await for data to be received, we use a data received
|
||||
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
|
||||
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
|
||||
let self = cast[DtlsConn](ctx)
|
||||
if self.dataRecv.len() == 0:
|
||||
return MBEDTLS_ERR_SSL_WANT_READ
|
||||
|
||||
var dataRecv = self.dataRecv.popFirstNoWait()
|
||||
copyMem(buf, addr dataRecv[0], dataRecv.len())
|
||||
result = dataRecv.len().cint
|
||||
trace "dtls receive", len, result
|
||||
|
||||
# -- Dtls Accept / Connect procedures --
|
||||
|
||||
proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} =
|
||||
await conn.join()
|
||||
self.connections.del(raddr)
|
||||
|
||||
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
|
||||
res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
mb_ssl_cookie_init(res.cookie)
|
||||
mb_ssl_cache_init(res.cache)
|
||||
|
||||
res.ctr_drbg = self.ctr_drbg
|
||||
res.entropy = self.entropy
|
||||
|
||||
var pkey = self.serverPrivKey
|
||||
var srvcert = self.serverCert
|
||||
res.localCert = self.localCert
|
||||
|
||||
mb_ssl_config_defaults(res.config,
|
||||
MBEDTLS_SSL_IS_SERVER,
|
||||
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT)
|
||||
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
|
||||
mb_ssl_conf_own_cert(res.config, srvcert, pkey)
|
||||
mb_ssl_cookie_setup(res.cookie, mbedtls_ctr_drbg_random, res.ctr_drbg)
|
||||
mb_ssl_conf_dtls_cookies(res.config, res.cookie)
|
||||
mb_ssl_set_timer_cb(res.ssl, res.timer)
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_session_reset(res.ssl)
|
||||
mb_ssl_set_verify(res.ssl, verify, res)
|
||||
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
|
||||
while true:
|
||||
let (raddr, buf) = await self.pendingHandshakes.popFirst()
|
||||
try:
|
||||
res.raddr = raddr
|
||||
res.dataRecv.addLastNoWait(buf)
|
||||
self.connections[raddr] = res
|
||||
await res.dtlsHandshake(true)
|
||||
asyncSpawn self.removeConnection(res, raddr)
|
||||
break
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
continue
|
||||
return res
|
||||
|
||||
proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
|
||||
var
|
||||
selfvar = self
|
||||
res = DtlsConn()
|
||||
|
||||
res.init(self.conn, self.laddr)
|
||||
mb_ssl_init(res.ssl)
|
||||
mb_ssl_config_init(res.config)
|
||||
|
||||
res.ctr_drbg = self.ctr_drbg
|
||||
res.entropy = self.entropy
|
||||
|
||||
var pkey = res.ctr_drbg.generateKey()
|
||||
var srvcert = res.ctr_drbg.generateCertificate(pkey)
|
||||
res.localCert = newSeq[byte](srvcert.raw.len)
|
||||
copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len)
|
||||
|
||||
mb_ctr_drbg_init(res.ctr_drbg)
|
||||
mb_entropy_init(res.entropy)
|
||||
mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0)
|
||||
|
||||
mb_ssl_config_defaults(res.config,
|
||||
MBEDTLS_SSL_IS_CLIENT,
|
||||
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT)
|
||||
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
|
||||
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
|
||||
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
|
||||
mb_ssl_set_timer_cb(res.ssl, res.timer)
|
||||
mb_ssl_setup(res.ssl, res.config)
|
||||
mb_ssl_set_verify(res.ssl, verify, res)
|
||||
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
|
||||
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
|
||||
|
||||
res.raddr = raddr
|
||||
self.connections[raddr] = res
|
||||
|
||||
try:
|
||||
await res.dtlsHandshake(false)
|
||||
asyncSpawn self.removeConnection(res, raddr)
|
||||
except CatchableError as exc:
|
||||
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
|
||||
self.connections.del(raddr)
|
||||
raise exc
|
||||
|
||||
return res
|
|
@ -0,0 +1,96 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/times
|
||||
|
||||
import stew/byteutils
|
||||
|
||||
import mbedtls/pk
|
||||
import mbedtls/rsa
|
||||
import mbedtls/ctr_drbg
|
||||
import mbedtls/x509_crt
|
||||
import mbedtls/bignum
|
||||
import mbedtls/md
|
||||
|
||||
import chronicles
|
||||
|
||||
# This sequence is used for debugging.
|
||||
const mb_ssl_states* = @[
|
||||
"MBEDTLS_SSL_HELLO_REQUEST",
|
||||
"MBEDTLS_SSL_CLIENT_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CERTIFICATE",
|
||||
"MBEDTLS_SSL_SERVER_KEY_EXCHANGE",
|
||||
"MBEDTLS_SSL_CERTIFICATE_REQUEST",
|
||||
"MBEDTLS_SSL_SERVER_HELLO_DONE",
|
||||
"MBEDTLS_SSL_CLIENT_CERTIFICATE",
|
||||
"MBEDTLS_SSL_CLIENT_KEY_EXCHANGE",
|
||||
"MBEDTLS_SSL_CERTIFICATE_VERIFY",
|
||||
"MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC",
|
||||
"MBEDTLS_SSL_CLIENT_FINISHED",
|
||||
"MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC",
|
||||
"MBEDTLS_SSL_SERVER_FINISHED",
|
||||
"MBEDTLS_SSL_FLUSH_BUFFERS",
|
||||
"MBEDTLS_SSL_HANDSHAKE_WRAPUP",
|
||||
"MBEDTLS_SSL_NEW_SESSION_TICKET",
|
||||
"MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT",
|
||||
"MBEDTLS_SSL_HELLO_RETRY_REQUEST",
|
||||
"MBEDTLS_SSL_ENCRYPTED_EXTENSIONS",
|
||||
"MBEDTLS_SSL_END_OF_EARLY_DATA",
|
||||
"MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO",
|
||||
"MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO",
|
||||
"MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST",
|
||||
"MBEDTLS_SSL_HANDSHAKE_OVER",
|
||||
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET",
|
||||
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH"
|
||||
]
|
||||
|
||||
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
|
||||
var res: mbedtls_pk_context
|
||||
mb_pk_init(res)
|
||||
discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))
|
||||
mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537)
|
||||
let x = mb_pk_rsa(res)
|
||||
res
|
||||
|
||||
template generateCertificate*(random: mbedtls_ctr_drbg_context,
|
||||
issuer_key: mbedtls_pk_context): mbedtls_x509_crt =
|
||||
let
|
||||
# To be honest, I have no clue what to put here as a name
|
||||
name = "C=FR,O=Status,CN=webrtc"
|
||||
time_format = initTimeFormat("YYYYMMddHHmmss")
|
||||
time_from = times.now().format(time_format)
|
||||
time_to = (times.now() + times.years(1)).format(time_format)
|
||||
|
||||
var write_cert: mbedtls_x509write_cert
|
||||
var serial_mpi: mbedtls_mpi
|
||||
mb_x509write_crt_init(write_cert)
|
||||
mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256);
|
||||
mb_x509write_crt_set_subject_key(write_cert, issuer_key)
|
||||
mb_x509write_crt_set_issuer_key(write_cert, issuer_key)
|
||||
mb_x509write_crt_set_subject_name(write_cert, name)
|
||||
mb_x509write_crt_set_issuer_name(write_cert, name)
|
||||
mb_x509write_crt_set_validity(write_cert, time_from, time_to)
|
||||
mb_x509write_crt_set_basic_constraints(write_cert, 0, -1)
|
||||
mb_x509write_crt_set_subject_key_identifier(write_cert)
|
||||
mb_x509write_crt_set_authority_key_identifier(write_cert)
|
||||
mb_mpi_init(serial_mpi)
|
||||
let serial_hex = mb_mpi_read_string(serial_mpi, 16)
|
||||
mb_x509write_crt_set_serial(write_cert, serial_mpi)
|
||||
let buf =
|
||||
try:
|
||||
mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random)
|
||||
except MbedTLSError as e:
|
||||
raise e
|
||||
var res: mbedtls_x509_crt
|
||||
mb_x509_crt_parse(res, buf)
|
||||
res
|
459
webrtc/sctp.nim
459
webrtc/sctp.nim
|
@ -1,5 +1,5 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2022 Status Research & Development GmbH
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
|
@ -8,14 +8,29 @@
|
|||
# those terms.
|
||||
|
||||
import tables, bitops, posix, strutils, sequtils
|
||||
import chronos, chronicles, stew/ranges/ptr_arith
|
||||
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
|
||||
import usrsctp
|
||||
import dtls/dtls
|
||||
import binary_serialization
|
||||
|
||||
export chronicles
|
||||
|
||||
logScope:
|
||||
topics = "webrtc sctp"
|
||||
|
||||
# Implementation of an Sctp client and server using the usrsctp library.
|
||||
# Usrsctp is usable as a single thread but it's not the intended way to
|
||||
# use it. There's a lot of callbacks calling each other in a synchronous
|
||||
# way where we want to be able to call asynchronous procedure, but cannot.
|
||||
|
||||
# TODO:
|
||||
# - Replace doAssert by a proper exception management
|
||||
# - Find a clean way to manage SCTP ports
|
||||
# - Unregister address when closing
|
||||
|
||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||
|
||||
type
|
||||
SctpError* = object of CatchableError
|
||||
|
||||
|
@ -24,115 +39,172 @@ type
|
|||
Connected
|
||||
Closed
|
||||
|
||||
SctpConnection* = ref object
|
||||
SctpMessageParameters* = object
|
||||
protocolId*: uint32
|
||||
streamId*: uint16
|
||||
endOfRecord*: bool
|
||||
unordered*: bool
|
||||
|
||||
SctpMessage* = ref object
|
||||
data*: seq[byte]
|
||||
info: sctp_recvv_rn
|
||||
params*: SctpMessageParameters
|
||||
|
||||
SctpConn* = ref object
|
||||
conn*: DtlsConn
|
||||
state: SctpState
|
||||
connectEvent: AsyncEvent
|
||||
acceptEvent: AsyncEvent
|
||||
readLoop: Future[void]
|
||||
sctp: Sctp
|
||||
udp: DatagramTransport
|
||||
address: TransportAddress
|
||||
sctpSocket: ptr socket
|
||||
recvEvent: AsyncEvent
|
||||
dataRecv: seq[byte]
|
||||
dataRecv: AsyncQueue[SctpMessage]
|
||||
sentFuture: Future[void]
|
||||
|
||||
Sctp* = ref object
|
||||
dtls: Dtls
|
||||
udp: DatagramTransport
|
||||
connections: Table[TransportAddress, SctpConnection]
|
||||
connections: Table[TransportAddress, SctpConn]
|
||||
gotConnection: AsyncEvent
|
||||
timersHandler: Future[void]
|
||||
isServer: bool
|
||||
sockServer: ptr socket
|
||||
pendingConnections: seq[SctpConnection]
|
||||
sentFuture: Future[void]
|
||||
sentConnection: SctpConnection
|
||||
pendingConnections: seq[SctpConn]
|
||||
pendingConnections2: Table[SockAddr, SctpConn]
|
||||
sentAddress: TransportAddress
|
||||
sentFuture: Future[void]
|
||||
|
||||
const
|
||||
IPPROTO_SCTP = 132
|
||||
# These three objects are used for debugging/trace only
|
||||
SctpChunk = object
|
||||
chunkType: uint8
|
||||
flag: uint8
|
||||
length {.bin_value: it.data.len() + 4.}: uint16
|
||||
data {.bin_len: it.length - 4.}: seq[byte]
|
||||
|
||||
proc newSctpError(msg: string): ref SctpError =
|
||||
result = newException(SctpError, msg)
|
||||
SctpPacketHeader = object
|
||||
srcPort: uint16
|
||||
dstPort: uint16
|
||||
verifTag: uint32
|
||||
checksum: uint32
|
||||
|
||||
template usrsctpAwait(sctp: Sctp, body: untyped): untyped =
|
||||
sctp.sentFuture = nil
|
||||
SctpPacketStructure = object
|
||||
header: SctpPacketHeader
|
||||
chunks: seq[SctpChunk]
|
||||
|
||||
const IPPROTO_SCTP = 132
|
||||
|
||||
proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
|
||||
# Only used for debugging/trace
|
||||
result.header = Binary.decode(buffer, SctpPacketHeader)
|
||||
var size = sizeof(SctpPacketStructure)
|
||||
while size < buffer.len:
|
||||
let chunk = Binary.decode(buffer[size..^1], SctpChunk)
|
||||
result.chunks.add(chunk)
|
||||
size.inc(chunk.length.int)
|
||||
while size mod 4 != 0:
|
||||
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
|
||||
size.inc(1)
|
||||
|
||||
# -- Asynchronous wrapper --
|
||||
|
||||
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
|
||||
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
|
||||
# an usrsctp function. If during the synchronous run of the usrsctp function
|
||||
# `sendCallback` is called, then `sentFuture` is set and waited.
|
||||
self.sentFuture = nil
|
||||
when type(body) is void:
|
||||
body
|
||||
if sctp.sentFuture != nil: await sctp.sentFuture
|
||||
if self.sentFuture != nil: await self.sentFuture
|
||||
else:
|
||||
let res = body
|
||||
if sctp.sentFuture != nil: await sctp.sentFuture
|
||||
if self.sentFuture != nil: await self.sentFuture
|
||||
res
|
||||
|
||||
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
|
||||
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
|
||||
# -- SctpConn --
|
||||
|
||||
proc packetPretty(packet: cstring): string =
|
||||
let data = $packet
|
||||
let ctn = data[23..^16]
|
||||
result = data[1..14]
|
||||
if ctn.len > 30:
|
||||
result = result & ctn[0..14] & " ... " & ctn[^14..^1]
|
||||
else:
|
||||
result = result & ctn
|
||||
|
||||
proc new(T: typedesc[SctpConnection],
|
||||
sctp: Sctp,
|
||||
udp: DatagramTransport,
|
||||
address: TransportAddress,
|
||||
sctpSocket: ptr socket): T =
|
||||
T(sctp: sctp,
|
||||
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
|
||||
T(conn: conn,
|
||||
sctp: sctp,
|
||||
state: Connecting,
|
||||
udp: udp,
|
||||
address: address,
|
||||
sctpSocket: sctpSocket,
|
||||
connectEvent: AsyncEvent(),
|
||||
recvEvent: AsyncEvent())
|
||||
acceptEvent: AsyncEvent(),
|
||||
dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
|
||||
)
|
||||
|
||||
proc read*(self: SctpConnection): Future[seq[byte]] {.async.} =
|
||||
trace "Read"
|
||||
if self.dataRecv.len == 0:
|
||||
self.recvEvent.clear()
|
||||
await self.recvEvent.wait()
|
||||
let res = self.dataRecv
|
||||
self.dataRecv = @[]
|
||||
return res
|
||||
proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
|
||||
# Used by DataChannel, returns SctpMessage in order to get the stream
|
||||
# and protocol ids
|
||||
return await self.dataRecv.popFirst()
|
||||
|
||||
proc write*(self: SctpConnection, buf: seq[byte]) {.async.} =
|
||||
proc toFlags(params: SctpMessageParameters): uint16 =
|
||||
if params.endOfRecord:
|
||||
result = result or SCTP_EOR
|
||||
if params.unordered:
|
||||
result = result or SCTP_UNORDERED
|
||||
|
||||
proc write*(self: SctpConn, buf: seq[byte],
|
||||
sendParams = default(SctpMessageParameters)) {.async.} =
|
||||
# Used by DataChannel, writes buf on the Dtls connection.
|
||||
trace "Write", buf
|
||||
self.sctp.sentConnection = self
|
||||
self.sctp.sentAddress = self.address
|
||||
let sendvErr = self.sctp.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint,
|
||||
nil, 0, nil, 0,
|
||||
SCTP_SENDV_NOINFO, 0)
|
||||
|
||||
proc close*(self: SctpConnection) {.async.} =
|
||||
self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close()
|
||||
var cpy = buf
|
||||
let sendvErr =
|
||||
if sendParams == default(SctpMessageParameters):
|
||||
# If writes is called by DataChannel, sendParams should never
|
||||
# be the default value. This split is useful for testing.
|
||||
self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
|
||||
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
|
||||
else:
|
||||
let sendInfo = sctp_sndinfo(
|
||||
snd_sid: sendParams.streamId,
|
||||
# TODO: swapBytes => htonl?
|
||||
snd_ppid: sendParams.protocolId.swapBytes(),
|
||||
snd_flags: sendParams.toFlags)
|
||||
self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
|
||||
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
|
||||
SCTP_SENDV_SNDINFO.cuint, 0)
|
||||
if sendvErr < 0:
|
||||
# TODO: throw an exception
|
||||
perror("usrsctp_sendv")
|
||||
|
||||
proc write*(self: SctpConn, s: string) {.async.} =
|
||||
await self.write(s.toBytes())
|
||||
|
||||
proc close*(self: SctpConn) {.async.} =
|
||||
self.usrsctpAwait:
|
||||
self.sctpSocket.usrsctp_close()
|
||||
|
||||
# -- usrsctp receive data callbacks --
|
||||
|
||||
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when we receive data after
|
||||
# connection has been established.
|
||||
let
|
||||
conn = cast[SctpConn](data)
|
||||
events = usrsctp_get_events(sock)
|
||||
conn = cast[SctpConnection](data)
|
||||
|
||||
trace "Handle Upcall", events
|
||||
if conn.state == Connecting:
|
||||
if bitand(events, SCTP_EVENT_ERROR) != 0:
|
||||
warn "Cannot connect", address = conn.address
|
||||
conn.state = Closed
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
conn.state = Connected
|
||||
conn.connectEvent.fire()
|
||||
elif bitand(events, SCTP_EVENT_READ) != 0:
|
||||
if bitand(events, SCTP_EVENT_READ) != 0:
|
||||
var
|
||||
buffer = newSeq[byte](4096)
|
||||
message = SctpMessage(
|
||||
data: newSeq[byte](4096)
|
||||
)
|
||||
address: Sockaddr_storage
|
||||
rn: sctp_recvv_rn
|
||||
addressLen = sizeof(Sockaddr_storage).SockLen
|
||||
rnLen = sizeof(sctp_recvv_rn).SockLen
|
||||
infotype: uint
|
||||
flags: int
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr buffer[0]), buffer.len.uint,
|
||||
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]),
|
||||
message.data.len.uint,
|
||||
cast[ptr SockAddr](addr address),
|
||||
cast[ptr SockLen](addr addressLen),
|
||||
cast[pointer](addr rn),
|
||||
cast[pointer](addr message.info),
|
||||
cast[ptr SockLen](addr rnLen),
|
||||
cast[ptr cuint](addr infotype),
|
||||
cast[ptr cint](addr flags))
|
||||
|
@ -140,91 +212,152 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
|||
perror("usrsctp_recvv")
|
||||
return
|
||||
elif n > 0:
|
||||
# It might be necessary to check if infotype == SCTP_RECVV_RCVINFO
|
||||
message.data.delete(n..<message.data.len())
|
||||
trace "message info from handle upcall", msginfo = message.info
|
||||
message.params = SctpMessageParameters(
|
||||
protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
|
||||
streamId: message.info.recvv_rcvinfo.rcv_sid
|
||||
)
|
||||
if bitand(flags, MSG_NOTIFICATION) != 0:
|
||||
trace "Notification received", length = n
|
||||
else:
|
||||
conn.dataRecv = conn.dataRecv.concat(buffer[0..n])
|
||||
conn.recvEvent.fire()
|
||||
try:
|
||||
conn.dataRecv.addLastNoWait(message)
|
||||
except AsyncQueueFullError:
|
||||
trace "Queue full, dropping packet"
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
trace "sctp event write in the upcall"
|
||||
else:
|
||||
warn "Handle Upcall unexpected event", events
|
||||
|
||||
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when accepting a connection.
|
||||
trace "Handle Accept"
|
||||
var
|
||||
sconn: Sockaddr_conn
|
||||
slen: Socklen = sizeof(Sockaddr_conn).uint32
|
||||
let
|
||||
sctp = cast[Sctp](data)
|
||||
sctpSocket = usrsctp_accept(sctp.sockServer, nil, nil)
|
||||
# TODO: check if sctpSocket != nil
|
||||
sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen)
|
||||
|
||||
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
|
||||
let conn = SctpConnection.new(sctp, sctp.udp, sctp.sentAddress, sctpSocket)
|
||||
sctp.connections[sctp.sentAddress] = conn
|
||||
sctp.pendingConnections.add(conn)
|
||||
let conn = cast[SctpConn](sconn.sconn_addr)
|
||||
conn.sctpSocket = sctpSocket
|
||||
conn.state = Connected
|
||||
doAssert 0 == sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
||||
sctp.gotConnection.fire()
|
||||
var nodelay: uint32 = 1
|
||||
var recvinfo: uint32 = 1
|
||||
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
|
||||
addr nodelay, sizeof(nodelay).SockLen)
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
|
||||
addr recvinfo, sizeof(recvinfo).SockLen)
|
||||
conn.acceptEvent.fire()
|
||||
|
||||
proc getOrCreateConnection(self: Sctp,
|
||||
udp: DatagramTransport,
|
||||
address: TransportAddress,
|
||||
sctpPort: uint16 = 5000): Future[SctpConnection] {.async.} =
|
||||
#TODO remove the = 5000
|
||||
if self.connections.hasKey(address):
|
||||
return self.connections[address]
|
||||
trace "Create Connection", address
|
||||
proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
|
||||
# Callback procedure called when connecting
|
||||
trace "Handle Connect"
|
||||
let
|
||||
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
conn = SctpConnection.new(self, udp, address, sctpSocket)
|
||||
var on: int = 1
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP,
|
||||
SCTP_RECVRCVINFO,
|
||||
addr on,
|
||||
sizeof(on).SockLen)
|
||||
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, cast[pointer](conn))
|
||||
var sconn: Sockaddr_conn
|
||||
sconn.sconn_family = AF_CONN
|
||||
sconn.sconn_port = htons(sctpPort)
|
||||
sconn.sconn_addr = cast[pointer](self)
|
||||
self.sentConnection = conn
|
||||
self.sentAddress = address
|
||||
let connErr = self.usrsctpAwait:
|
||||
conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
|
||||
doAssert 0 == connErr or errno == EINPROGRESS, ($errno) # TODO raise
|
||||
self.connections[address] = conn
|
||||
return conn
|
||||
conn = cast[SctpConn](data)
|
||||
events = usrsctp_get_events(sock)
|
||||
|
||||
proc sendCallback(address: pointer,
|
||||
trace "Handle Upcall", events, state = conn.state
|
||||
if conn.state == Connecting:
|
||||
if bitand(events, SCTP_EVENT_ERROR) != 0:
|
||||
warn "Cannot connect", address = conn.address
|
||||
conn.state = Closed
|
||||
elif bitand(events, SCTP_EVENT_WRITE) != 0:
|
||||
conn.state = Connected
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data)
|
||||
conn.connectEvent.fire()
|
||||
else:
|
||||
warn "should be connecting", currentState = conn.state
|
||||
|
||||
# -- usrsctp send data callback --
|
||||
|
||||
proc sendCallback(ctx: pointer,
|
||||
buffer: pointer,
|
||||
length: uint,
|
||||
tos: uint8,
|
||||
set_df: uint8): cint {.cdecl.} =
|
||||
let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND)
|
||||
if data != nil:
|
||||
trace "sendCallback", data = data.packetPretty(), length
|
||||
trace "sendCallback", sctpPacket = data.getSctpPacket(), length
|
||||
usrsctp_freedumpbuffer(data)
|
||||
let sctp = cast[Sctp](address)
|
||||
let sctpConn = cast[SctpConn](ctx)
|
||||
let buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||
proc testSend() {.async.} =
|
||||
try:
|
||||
let
|
||||
buf = @(buffer.makeOpenArray(byte, int(length)))
|
||||
address = sctp.sentAddress
|
||||
trace "Send To", address
|
||||
await sendTo(sctp.udp, address, buf, int(length))
|
||||
trace "Send To", address = sctpConn.address
|
||||
await sctpConn.conn.write(buf)
|
||||
except CatchableError as exc:
|
||||
trace "Send Failed", message = exc.msg
|
||||
sctp.sentFuture = testSend()
|
||||
sctpConn.sentFuture = testSend()
|
||||
|
||||
# -- Sctp --
|
||||
|
||||
proc timersHandler() {.async.} =
|
||||
while true:
|
||||
await sleepAsync(500.milliseconds)
|
||||
usrsctp_handle_timers(500)
|
||||
|
||||
proc startServer*(self: Sctp, sctpPort: uint16 = 5000) =
|
||||
proc stopServer*(self: Sctp) =
|
||||
if not self.isServer:
|
||||
trace "Try to close a client"
|
||||
return
|
||||
self.isServer = false
|
||||
let pcs = self.pendingConnections
|
||||
self.pendingConnections = @[]
|
||||
for pc in pcs:
|
||||
pc.sctpSocket.usrsctp_close()
|
||||
self.sockServer.usrsctp_close()
|
||||
|
||||
proc init*(self: Sctp, dtls: Dtls, laddr: TransportAddress) =
|
||||
self.gotConnection = newAsyncEvent()
|
||||
self.timersHandler = timersHandler()
|
||||
self.dtls = dtls
|
||||
|
||||
usrsctp_init_nothreads(laddr.port.uint16, sendCallback, printf)
|
||||
discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
|
||||
discard usrsctp_sysctl_set_sctp_ecn_enable(1)
|
||||
usrsctp_register_address(cast[pointer](self))
|
||||
|
||||
proc stop*(self: Sctp) {.async.} =
|
||||
# TODO: close every connections
|
||||
discard self.usrsctpAwait usrsctp_finish()
|
||||
self.udp.close()
|
||||
|
||||
proc readLoopProc(res: SctpConn) {.async.} =
|
||||
while true:
|
||||
let
|
||||
msg = await res.conn.read()
|
||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
if not data.isNil():
|
||||
trace "Receive data", remoteAddress = res.conn.raddr,
|
||||
sctpPacket = data.getSctpPacket()
|
||||
usrsctp_freedumpbuffer(data)
|
||||
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
|
||||
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
|
||||
if not self.isServer:
|
||||
raise newException(SctpError, "Not a server")
|
||||
var res = SctpConn.new(await self.dtls.accept(), self)
|
||||
usrsctp_register_address(cast[pointer](res))
|
||||
res.readLoop = res.readLoopProc()
|
||||
res.acceptEvent.clear()
|
||||
await res.acceptEvent.wait()
|
||||
return res
|
||||
|
||||
proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
|
||||
if self.isServer:
|
||||
trace "Try to start the server twice"
|
||||
return
|
||||
self.isServer = true
|
||||
trace "Listening", sctpPort
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
|
||||
doAssert 0 == usrsctp_sysctl_set_sctp_delayed_sack_time_default(0)
|
||||
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
var on: int = 1
|
||||
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
|
||||
|
@ -237,75 +370,37 @@ proc startServer*(self: Sctp, sctpPort: uint16 = 5000) =
|
|||
doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self))
|
||||
self.sockServer = sock
|
||||
|
||||
proc closeServer(self: Sctp) =
|
||||
if not self.isServer:
|
||||
trace "Try to close a client"
|
||||
return
|
||||
self.isServer = false
|
||||
let pcs = self.pendingConnections
|
||||
self.pendingConnections = @[]
|
||||
for pc in pcs:
|
||||
pc.sctpSocket.usrsctp_close()
|
||||
self.sockServer.usrsctp_close()
|
||||
|
||||
proc new*(T: typedesc[Sctp], port: uint16 = 9899): T =
|
||||
logScope: topics = "webrtc sctp"
|
||||
let sctp = T(gotConnection: newAsyncEvent())
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
let
|
||||
msg = udp.getMessage()
|
||||
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
|
||||
if data != nil:
|
||||
if sctp.isServer:
|
||||
trace "onReceive (server)", data = data.packetPretty(), length = msg.len(), address
|
||||
else:
|
||||
trace "onReceive (client)", data = data.packetPretty(), length = msg.len(), address
|
||||
usrsctp_freedumpbuffer(data)
|
||||
|
||||
if sctp.isServer:
|
||||
sctp.sentAddress = address
|
||||
usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
else:
|
||||
let conn = await sctp.getOrCreateConnection(udp, address)
|
||||
sctp.sentConnection = conn
|
||||
sctp.sentAddress = address
|
||||
usrsctp_conninput(cast[pointer](sctp), unsafeAddr msg[0], uint(msg.len), 0)
|
||||
let
|
||||
localAddr = TransportAddress(family: AddressFamily.IPv4, port: Port(port))
|
||||
laddr = initTAddress("127.0.0.1:" & $port)
|
||||
udp = newDatagramTransport(onReceive, local = laddr)
|
||||
trace "local address", localAddr, laddr
|
||||
sctp.udp = udp
|
||||
sctp.timersHandler = timersHandler()
|
||||
|
||||
usrsctp_init_nothreads(port, sendCallback, printf)
|
||||
discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
|
||||
discard usrsctp_sysctl_set_sctp_ecn_enable(1)
|
||||
usrsctp_register_address(cast[pointer](sctp))
|
||||
|
||||
return sctp
|
||||
|
||||
proc listen*(self: Sctp): Future[SctpConnection] {.async.} =
|
||||
if not self.isServer:
|
||||
raise newSctpError("Not a server")
|
||||
trace "Listening"
|
||||
if self.pendingConnections.len == 0:
|
||||
self.gotConnection.clear()
|
||||
await self.gotConnection.wait()
|
||||
let res = self.pendingConnections[0]
|
||||
self.pendingConnections.delete(0)
|
||||
return res
|
||||
|
||||
proc connect*(self: Sctp,
|
||||
address: TransportAddress,
|
||||
sctpPort: uint16 = 5000): Future[SctpConnection] {.async.} =
|
||||
trace "Connect", address
|
||||
let conn = await self.getOrCreateConnection(self.udp, address, sctpPort)
|
||||
try:
|
||||
await conn.connectEvent.wait()
|
||||
except CancelledError as exc:
|
||||
conn.sctpSocket.usrsctp_close()
|
||||
return nil
|
||||
if conn.state != Connected:
|
||||
raise newSctpError("Cannot connect to " & $address)
|
||||
sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
|
||||
let
|
||||
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
|
||||
conn = SctpConn.new(await self.dtls.connect(address), self)
|
||||
|
||||
trace "Create Connection", address
|
||||
conn.sctpSocket = sctpSocket
|
||||
conn.state = Connected
|
||||
var nodelay: uint32 = 1
|
||||
var recvinfo: uint32 = 1
|
||||
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
|
||||
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn))
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
|
||||
addr nodelay, sizeof(nodelay).SockLen)
|
||||
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
|
||||
addr recvinfo, sizeof(recvinfo).SockLen)
|
||||
var sconn: Sockaddr_conn
|
||||
sconn.sconn_family = AF_CONN
|
||||
sconn.sconn_port = htons(sctpPort)
|
||||
sconn.sconn_addr = cast[pointer](conn)
|
||||
self.sentAddress = address
|
||||
usrsctp_register_address(cast[pointer](conn))
|
||||
conn.readLoop = conn.readLoopProc()
|
||||
let connErr = self.usrsctpAwait:
|
||||
conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
|
||||
doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno)
|
||||
conn.state = Connecting
|
||||
conn.connectEvent.clear()
|
||||
await conn.connectEvent.wait()
|
||||
# TODO: check connection state, if closed throw an exception
|
||||
self.connections[address] = conn
|
||||
return conn
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import bitops, strutils
|
||||
import chronos,
|
||||
chronicles,
|
||||
binary_serialization,
|
||||
stew/objects,
|
||||
stew/byteutils
|
||||
import stun_attributes
|
||||
|
||||
export binary_serialization
|
||||
|
||||
logScope:
|
||||
topics = "webrtc stun"
|
||||
|
||||
const
|
||||
msgHeaderSize = 20
|
||||
magicCookieSeq = @[ 0x21'u8, 0x12, 0xa4, 0x42 ]
|
||||
magicCookie = 0x2112a442
|
||||
BindingRequest = 0x0001'u16
|
||||
BindingResponse = 0x0101'u16
|
||||
|
||||
proc decode(T: typedesc[RawStunAttribute], cnt: seq[byte]): seq[RawStunAttribute] =
|
||||
const pad = @[0, 3, 2, 1]
|
||||
var padding = 0
|
||||
while padding < cnt.len():
|
||||
let attr = Binary.decode(cnt[padding ..^ 1], RawStunAttribute)
|
||||
result.add(attr)
|
||||
padding += 4 + attr.value.len()
|
||||
padding += pad[padding mod 4]
|
||||
|
||||
type
|
||||
# Stun Header
|
||||
# 0 1 2 3
|
||||
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# |0 0| STUN Message Type | Message Length |
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# | Magic Cookie |
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# | |
|
||||
# | Transaction ID (96 bits) |
|
||||
# | |
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# Message type:
|
||||
# 0x0001: Binding Request
|
||||
# 0x0101: Binding Response
|
||||
# 0x0111: Binding Error Response
|
||||
# 0x0002: Shared Secret Request
|
||||
# 0x0102: Shared Secret Response
|
||||
# 0x0112: Shared Secret Error Response
|
||||
|
||||
RawStunMessage = object
|
||||
msgType: uint16
|
||||
# it.conten.len() + 8 Because the Fingerprint is added after the encoding
|
||||
length* {.bin_value: it.content.len().}: uint16
|
||||
magicCookie: uint32
|
||||
transactionId: array[12, byte]
|
||||
content* {.bin_len: it.length.}: seq[byte]
|
||||
|
||||
StunMessage* = object
|
||||
msgType*: uint16
|
||||
transactionId*: array[12, byte]
|
||||
attributes*: seq[RawStunAttribute]
|
||||
|
||||
Stun* = object
|
||||
|
||||
proc getAttribute(attrs: seq[RawStunAttribute], typ: uint16): Option[seq[byte]] =
|
||||
for attr in attrs:
|
||||
if attr.attributeType == typ:
|
||||
return some(attr.value)
|
||||
return none(seq[byte])
|
||||
|
||||
proc isMessage*(T: typedesc[Stun], msg: seq[byte]): bool =
|
||||
msg.len >= msgHeaderSize and msg[4..<8] == magicCookieSeq and bitand(0xC0'u8, msg[0]) == 0'u8
|
||||
|
||||
proc addLength(msgEncoded: var seq[byte], length: uint16) =
|
||||
let
|
||||
hi = (length div 256'u16).uint8
|
||||
lo = (length mod 256'u16).uint8
|
||||
msgEncoded[2] = msgEncoded[2] + hi
|
||||
if msgEncoded[3].int + lo.int >= 256:
|
||||
msgEncoded[2] = msgEncoded[2] + 1
|
||||
msgEncoded[3] = ((msgEncoded[3].int + lo.int) mod 256).uint8
|
||||
else:
|
||||
msgEncoded[3] = msgEncoded[3] + lo
|
||||
|
||||
proc decode*(T: typedesc[StunMessage], msg: seq[byte]): StunMessage =
|
||||
let smi = Binary.decode(msg, RawStunMessage)
|
||||
return T(msgType: smi.msgType,
|
||||
transactionId: smi.transactionId,
|
||||
attributes: RawStunAttribute.decode(smi.content))
|
||||
|
||||
proc encode*(msg: StunMessage, userOpt: Option[seq[byte]]): seq[byte] =
|
||||
const pad = @[0, 3, 2, 1]
|
||||
var smi = RawStunMessage(msgType: msg.msgType,
|
||||
magicCookie: magicCookie,
|
||||
transactionId: msg.transactionId)
|
||||
for attr in msg.attributes:
|
||||
smi.content.add(Binary.encode(attr))
|
||||
smi.content.add(newSeq[byte](pad[smi.content.len() mod 4]))
|
||||
|
||||
result = Binary.encode(smi)
|
||||
|
||||
if userOpt.isSome():
|
||||
let username = string.fromBytes(userOpt.get())
|
||||
let usersplit = username.split(":")
|
||||
if usersplit.len() == 2 and usersplit[0].startsWith("libp2p+webrtc+v1/"):
|
||||
result.addLength(24)
|
||||
result.add(Binary.encode(MessageIntegrity.encode(result, toBytes(usersplit[0]))))
|
||||
|
||||
result.addLength(8)
|
||||
result.add(Binary.encode(Fingerprint.encode(result)))
|
||||
|
||||
proc getResponse*(T: typedesc[Stun], msg: seq[byte],
|
||||
ta: TransportAddress): Option[seq[byte]] =
|
||||
if ta.family != AddressFamily.IPv4 and ta.family != AddressFamily.IPv6:
|
||||
return none(seq[byte])
|
||||
let sm =
|
||||
try:
|
||||
StunMessage.decode(msg)
|
||||
except CatchableError as exc:
|
||||
return none(seq[byte])
|
||||
|
||||
if sm.msgType != BindingRequest:
|
||||
return none(seq[byte])
|
||||
|
||||
var res = StunMessage(msgType: BindingResponse,
|
||||
transactionId: sm.transactionId)
|
||||
|
||||
var unknownAttr: seq[uint16]
|
||||
for attr in sm.attributes:
|
||||
let typ = attr.attributeType
|
||||
if typ.isRequired() and typ notin StunAttributeEnum:
|
||||
unknownAttr.add(typ)
|
||||
if unknownAttr.len() > 0:
|
||||
res.attributes.add(ErrorCode.encode(ECUnknownAttribute))
|
||||
res.attributes.add(UnknownAttribute.encode(unknownAttr))
|
||||
return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16)))
|
||||
|
||||
res.attributes.add(XorMappedAddress.encode(ta, sm.transactionId))
|
||||
return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16)))
|
||||
|
||||
proc new*(T: typedesc[Stun]): T =
|
||||
result = T()
|
|
@ -0,0 +1,228 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import std/sha1, sequtils, typetraits, std/md5
|
||||
import binary_serialization,
|
||||
stew/byteutils,
|
||||
chronos
|
||||
|
||||
# -- Utils --
|
||||
|
||||
proc createCrc32Table(): array[0..255, uint32] =
|
||||
for i in 0..255:
|
||||
var rem = i.uint32
|
||||
for j in 0..7:
|
||||
if (rem and 1) > 0:
|
||||
rem = (rem shr 1) xor 0xedb88320'u32
|
||||
else:
|
||||
rem = rem shr 1
|
||||
result[i] = rem
|
||||
|
||||
proc crc32(s: seq[byte]): uint32 =
|
||||
# CRC-32 is used for the fingerprint attribute
|
||||
# See https://datatracker.ietf.org/doc/html/rfc5389#section-15.5
|
||||
const crc32table = createCrc32Table()
|
||||
result = 0xffffffff'u32
|
||||
for c in s:
|
||||
result = (result shr 8) xor crc32table[(result and 0xff) xor c]
|
||||
result = not result
|
||||
|
||||
proc hmacSha1(key: seq[byte], msg: seq[byte]): seq[byte] =
|
||||
# HMAC-SHA1 is used for the message integrity attribute
|
||||
# See https://datatracker.ietf.org/doc/html/rfc5389#section-15.4
|
||||
let
|
||||
keyPadded =
|
||||
if len(key) > 64:
|
||||
@(secureHash(key.mapIt(it.chr)).distinctBase)
|
||||
elif key.len() < 64:
|
||||
key.concat(newSeq[byte](64 - key.len()))
|
||||
else:
|
||||
key
|
||||
innerHash = keyPadded.
|
||||
mapIt(it xor 0x36'u8).
|
||||
concat(msg).
|
||||
mapIt(it.chr).
|
||||
secureHash()
|
||||
outerHash = keyPadded.
|
||||
mapIt(it xor 0x5c'u8).
|
||||
concat(@(innerHash.distinctBase)).
|
||||
mapIt(it.chr).
|
||||
secureHash()
|
||||
return @(outerHash.distinctBase)
|
||||
|
||||
# -- Attributes --
|
||||
# There are obviously some attributes implementation that are missing,
|
||||
# it might be something to do eventually if we want to make this
|
||||
# repository work for other project than nim-libp2p
|
||||
#
|
||||
# Stun Attribute
|
||||
# 0 1 2 3
|
||||
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# | Type | Length |
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
# | Value (variable) ....
|
||||
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
type
|
||||
StunAttributeEncodingError* = object of CatchableError
|
||||
|
||||
RawStunAttribute* = object
|
||||
attributeType*: uint16
|
||||
length* {.bin_value: it.value.len.}: uint16
|
||||
value* {.bin_len: it.length.}: seq[byte]
|
||||
|
||||
StunAttributeEnum* = enum
|
||||
AttrMappedAddress = 0x0001
|
||||
AttrChangeRequest = 0x0003 # RFC5780 Nat Behavior Discovery
|
||||
AttrSourceAddress = 0x0004 # Deprecated
|
||||
AttrChangedAddress = 0x0005 # Deprecated
|
||||
AttrUsername = 0x0006
|
||||
AttrMessageIntegrity = 0x0008
|
||||
AttrErrorCode = 0x0009
|
||||
AttrUnknownAttributes = 0x000A
|
||||
AttrChannelNumber = 0x000C # RFC5766 TURN
|
||||
AttrLifetime = 0x000D # RFC5766 TURN
|
||||
AttrXORPeerAddress = 0x0012 # RFC5766 TURN
|
||||
AttrData = 0x0013 # RFC5766 TURN
|
||||
AttrRealm = 0x0014
|
||||
AttrNonce = 0x0015
|
||||
AttrXORRelayedAddress = 0x0016 # RFC5766 TURN
|
||||
AttrRequestedAddressFamily = 0x0017 # RFC6156
|
||||
AttrEvenPort = 0x0018 # RFC5766 TURN
|
||||
AttrRequestedTransport = 0x0019 # RFC5766 TURN
|
||||
AttrDontFragment = 0x001A # RFC5766 TURN
|
||||
AttrMessageIntegritySHA256 = 0x001C # RFC8489 STUN (v2)
|
||||
AttrPasswordAlgorithm = 0x001D # RFC8489 STUN (v2)
|
||||
AttrUserhash = 0x001E # RFC8489 STUN (v2)
|
||||
AttrXORMappedAddress = 0x0020
|
||||
AttrReservationToken = 0x0022 # RFC5766 TURN
|
||||
AttrPriority = 0x0024 # RFC5245 ICE
|
||||
AttrUseCandidate = 0x0025 # RFC5245 ICE
|
||||
AttrPadding = 0x0026 # RFC5780 Nat Behavior Discovery
|
||||
AttrResponsePort = 0x0027 # RFC5780 Nat Behavior Discovery
|
||||
AttrConnectionID = 0x002a # RFC6062 TURN Extensions
|
||||
AttrPasswordAlgorithms = 0x8002 # RFC8489 STUN (v2)
|
||||
AttrAlternateDomain = 0x8003 # RFC8489 STUN (v2)
|
||||
AttrSoftware = 0x8022
|
||||
AttrAlternateServer = 0x8023
|
||||
AttrCacheTimeout = 0x8027 # RFC5780 Nat Behavior Discovery
|
||||
AttrFingerprint = 0x8028
|
||||
AttrICEControlled = 0x8029 # RFC5245 ICE
|
||||
AttrICEControlling = 0x802A # RFC5245 ICE
|
||||
AttrResponseOrigin = 0x802b # RFC5780 Nat Behavior Discovery
|
||||
AttrOtherAddress = 0x802C # RFC5780 Nat Behavior Discovery
|
||||
AttrOrigin = 0x802F
|
||||
|
||||
proc isRequired*(typ: uint16): bool = typ <= 0x7FFF'u16
|
||||
proc isOptional*(typ: uint16): bool = typ >= 0x8000'u16
|
||||
|
||||
# Error Code
|
||||
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.6
|
||||
|
||||
type
|
||||
ErrorCodeEnum* = enum
|
||||
ECTryAlternate = 300
|
||||
ECBadRequest = 400
|
||||
ECUnauthenticated = 401
|
||||
ECUnknownAttribute = 420
|
||||
ECStaleNonce = 438
|
||||
ECServerError = 500
|
||||
ErrorCode* = object
|
||||
reserved1: uint16 # should be 0
|
||||
reserved2 {.bin_bitsize: 5.}: uint8 # should be 0
|
||||
class {.bin_bitsize: 3.}: uint8
|
||||
number: uint8
|
||||
reason: seq[byte]
|
||||
|
||||
proc encode*(T: typedesc[ErrorCode], code: ErrorCodeEnum, reason: string = ""): RawStunAttribute =
|
||||
let
|
||||
ec = T(class: (code.uint16 div 100'u16).uint8,
|
||||
number: (code.uint16 mod 100'u16).uint8,
|
||||
reason: reason.toBytes())
|
||||
value = Binary.encode(ec)
|
||||
result = RawStunAttribute(attributeType: AttrErrorCode.uint16,
|
||||
length: value.len().uint16,
|
||||
value: value)
|
||||
|
||||
# Unknown Attribute
|
||||
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.9
|
||||
|
||||
type
|
||||
UnknownAttribute* = object
|
||||
unknownAttr: seq[uint16]
|
||||
|
||||
proc encode*(T: typedesc[UnknownAttribute], unknownAttr: seq[uint16]): RawStunAttribute =
|
||||
let
|
||||
ua = T(unknownAttr: unknownAttr)
|
||||
value = Binary.encode(ua)
|
||||
result = RawStunAttribute(attributeType: AttrUnknownAttributes.uint16,
|
||||
length: value.len().uint16,
|
||||
value: value)
|
||||
|
||||
# Fingerprint
|
||||
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.5
|
||||
|
||||
type
|
||||
Fingerprint* = object
|
||||
crc32: uint32
|
||||
|
||||
proc encode*(T: typedesc[Fingerprint], msg: seq[byte]): RawStunAttribute =
|
||||
let value = Binary.encode(T(crc32: crc32(msg) xor 0x5354554e'u32))
|
||||
result = RawStunAttribute(attributeType: AttrFingerprint.uint16,
|
||||
length: value.len().uint16,
|
||||
value: value)
|
||||
|
||||
# Xor Mapped Address
|
||||
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.2
|
||||
|
||||
type
|
||||
MappedAddressFamily {.size: 1.} = enum
|
||||
MAFIPv4 = 0x01
|
||||
MAFIPv6 = 0x02
|
||||
|
||||
XorMappedAddress* = object
|
||||
reserved: uint8 # should be 0
|
||||
family: MappedAddressFamily
|
||||
port: uint16
|
||||
address: seq[byte]
|
||||
|
||||
proc encode*(T: typedesc[XorMappedAddress], ta: TransportAddress,
|
||||
tid: array[12, byte]): RawStunAttribute =
|
||||
const magicCookie = @[ 0x21'u8, 0x12, 0xa4, 0x42 ]
|
||||
let
|
||||
(address, family) =
|
||||
if ta.family == AddressFamily.IPv4:
|
||||
var s = newSeq[uint8](4)
|
||||
for i in 0..3:
|
||||
s[i] = ta.address_v4[i] xor magicCookie[i]
|
||||
(s, MAFIPv4)
|
||||
else:
|
||||
let magicCookieTid = magicCookie.concat(@tid)
|
||||
var s = newSeq[uint8](16)
|
||||
for i in 0..15:
|
||||
s[i] = ta.address_v6[i] xor magicCookieTid[i]
|
||||
(s, MAFIPv6)
|
||||
xma = T(family: family, port: ta.port.distinctBase xor 0x2112'u16, address: address)
|
||||
value = Binary.encode(xma)
|
||||
result = RawStunAttribute(attributeType: AttrXORMappedAddress.uint16,
|
||||
length: value.len().uint16,
|
||||
value: value)
|
||||
|
||||
# Message Integrity
|
||||
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.4
|
||||
|
||||
type
|
||||
MessageIntegrity* = object
|
||||
msgInt: seq[byte]
|
||||
|
||||
proc encode*(T: typedesc[MessageIntegrity], msg: seq[byte], key: seq[byte]): RawStunAttribute =
|
||||
let value = Binary.encode(T(msgInt: hmacSha1(key, msg)))
|
||||
result = RawStunAttribute(attributeType: AttrMessageIntegrity.uint16,
|
||||
length: value.len().uint16, value: value)
|
|
@ -0,0 +1,61 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
import ../udp_connection, stun
|
||||
|
||||
logScope:
|
||||
topics = "webrtc stun"
|
||||
|
||||
# TODO: Work fine when behaves like a server, need to implement the client side
|
||||
|
||||
type
|
||||
StunConn* = ref object
|
||||
conn: UdpConn
|
||||
laddr: TransportAddress
|
||||
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
|
||||
handlesFut: Future[void]
|
||||
closed: bool
|
||||
|
||||
proc handles(self: StunConn) {.async.} =
|
||||
while true:
|
||||
let (msg, raddr) = await self.conn.read()
|
||||
if Stun.isMessage(msg):
|
||||
let res = Stun.getResponse(msg, self.laddr)
|
||||
if res.isSome():
|
||||
await self.conn.write(raddr, res.get())
|
||||
else:
|
||||
self.dataRecv.addLastNoWait((msg, raddr))
|
||||
|
||||
proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) =
|
||||
self.conn = conn
|
||||
self.laddr = laddr
|
||||
self.closed = false
|
||||
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.handlesFut = self.handles()
|
||||
|
||||
proc close*(self: StunConn) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to close StunConn twice"
|
||||
return
|
||||
self.handlesFut.cancel() # check before?
|
||||
await self.conn.close()
|
||||
|
||||
proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to write on an already closed StunConn"
|
||||
return
|
||||
await self.conn.write(raddr, msg)
|
||||
|
||||
proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to read on an already closed StunConn"
|
||||
return
|
||||
return await self.dataRecv.popFirst()
|
|
@ -0,0 +1,58 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
|
||||
logScope:
|
||||
topics = "webrtc udp"
|
||||
|
||||
# UdpConn is a small wrapper of the chronos DatagramTransport.
|
||||
# It's the simplest solution we found to store the message and
|
||||
# the remote address used by the underlying protocols (dtls/sctp etc...)
|
||||
|
||||
type
|
||||
UdpConn* = ref object
|
||||
laddr*: TransportAddress
|
||||
udp: DatagramTransport
|
||||
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
|
||||
closed: bool
|
||||
|
||||
proc init*(self: UdpConn, laddr: TransportAddress) =
|
||||
self.laddr = laddr
|
||||
self.closed = false
|
||||
|
||||
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
|
||||
trace "UDP onReceive"
|
||||
let msg = udp.getMessage()
|
||||
self.dataRecv.addLastNoWait((msg, address))
|
||||
|
||||
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
|
||||
self.udp = newDatagramTransport(onReceive, local = laddr)
|
||||
|
||||
proc close*(self: UdpConn) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to close UdpConn twice"
|
||||
return
|
||||
self.closed = true
|
||||
self.udp.close()
|
||||
|
||||
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to write on an already closed UdpConn"
|
||||
return
|
||||
trace "UDP write", msg
|
||||
await self.udp.sendTo(raddr, msg)
|
||||
|
||||
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
|
||||
if self.closed:
|
||||
debug "Try to read on an already closed UdpConn"
|
||||
return
|
||||
trace "UDP read"
|
||||
return await self.dataRecv.popFirst()
|
|
@ -4,12 +4,12 @@ import strformat, os
|
|||
import nativesockets
|
||||
|
||||
# C include directory
|
||||
const root = currentSourcePath.parentDir
|
||||
const root = currentSourcePath.parentDir.parentDir
|
||||
const usrsctpInclude = root/"usrsctp"/"usrsctplib"
|
||||
|
||||
{.passc: fmt"-I{usrsctpInclude}".}
|
||||
|
||||
# Generated @ 2022-11-23T14:21:00+01:00
|
||||
# Generated @ 2023-03-30T13:55:23+02:00
|
||||
# Command line:
|
||||
# /home/lchenut/.nimble/pkgs/nimterop-0.6.13/nimterop/toast --compile=./usrsctp/usrsctplib/netinet/sctp_input.c --compile=./usrsctp/usrsctplib/netinet/sctp_asconf.c --compile=./usrsctp/usrsctplib/netinet/sctp_pcb.c --compile=./usrsctp/usrsctplib/netinet/sctp_usrreq.c --compile=./usrsctp/usrsctplib/netinet/sctp_cc_functions.c --compile=./usrsctp/usrsctplib/netinet/sctp_auth.c --compile=./usrsctp/usrsctplib/netinet/sctp_userspace.c --compile=./usrsctp/usrsctplib/netinet/sctp_output.c --compile=./usrsctp/usrsctplib/netinet/sctp_callout.c --compile=./usrsctp/usrsctplib/netinet/sctp_crc32.c --compile=./usrsctp/usrsctplib/netinet/sctp_sysctl.c --compile=./usrsctp/usrsctplib/netinet/sctp_sha1.c --compile=./usrsctp/usrsctplib/netinet/sctp_timer.c --compile=./usrsctp/usrsctplib/netinet/sctputil.c --compile=./usrsctp/usrsctplib/netinet/sctp_bsd_addr.c --compile=./usrsctp/usrsctplib/netinet/sctp_peeloff.c --compile=./usrsctp/usrsctplib/netinet/sctp_indata.c --compile=./usrsctp/usrsctplib/netinet/sctp_ss_functions.c --compile=./usrsctp/usrsctplib/user_socket.c --compile=./usrsctp/usrsctplib/netinet6/sctp6_usrreq.c --compile=./usrsctp/usrsctplib/user_mbuf.c --compile=./usrsctp/usrsctplib/user_environment.c --compile=./usrsctp/usrsctplib/user_recv_thread.c --pnim --preprocess --noHeader --defines=SCTP_PROCESS_LEVEL_LOCKS --defines=SCTP_SIMPLE_ALLOCATOR --defines=__Userspace__ --defines=STDC_HEADERS=1 --defines=HAVE_SYS_TYPES_H=1 --defines=HAVE_SYS_STAT_H=1 --defines=HAVE_STDLIB_H=1 --defines=HAVE_STRING_H=1 --defines=HAVE_MEMORY_H=1 --defines=HAVE_STRINGS_H=1 --defines=HAVE_INTTYPES_H=1 --defines=HAVE_STDINT_H=1 --defines=HAVE_UNISTD_H=1 --defines=HAVE_DLFCN_H=1 --defines=LT_OBJDIR=".libs/" --defines=SCTP_DEBUG=1 --defines=INET=1 --defines=INET6=1 --defines=HAVE_SOCKET=1 --defines=HAVE_INET_ADDR=1 --defines=HAVE_STDATOMIC_H=1 --defines=HAVE_SYS_QUEUE_H=1 --defines=HAVE_LINUX_IF_ADDR_H=1 --defines=HAVE_LINUX_RTNETLINK_H=1 --defines=HAVE_NETINET_IP_ICMP_H=1 --defines=HAVE_NET_ROUTE_H=1 --defines=_GNU_SOURCE --replace=sockaddr=SockAddr --replace=SockAddr_storage=Sockaddr_storage --replace=SockAddr_in=Sockaddr_in --replace=SockAddr_conn=Sockaddr_conn --replace=socklen_t=SockLen --includeDirs=./usrsctp/usrsctplib ./usrsctp/usrsctplib/usrsctp.h
|
||||
|
||||
|
@ -47,30 +47,29 @@ const usrsctpInclude = root/"usrsctp"/"usrsctplib"
|
|||
{.passc: "-DHAVE_NETINET_IP_ICMP_H=1".}
|
||||
{.passc: "-DHAVE_NET_ROUTE_H=1".}
|
||||
{.passc: "-D_GNU_SOURCE".}
|
||||
{.passc: "-I./usrsctp/usrsctplib".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_input.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_asconf.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_pcb.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_usrreq.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_cc_functions.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_auth.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_userspace.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_output.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_callout.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_crc32.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_sysctl.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_sha1.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_timer.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctputil.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_bsd_addr.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_peeloff.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_indata.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet/sctp_ss_functions.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_socket.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/netinet6/sctp6_usrreq.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_mbuf.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_environment.c".}
|
||||
{.compile: "./usrsctp/usrsctplib/user_recv_thread.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_input.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_asconf.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_pcb.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_usrreq.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_cc_functions.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_auth.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_userspace.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_output.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_callout.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_crc32.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_sysctl.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_sha1.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_timer.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctputil.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_bsd_addr.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_peeloff.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_indata.c".}
|
||||
{.compile: usrsctpInclude / "netinet/sctp_ss_functions.c".}
|
||||
{.compile: usrsctpInclude / "user_socket.c".}
|
||||
{.compile: usrsctpInclude / "netinet6/sctp6_usrreq.c".}
|
||||
{.compile: usrsctpInclude / "user_mbuf.c".}
|
||||
{.compile: usrsctpInclude / "user_environment.c".}
|
||||
{.compile: usrsctpInclude / "user_recv_thread.c".}
|
||||
const
|
||||
MSG_NOTIFICATION* = 0x00002000
|
||||
AF_CONN* = 123
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Nim-WebRTC
|
||||
# Copyright (c) 2024 Status Research & Development GmbH
|
||||
# Licensed under either of
|
||||
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
# at your option.
|
||||
# This file may not be copied, modified, or distributed except according to
|
||||
# those terms.
|
||||
|
||||
import chronos, chronicles
|
||||
|
||||
import udp_connection
|
||||
import stun/stun_connection
|
||||
import dtls/dtls
|
||||
import sctp, datachannel
|
||||
|
||||
logScope:
|
||||
topics = "webrtc"
|
||||
|
||||
type
|
||||
WebRTC* = ref object
|
||||
udp*: UdpConn
|
||||
stun*: StunConn
|
||||
dtls*: Dtls
|
||||
sctp*: Sctp
|
||||
port: int
|
||||
|
||||
proc new*(T: typedesc[WebRTC], address: TransportAddress): T =
|
||||
result = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls(), sctp: Sctp())
|
||||
result.udp.init(address)
|
||||
result.stun.init(result.udp, address)
|
||||
result.dtls.init(result.stun, address)
|
||||
result.sctp.init(result.dtls, address)
|
||||
|
||||
proc listen*(self: WebRTC) =
|
||||
self.sctp.listen()
|
||||
|
||||
proc connect*(self: WebRTC, raddr: TransportAddress): Future[DataChannelConnection] {.async.} =
|
||||
let sctpConn = await self.sctp.connect(raddr) # TODO: Port?
|
||||
result = DataChannelConnection.new(sctpConn)
|
||||
|
||||
proc accept*(w: WebRTC): Future[DataChannelConnection] {.async.} =
|
||||
let sctpConn = await w.sctp.accept()
|
||||
result = DataChannelConnection.new(sctpConn)
|
Loading…
Reference in New Issue