diff --git a/README.md b/README.md new file mode 100644 index 0000000..23aed32 --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +# Nim-Webrtc + +![Stability: experimental](https://img.shields.io/badge/stability-experimental-orange.svg) +[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) +[![License: Apache](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +A simple WebRTC stack first implemented for [libp2p WebRTC direct transport](https://github.com/libp2p/specs/blob/master/webrtc/webrtc-direct.md). +It uses a wrapper from two different C libraries: + - [usrsctp]() for the SCTP stack + - [mbedtls]() for the DTLS stack + +## Usage + +## Installation + +## TODO + +## License + +Licensed and distributed under either of + +* MIT license: [LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT + +or + +* Apache License, Version 2.0, ([LICENSE-APACHEv2](LICENSE-APACHEv2) or http://www.apache.org/licenses/LICENSE-2.0) + +at your option. This file may not be copied, modified, or distributed except according to those terms. diff --git a/build.sh b/build_usrsctp.sh similarity index 97% rename from build.sh rename to build_usrsctp.sh index c729c7e..04f1629 100755 --- a/build.sh +++ b/build_usrsctp.sh @@ -11,7 +11,7 @@ fi cd "${root}/usrsctp" && ./bootstrap && ./configure && make && cd - # add prelude -cat "${root}/prelude.nim" > "${outputFile}" +cat "${root}/prelude_usrsctp.nim" > "${outputFile}" # assemble list of C files to be compiled for file in `find ${root}/usrsctp/usrsctplib -name '*.c'`; do diff --git a/examples/ping.nim b/examples/ping.nim new file mode 100644 index 0000000..ea11c12 --- /dev/null +++ b/examples/ping.nim @@ -0,0 +1,24 @@ +import chronos, stew/byteutils +import ../webrtc/udp_connection +import ../webrtc/stun/stun_connection +import ../webrtc/dtls/dtls +import ../webrtc/sctp + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4244") + let udp = UdpConn() + udp.init(laddr) + let stun = StunConn() + stun.init(udp, laddr) + let dtls = Dtls() + dtls.init(stun, laddr) + let sctp = Sctp() + sctp.init(dtls, laddr) + let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13) + while true: + await conn.write("ping".toBytes) + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await sleepAsync(1.seconds) + +waitFor(main()) diff --git a/examples/pong.nim b/examples/pong.nim new file mode 100644 index 0000000..b614b59 --- /dev/null +++ b/examples/pong.nim @@ -0,0 +1,30 @@ +import chronos, stew/byteutils +import ../webrtc/udp_connection +import ../webrtc/stun/stun_connection +import ../webrtc/dtls/dtls +import ../webrtc/sctp + +proc sendPong(conn: SctpConn) {.async.} = + var i = 0 + while true: + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await conn.write(("pong " & $i).toBytes) + i.inc() + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4242") + let udp = UdpConn() + udp.init(laddr) + let stun = StunConn() + stun.init(udp, laddr) + let dtls = Dtls() + dtls.init(stun, laddr) + let sctp = Sctp() + sctp.init(dtls, laddr) + sctp.listen(13) + while true: + let conn = await sctp.accept() + asyncSpawn conn.sendPong() + +waitFor(main()) diff --git a/examples/sctp_both.nim b/examples/sctp_both.nim new file mode 100644 index 0000000..bb861fa --- /dev/null +++ b/examples/sctp_both.nim @@ -0,0 +1,25 @@ +import chronos, stew/byteutils +import ../webrtc/sctp as sc + +let sctp = Sctp.new(port = 4242) +proc serv(fut: Future[void]) {.async.} = + sctp.startServer(13) + fut.complete() + let conn = await sctp.listen() + echo "await read()" + let msg = await conn.read() + echo "read() finished" + echo "Receive: ", string.fromBytes(msg) + await conn.close() + sctp.stopServer() + +proc main() {.async.} = + let fut = Future[void]() + asyncSpawn serv(fut) + await fut + let address = TransportAddress(initTAddress("127.0.0.1:4242")) + let conn = await sctp.connect(address, sctpPort = 13) + await conn.write("test".toBytes) + await conn.close() + +waitFor(main()) diff --git a/examples/sctp_client.nim b/examples/sctp_client.nim new file mode 100644 index 0000000..b49b1d8 --- /dev/null +++ b/examples/sctp_client.nim @@ -0,0 +1,14 @@ +import chronos, stew/byteutils +import ../webrtc/sctp + +proc main() {.async.} = + let + sctp = Sctp.new(port = 4244) + address = TransportAddress(initTAddress("127.0.0.1:4242")) + conn = await sctp.connect(address, sctpPort = 13) + await conn.write("test".toBytes) + let msg = await conn.read() + echo "Client read() finished ; receive: ", string.fromBytes(msg) + await conn.close() + +waitFor(main()) diff --git a/examples/sctp_server.nim b/examples/sctp_server.nim new file mode 100644 index 0000000..429e247 --- /dev/null +++ b/examples/sctp_server.nim @@ -0,0 +1,13 @@ +import chronos, stew/byteutils +import ../webrtc/sctp + +proc main() {.async.} = + let sctp = Sctp.new(port = 4242) + sctp.startServer(13) + let conn = await sctp.listen() + let msg = await conn.read() + echo "Receive: ", string.fromBytes(msg) + await conn.close() + sctp.stopServer() + +waitFor(main()) diff --git a/prelude.nim b/prelude_usrsctp.nim similarity index 100% rename from prelude.nim rename to prelude_usrsctp.nim diff --git a/tests/testdatachannel.nim b/tests/testdatachannel.nim new file mode 100644 index 0000000..cf1a6a0 --- /dev/null +++ b/tests/testdatachannel.nim @@ -0,0 +1,25 @@ +import ../webrtc/datachannel +import chronos/unittest2/asynctests +import binary_serialization + +suite "DataChannel encoding": + test "DataChannelOpenMessage": + let msg = @[ + 0x03'u8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72] + check msg == Binary.encode(Binary.decode(msg, DataChannelMessage)) + check Binary.decode(msg, DataChannelMessage).openMessage == + DataChannelOpenMessage( + channelType: Reliable, + priority: 0, + reliabilityParameter: 0, + labelLength: 3, + protocolLength: 3, + label: @[102, 111, 111], + protocol: @[98, 97, 114] + ) + + test "DataChannelAck": + let msg = @[0x02'u8] + check msg == Binary.encode(Binary.decode(msg, DataChannelMessage)) + check Binary.decode(msg, DataChannelMessage).messageType == Ack diff --git a/tests/teststun.nim b/tests/teststun.nim new file mode 100644 index 0000000..c850484 --- /dev/null +++ b/tests/teststun.nim @@ -0,0 +1,14 @@ +import ../webrtc/stun +import ./asyncunit +import binary_serialization + +suite "Stun suite": + test "Stun encoding/decoding with padding": + let msg = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x00, 0x06, 0x00, 0x63, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x3a, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x00, 0xc0, 0x57, 0x00, 0x04, 0x00, 0x00, 0x03, 0xe7, 0x80, 0x2a, 0x00, 0x08, 0x86, 0x63, 0xfd, 0x45, 0xa9, 0xe5, 0x4c, 0xdb, 0x00, 0x24, 0x00, 0x04, 0x6e, 0x00, 0x1e, 0xff, 0x00, 0x08, 0x00, 0x14, 0x16, 0xff, 0x70, 0x8d, 0x97, 0x0b, 0xd6, 0xa3, 0x5b, 0xac, 0x8f, 0x4c, 0x85, 0xe6, 0xa6, 0xac, 0xaa, 0x7a, 0x68, 0x27, 0x80, 0x28, 0x00, 0x04, 0x79, 0x5e, 0x03, 0xd8 ] + check msg == encode(StunMessage.decode(msg)) + + test "Error while decoding": + let msgLengthFailed = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d ] + expect AssertionDefect: discard StunMessage.decode(msgLengthFailed) + let msgAttrFailed = @[ 0x00'u8, 0x01, 0x00, 0x08, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x28, 0x00, 0x05, 0x79, 0x5e, 0x03, 0xd8 ] + expect AssertionDefect: discard StunMessage.decode(msgAttrFailed) diff --git a/webrtc.nimble b/webrtc.nimble index b6d70e0..c30bde2 100644 --- a/webrtc.nimble +++ b/webrtc.nimble @@ -3,8 +3,13 @@ version = "0.0.1" author = "Status Research & Development GmbH" description = "Webrtc stack" license = "MIT" -#installDirs = @["usrsctp"] +installDirs = @["usrsctp", "webrtc"] requires "nim >= 1.2.0", "chronicles >= 0.10.2", - "chronos >= 3.0.6" + "chronos >= 3.0.6", + "https://github.com/status-im/nim-binary-serialization.git", + "https://github.com/status-im/nim-mbedtls.git" + +proc runTest(filename: string) = + discard diff --git a/webrtc/datachannel.nim b/webrtc/datachannel.nim new file mode 100644 index 0000000..5505a45 --- /dev/null +++ b/webrtc/datachannel.nim @@ -0,0 +1,227 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import tables + +import chronos, + chronicles, + binary_serialization + +import sctp + +export binary_serialization + +logScope: + topics = "webrtc datachannel" + +# Implementation of the DataChannel protocol, mostly following +# https://www.rfc-editor.org/rfc/rfc8831.html and +# https://www.rfc-editor.org/rfc/rfc8832.html + +type + DataChannelProtocolIds* {.size: 4.} = enum + WebRtcDcep = 50 + WebRtcString = 51 + WebRtcBinary = 53 + WebRtcStringEmpty = 56 + WebRtcBinaryEmpty = 57 + + DataChannelMessageType* {.size: 1.} = enum + Reserved = 0x00 + Ack = 0x02 + Open = 0x03 + + DataChannelMessage* = object + case messageType*: DataChannelMessageType + of Open: openMessage*: DataChannelOpenMessage + else: discard + + DataChannelType {.size: 1.} = enum + Reliable = 0x00 + PartialReliableRexmit = 0x01 + PartialReliableTimed = 0x02 + ReliableUnordered = 0x80 + PartialReliableRexmitUnordered = 0x81 + PartialReliableTimedUnorderd = 0x82 + + DataChannelOpenMessage* = object + channelType*: DataChannelType + priority*: uint16 + reliabilityParameter*: uint32 + labelLength* {.bin_value: it.label.len.}: uint16 + protocolLength* {.bin_value: it.protocol.len.}: uint16 + label* {.bin_len: it.labelLength.}: seq[byte] + protocol* {.bin_len: it.protocolLength.}: seq[byte] + +proc ordered(t: DataChannelType): bool = + t in [Reliable, PartialReliableRexmit, PartialReliableTimed] + +type + #TODO handle closing + DataChannelStream* = ref object + id: uint16 + conn: SctpConn + reliability: DataChannelType + reliabilityParameter: uint32 + receivedData: AsyncQueue[seq[byte]] + acked: bool + + #TODO handle closing + DataChannelConnection* = ref object + readLoopFut: Future[void] + streams: Table[uint16, DataChannelStream] + streamId: uint16 + conn*: SctpConn + incomingStreams: AsyncQueue[DataChannelStream] + +proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} = + let x = await stream.receivedData.popFirst() + trace "read", length=x.len(), id=stream.id + return x + +proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} = + trace "write", length=buf.len(), id=stream.id + var + sendInfo = SctpMessageParameters( + streamId: stream.id, + endOfRecord: true, + protocolId: uint32(WebRtcBinary) + ) + + if stream.acked: + sendInfo.unordered = not stream.reliability.ordered + #TODO add reliability params + + if buf.len == 0: + trace "Datachannel write empty" + sendInfo.protocolId = uint32(WebRtcBinaryEmpty) + await stream.conn.write(@[0'u8], sendInfo) + else: + await stream.conn.write(buf, sendInfo) + +proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.async.} = + let + encoded = Binary.encode(msg) + sendInfo = SctpMessageParameters( + streamId: stream.id, + endOfRecord: true, + protocolId: uint32(WebRtcDcep) + ) + trace "send control message", msg + + await stream.conn.write(encoded, sendInfo) + +proc openStream*( + conn: DataChannelConnection, + noiseHandshake: bool, + reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} = + let streamId: uint16 = + if not noiseHandshake: + let res = conn.streamId + conn.streamId += 2 + res + else: + 0 + + trace "open stream", streamId + if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0: + raise newException(ValueError, "reliabilityParameter should be 0") + + if streamId in conn.streams: + raise newException(ValueError, "streamId already used") + + #TODO: we should request more streams when required + # https://github.com/sctplab/usrsctp/blob/a0cbf4681474fab1e89d9e9e2d5c3694fce50359/programs/rtcweb.c#L304C16-L304C16 + + var stream = DataChannelStream( + id: streamId, conn: conn.conn, + reliability: reliability, + reliabilityParameter: reliabilityParameter, + receivedData: newAsyncQueue[seq[byte]]() + ) + + conn.streams[streamId] = stream + + let + msg = DataChannelMessage( + messageType: Open, + openMessage: DataChannelOpenMessage( + channelType: reliability, + reliabilityParameter: reliabilityParameter + ) + ) + await stream.sendControlMessage(msg) + return stream + +proc handleData(conn: DataChannelConnection, msg: SctpMessage) = + let streamId = msg.params.streamId + trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data + + if streamId notin conn.streams: + raise newException(ValueError, "got data for unknown streamid") + + let stream = conn.streams[streamId] + + #TODO handle string vs binary + if msg.params.protocolId in [uint32(WebRtcStringEmpty), uint32(WebRtcBinaryEmpty)]: + # PPID indicate empty message + stream.receivedData.addLastNoWait(@[]) + else: + stream.receivedData.addLastNoWait(msg.data) + +proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} = + let + decoded = Binary.decode(msg.data, DataChannelMessage) + streamId = msg.params.streamId + + trace "handle control message", decoded, streamId = msg.params.streamId + if decoded.messageType == Ack: + if streamId notin conn.streams: + raise newException(ValueError, "got ack for unknown streamid") + conn.streams[streamId].acked = true + elif decoded.messageType == Open: + if streamId in conn.streams: + raise newException(ValueError, "got open for already existing streamid") + + let stream = DataChannelStream( + id: streamId, conn: conn.conn, + reliability: decoded.openMessage.channelType, + reliabilityParameter: decoded.openMessage.reliabilityParameter, + receivedData: newAsyncQueue[seq[byte]]() + ) + + conn.streams[streamId] = stream + conn.incomingStreams.addLastNoWait(stream) + + await stream.sendControlMessage(DataChannelMessage(messageType: Ack)) + +proc readLoop(conn: DataChannelConnection) {.async.} = + try: + while true: + let message = await conn.conn.read() + # TODO: check the protocolId + if message.params.protocolId == uint32(WebRtcDcep): + #TODO should we really await? + await conn.handleControl(message) + else: + conn.handleData(message) + + except CatchableError as exc: + discard + +proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} = + return await conn.incomingStreams.popFirst() + +proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection = + result = DataChannelConnection( + conn: conn, + incomingStreams: newAsyncQueue[DataChannelStream](), + streamId: 1'u16 # TODO: Serveur == 1, client == 2 + ) + result.readLoopFut = result.readLoop() diff --git a/webrtc/dtls/dtls.nim b/webrtc/dtls/dtls.nim new file mode 100644 index 0000000..b09a888 --- /dev/null +++ b/webrtc/dtls/dtls.nim @@ -0,0 +1,381 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import times, deques, tables, sequtils +import chronos, chronicles +import ./utils, ../stun/stun_connection + +import mbedtls/ssl +import mbedtls/ssl_cookie +import mbedtls/ssl_cache +import mbedtls/pk +import mbedtls/md +import mbedtls/entropy +import mbedtls/ctr_drbg +import mbedtls/rsa +import mbedtls/x509 +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/error +import mbedtls/net_sockets +import mbedtls/timing + +logScope: + topics = "webrtc dtls" + +# Implementation of a DTLS client and a DTLS Server by using the mbedtls library. +# Multiple things here are unintuitive partly because of the callbacks +# used by mbedtls and that those callbacks cannot be async. +# +# TODO: +# - Check the viability of the add/pop first/last of the asyncqueue with the limit. +# There might be some errors (or crashes) with some edge cases with the no wait option +# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE +# - Not critical - May be interesting to split Dtls and DtlsConn into two files + +# This limit is arbitrary, it could be interesting to make it configurable. +const PendingHandshakeLimit = 1024 + +# -- DtlsConn -- +# A Dtls connection to a specific IP address recovered by the receiving part of +# the Udp "connection" + +type + DtlsError* = object of CatchableError + DtlsConn* = ref object + conn: StunConn + laddr: TransportAddress + raddr*: TransportAddress + dataRecv: AsyncQueue[seq[byte]] + sendFuture: Future[void] + closed: bool + closeEvent: AsyncEvent + + timer: mbedtls_timing_delay_context + + ssl: mbedtls_ssl_context + config: mbedtls_ssl_config + cookie: mbedtls_ssl_cookie_ctx + cache: mbedtls_ssl_cache_context + + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + + localCert: seq[byte] + remoteCert: seq[byte] + +proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) = + self.conn = conn + self.laddr = laddr + self.dataRecv = newAsyncQueue[seq[byte]]() + self.closed = false + self.closeEvent = newAsyncEvent() + +proc join(self: DtlsConn) {.async.} = + await self.closeEvent.wait() + +proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} = + var shouldRead = isServer + while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + if isServer: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6) + else: + raise newException(DtlsError, "Remote address isn't an IP address") + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + self.sendFuture = nil + let res = mb_ssl_handshake_step(self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + shouldRead = false + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ: + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ssl) + shouldRead = isServer + continue + elif res != 0: + raise newException(DtlsError, $(res.mbedtls_high_level_strerr())) + +proc close*(self: DtlsConn) {.async.} = + if self.closed: + debug "Try to close DtlsConn twice" + return + + self.closed = true + self.sendFuture = nil + # TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls + let x = mbedtls_ssl_close_notify(addr self.ssl) + if not self.sendFuture.isNil(): + await self.sendFuture + self.closeEvent.fire() + +proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = + if self.closed: + debug "Try to write on an already closed DtlsConn" + return + var buf = msg + try: + let sendFuture = newFuture[void]("DtlsConn write") + self.sendFuture = nil + let write = mb_ssl_write(self.ssl, buf) + if not self.sendFuture.isNil(): + await self.sendFuture + trace "Dtls write", msgLen = msg.len(), actuallyWrote = write + except MbedTLSError as exc: + trace "Dtls write error", errorMsg = exc.msg + raise exc + +proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = + if self.closed: + debug "Try to read on an already closed DtlsConn" + return + var res = newSeq[byte](8192) + while true: + let tmp = await self.dataRecv.popFirst() + self.dataRecv.addFirstNoWait(tmp) + # TODO: Find a clear way to use the template `mb_ssl_read` without + # messing up things with exception + let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint) + if length == MBEDTLS_ERR_SSL_WANT_READ: + continue + if length < 0: + raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr())) + res.setLen(length) + return res + +# -- Dtls -- +# The Dtls object read every messages from the UdpConn/StunConn and, if the address +# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue +# to be accepted later, if the address is stored, add the message received to the +# corresponding DtlsConn `dataRecv` queue. + +type + Dtls* = ref object of RootObj + connections: Table[TransportAddress, DtlsConn] + pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])] + conn: StunConn + laddr: TransportAddress + started: bool + readLoop: Future[void] + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + + serverPrivKey: mbedtls_pk_context + serverCert: mbedtls_x509_crt + localCert: seq[byte] + +proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])], + raddr: TransportAddress, buf: seq[byte]) = + for kv in aq.mitems(): + if kv[0] == raddr: + kv[1] = buf + return + aq.addLastNoWait((raddr, buf)) + +proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) = + if self.started: + warn "Already started" + return + + proc readLoop() {.async.} = + while true: + let (buf, raddr) = await self.conn.read() + if self.connections.hasKey(raddr): + self.connections[raddr].dataRecv.addLastNoWait(buf) + else: + self.pendingHandshakes.updateOrAdd(raddr, buf) + + self.connections = initTable[TransportAddress, DtlsConn]() + self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit) + self.conn = conn + self.laddr = laddr + self.started = true + self.readLoop = readLoop() + + mb_ctr_drbg_init(self.ctr_drbg) + mb_entropy_init(self.entropy) + mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0) + + self.serverPrivKey = self.ctr_drbg.generateKey() + self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey) + self.localCert = newSeq[byte](self.serverCert.raw.len) + copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len) + +proc stop*(self: Dtls) {.async.} = + if not self.started: + warn "Already stopped" + return + + await allFutures(toSeq(self.connections.values()).mapIt(it.close())) + self.readLoop.cancel() + self.started = false + +# -- Remote / Local certificate getter -- + +proc remoteCertificate*(conn: DtlsConn): seq[byte] = + conn.remoteCert + +proc localCertificate*(conn: DtlsConn): seq[byte] = + conn.localCert + +proc localCertificate*(self: Dtls): seq[byte] = + self.localCert + +# -- MbedTLS Callbacks -- + +proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt, + state: cint, pflags: ptr uint32): cint {.cdecl.} = + # verify is the procedure called by mbedtls when receiving the remote + # certificate. It's usually used to verify the validity of the certificate. + # We use this procedure to store the remote certificate as it's mandatory + # to have it for the Prologue of the Noise protocol, aswell as the localCertificate. + var self = cast[DtlsConn](ctx) + let cert = pcert[] + + self.remoteCert = newSeq[byte](cert.raw.len) + copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len) + return 0 + +proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsSend is the procedure called by mbedtls when data needs to be sent. + # As the StunConn's write proc is asynchronous and dtlsSend cannot be async, + # we store the future of this write and await it after the end of the + # function (see write or dtlsHanshake for example). + var self = cast[DtlsConn](ctx) + var toWrite = newSeq[byte](len) + if len > 0: + copyMem(addr toWrite[0], buf, len) + trace "dtls send", len + self.sendFuture = self.conn.write(self.raddr, toWrite) + result = len.cint + +proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsRecv is the procedure called by mbedtls when data needs to be received. + # As we cannot asynchronously await for data to be received, we use a data received + # queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await + # when the mbedtls proc resumed (see read or dtlsHandshake for example) + let self = cast[DtlsConn](ctx) + if self.dataRecv.len() == 0: + return MBEDTLS_ERR_SSL_WANT_READ + + var dataRecv = self.dataRecv.popFirstNoWait() + copyMem(buf, addr dataRecv[0], dataRecv.len()) + result = dataRecv.len().cint + trace "dtls receive", len, result + +# -- Dtls Accept / Connect procedures -- + +proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} = + await conn.join() + self.connections.del(raddr) + +proc accept*(self: Dtls): Future[DtlsConn] {.async.} = + var + selfvar = self + res = DtlsConn() + + res.init(self.conn, self.laddr) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + mb_ssl_cookie_init(res.cookie) + mb_ssl_cache_init(res.cache) + + res.ctr_drbg = self.ctr_drbg + res.entropy = self.entropy + + var pkey = self.serverPrivKey + var srvcert = self.serverCert + res.localCert = self.localCert + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_conf_own_cert(res.config, srvcert, pkey) + mb_ssl_cookie_setup(res.cookie, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_dtls_cookies(res.config, res.cookie) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_session_reset(res.ssl) + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) + while true: + let (raddr, buf) = await self.pendingHandshakes.popFirst() + try: + res.raddr = raddr + res.dataRecv.addLastNoWait(buf) + self.connections[raddr] = res + await res.dtlsHandshake(true) + asyncSpawn self.removeConnection(res, raddr) + break + except CatchableError as exc: + trace "Handshake fail", remoteAddress = raddr, error = exc.msg + self.connections.del(raddr) + continue + return res + +proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} = + var + selfvar = self + res = DtlsConn() + + res.init(self.conn, self.laddr) + mb_ssl_init(res.ssl) + mb_ssl_config_init(res.config) + + res.ctr_drbg = self.ctr_drbg + res.entropy = self.entropy + + var pkey = res.ctr_drbg.generateKey() + var srvcert = res.ctr_drbg.generateCertificate(pkey) + res.localCert = newSeq[byte](srvcert.raw.len) + copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len) + + mb_ctr_drbg_init(res.ctr_drbg) + mb_entropy_init(res.entropy) + mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0) + + mb_ssl_config_defaults(res.config, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT) + mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg) + mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(res.config, srvcert.next, nil) + mb_ssl_set_timer_cb(res.ssl, res.timer) + mb_ssl_setup(res.ssl, res.config) + mb_ssl_set_verify(res.ssl, verify, res) + mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil) + + res.raddr = raddr + self.connections[raddr] = res + + try: + await res.dtlsHandshake(false) + asyncSpawn self.removeConnection(res, raddr) + except CatchableError as exc: + trace "Handshake fail", remoteAddress = raddr, error = exc.msg + self.connections.del(raddr) + raise exc + + return res diff --git a/webrtc/dtls/utils.nim b/webrtc/dtls/utils.nim new file mode 100644 index 0000000..06fb990 --- /dev/null +++ b/webrtc/dtls/utils.nim @@ -0,0 +1,96 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/times + +import stew/byteutils + +import mbedtls/pk +import mbedtls/rsa +import mbedtls/ctr_drbg +import mbedtls/x509_crt +import mbedtls/bignum +import mbedtls/md + +import chronicles + +# This sequence is used for debugging. +const mb_ssl_states* = @[ + "MBEDTLS_SSL_HELLO_REQUEST", + "MBEDTLS_SSL_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_HELLO", + "MBEDTLS_SSL_SERVER_CERTIFICATE", + "MBEDTLS_SSL_SERVER_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_REQUEST", + "MBEDTLS_SSL_SERVER_HELLO_DONE", + "MBEDTLS_SSL_CLIENT_CERTIFICATE", + "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_CLIENT_FINISHED", + "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_SERVER_FINISHED", + "MBEDTLS_SSL_FLUSH_BUFFERS", + "MBEDTLS_SSL_HANDSHAKE_WRAPUP", + "MBEDTLS_SSL_NEW_SESSION_TICKET", + "MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", + "MBEDTLS_SSL_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", + "MBEDTLS_SSL_END_OF_EARLY_DATA", + "MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED", + "MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_HANDSHAKE_OVER", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH" +] + +template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = + var res: mbedtls_pk_context + mb_pk_init(res) + discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) + mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537) + let x = mb_pk_rsa(res) + res + +template generateCertificate*(random: mbedtls_ctr_drbg_context, + issuer_key: mbedtls_pk_context): mbedtls_x509_crt = + let + # To be honest, I have no clue what to put here as a name + name = "C=FR,O=Status,CN=webrtc" + time_format = initTimeFormat("YYYYMMddHHmmss") + time_from = times.now().format(time_format) + time_to = (times.now() + times.years(1)).format(time_format) + + var write_cert: mbedtls_x509write_cert + var serial_mpi: mbedtls_mpi + mb_x509write_crt_init(write_cert) + mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256); + mb_x509write_crt_set_subject_key(write_cert, issuer_key) + mb_x509write_crt_set_issuer_key(write_cert, issuer_key) + mb_x509write_crt_set_subject_name(write_cert, name) + mb_x509write_crt_set_issuer_name(write_cert, name) + mb_x509write_crt_set_validity(write_cert, time_from, time_to) + mb_x509write_crt_set_basic_constraints(write_cert, 0, -1) + mb_x509write_crt_set_subject_key_identifier(write_cert) + mb_x509write_crt_set_authority_key_identifier(write_cert) + mb_mpi_init(serial_mpi) + let serial_hex = mb_mpi_read_string(serial_mpi, 16) + mb_x509write_crt_set_serial(write_cert, serial_mpi) + let buf = + try: + mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random) + except MbedTLSError as e: + raise e + var res: mbedtls_x509_crt + mb_x509_crt_parse(res, buf) + res diff --git a/webrtc/sctp.nim b/webrtc/sctp.nim index 4aa9c13..31a4b8b 100644 --- a/webrtc/sctp.nim +++ b/webrtc/sctp.nim @@ -1,5 +1,5 @@ # Nim-WebRTC -# Copyright (c) 2022 Status Research & Development GmbH +# Copyright (c) 2024 Status Research & Development GmbH # Licensed under either of # * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) # * MIT license ([LICENSE-MIT](LICENSE-MIT)) @@ -8,14 +8,29 @@ # those terms. import tables, bitops, posix, strutils, sequtils -import chronos, chronicles, stew/ranges/ptr_arith +import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2] import usrsctp +import dtls/dtls +import binary_serialization export chronicles logScope: topics = "webrtc sctp" +# Implementation of an Sctp client and server using the usrsctp library. +# Usrsctp is usable as a single thread but it's not the intended way to +# use it. There's a lot of callbacks calling each other in a synchronous +# way where we want to be able to call asynchronous procedure, but cannot. + +# TODO: +# - Replace doAssert by a proper exception management +# - Find a clean way to manage SCTP ports +# - Unregister address when closing + +proc perror(error: cstring) {.importc, cdecl, header: "".} +proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} + type SctpError* = object of CatchableError @@ -24,115 +39,172 @@ type Connected Closed - SctpConnection* = ref object + SctpMessageParameters* = object + protocolId*: uint32 + streamId*: uint16 + endOfRecord*: bool + unordered*: bool + + SctpMessage* = ref object + data*: seq[byte] + info: sctp_recvv_rn + params*: SctpMessageParameters + + SctpConn* = ref object + conn*: DtlsConn state: SctpState connectEvent: AsyncEvent + acceptEvent: AsyncEvent + readLoop: Future[void] sctp: Sctp udp: DatagramTransport address: TransportAddress sctpSocket: ptr socket - recvEvent: AsyncEvent - dataRecv: seq[byte] + dataRecv: AsyncQueue[SctpMessage] + sentFuture: Future[void] Sctp* = ref object + dtls: Dtls udp: DatagramTransport - connections: Table[TransportAddress, SctpConnection] + connections: Table[TransportAddress, SctpConn] gotConnection: AsyncEvent timersHandler: Future[void] isServer: bool sockServer: ptr socket - pendingConnections: seq[SctpConnection] - sentFuture: Future[void] - sentConnection: SctpConnection + pendingConnections: seq[SctpConn] + pendingConnections2: Table[SockAddr, SctpConn] sentAddress: TransportAddress + sentFuture: Future[void] -const - IPPROTO_SCTP = 132 + # These three objects are used for debugging/trace only + SctpChunk = object + chunkType: uint8 + flag: uint8 + length {.bin_value: it.data.len() + 4.}: uint16 + data {.bin_len: it.length - 4.}: seq[byte] -proc newSctpError(msg: string): ref SctpError = - result = newException(SctpError, msg) + SctpPacketHeader = object + srcPort: uint16 + dstPort: uint16 + verifTag: uint32 + checksum: uint32 -template usrsctpAwait(sctp: Sctp, body: untyped): untyped = - sctp.sentFuture = nil + SctpPacketStructure = object + header: SctpPacketHeader + chunks: seq[SctpChunk] + +const IPPROTO_SCTP = 132 + +proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure = + # Only used for debugging/trace + result.header = Binary.decode(buffer, SctpPacketHeader) + var size = sizeof(SctpPacketStructure) + while size < buffer.len: + let chunk = Binary.decode(buffer[size..^1], SctpChunk) + result.chunks.add(chunk) + size.inc(chunk.length.int) + while size mod 4 != 0: + # padding; could use `size.inc(-size %% 4)` instead but it lacks clarity + size.inc(1) + +# -- Asynchronous wrapper -- + +template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped = + # usrsctpAwait is template which set `sentFuture` to nil then calls (usually) + # an usrsctp function. If during the synchronous run of the usrsctp function + # `sendCallback` is called, then `sentFuture` is set and waited. + self.sentFuture = nil when type(body) is void: body - if sctp.sentFuture != nil: await sctp.sentFuture + if self.sentFuture != nil: await self.sentFuture else: let res = body - if sctp.sentFuture != nil: await sctp.sentFuture + if self.sentFuture != nil: await self.sentFuture res -proc perror(error: cstring) {.importc, cdecl, header: "".} -proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} +# -- SctpConn -- -proc packetPretty(packet: cstring): string = - let data = $packet - let ctn = data[23..^16] - result = data[1..14] - if ctn.len > 30: - result = result & ctn[0..14] & " ... " & ctn[^14..^1] - else: - result = result & ctn - -proc new(T: typedesc[SctpConnection], - sctp: Sctp, - udp: DatagramTransport, - address: TransportAddress, - sctpSocket: ptr socket): T = - T(sctp: sctp, +proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T = + T(conn: conn, + sctp: sctp, state: Connecting, - udp: udp, - address: address, - sctpSocket: sctpSocket, connectEvent: AsyncEvent(), - recvEvent: AsyncEvent()) + acceptEvent: AsyncEvent(), + dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure? + ) -proc read*(self: SctpConnection): Future[seq[byte]] {.async.} = - trace "Read" - if self.dataRecv.len == 0: - self.recvEvent.clear() - await self.recvEvent.wait() - let res = self.dataRecv - self.dataRecv = @[] - return res +proc read*(self: SctpConn): Future[SctpMessage] {.async.} = + # Used by DataChannel, returns SctpMessage in order to get the stream + # and protocol ids + return await self.dataRecv.popFirst() -proc write*(self: SctpConnection, buf: seq[byte]) {.async.} = +proc toFlags(params: SctpMessageParameters): uint16 = + if params.endOfRecord: + result = result or SCTP_EOR + if params.unordered: + result = result or SCTP_UNORDERED + +proc write*(self: SctpConn, buf: seq[byte], + sendParams = default(SctpMessageParameters)) {.async.} = + # Used by DataChannel, writes buf on the Dtls connection. trace "Write", buf - self.sctp.sentConnection = self self.sctp.sentAddress = self.address - let sendvErr = self.sctp.usrsctpAwait: - self.sctpSocket.usrsctp_sendv(unsafeAddr buf[0], buf.len.uint, - nil, 0, nil, 0, - SCTP_SENDV_NOINFO, 0) -proc close*(self: SctpConnection) {.async.} = - self.sctp.usrsctpAwait: self.sctpSocket.usrsctp_close() + var cpy = buf + let sendvErr = + if sendParams == default(SctpMessageParameters): + # If writes is called by DataChannel, sendParams should never + # be the default value. This split is useful for testing. + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, + nil, 0, SCTP_SENDV_NOINFO.cuint, 0) + else: + let sendInfo = sctp_sndinfo( + snd_sid: sendParams.streamId, + # TODO: swapBytes => htonl? + snd_ppid: sendParams.protocolId.swapBytes(), + snd_flags: sendParams.toFlags) + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0, + cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen, + SCTP_SENDV_SNDINFO.cuint, 0) + if sendvErr < 0: + # TODO: throw an exception + perror("usrsctp_sendv") + +proc write*(self: SctpConn, s: string) {.async.} = + await self.write(s.toBytes()) + +proc close*(self: SctpConn) {.async.} = + self.usrsctpAwait: + self.sctpSocket.usrsctp_close() + +# -- usrsctp receive data callbacks -- proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when we receive data after + # connection has been established. let + conn = cast[SctpConn](data) events = usrsctp_get_events(sock) - conn = cast[SctpConnection](data) + trace "Handle Upcall", events - if conn.state == Connecting: - if bitand(events, SCTP_EVENT_ERROR) != 0: - warn "Cannot connect", address = conn.address - conn.state = Closed - elif bitand(events, SCTP_EVENT_WRITE) != 0: - conn.state = Connected - conn.connectEvent.fire() - elif bitand(events, SCTP_EVENT_READ) != 0: + if bitand(events, SCTP_EVENT_READ) != 0: var - buffer = newSeq[byte](4096) + message = SctpMessage( + data: newSeq[byte](4096) + ) address: Sockaddr_storage rn: sctp_recvv_rn addressLen = sizeof(Sockaddr_storage).SockLen rnLen = sizeof(sctp_recvv_rn).SockLen infotype: uint flags: int - let n = sock.usrsctp_recvv(cast[pointer](addr buffer[0]), buffer.len.uint, + let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]), + message.data.len.uint, cast[ptr SockAddr](addr address), cast[ptr SockLen](addr addressLen), - cast[pointer](addr rn), + cast[pointer](addr message.info), cast[ptr SockLen](addr rnLen), cast[ptr cuint](addr infotype), cast[ptr cint](addr flags)) @@ -140,91 +212,152 @@ proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = perror("usrsctp_recvv") return elif n > 0: + # It might be necessary to check if infotype == SCTP_RECVV_RCVINFO + message.data.delete(n..= msgHeaderSize and msg[4..<8] == magicCookieSeq and bitand(0xC0'u8, msg[0]) == 0'u8 + +proc addLength(msgEncoded: var seq[byte], length: uint16) = + let + hi = (length div 256'u16).uint8 + lo = (length mod 256'u16).uint8 + msgEncoded[2] = msgEncoded[2] + hi + if msgEncoded[3].int + lo.int >= 256: + msgEncoded[2] = msgEncoded[2] + 1 + msgEncoded[3] = ((msgEncoded[3].int + lo.int) mod 256).uint8 + else: + msgEncoded[3] = msgEncoded[3] + lo + +proc decode*(T: typedesc[StunMessage], msg: seq[byte]): StunMessage = + let smi = Binary.decode(msg, RawStunMessage) + return T(msgType: smi.msgType, + transactionId: smi.transactionId, + attributes: RawStunAttribute.decode(smi.content)) + +proc encode*(msg: StunMessage, userOpt: Option[seq[byte]]): seq[byte] = + const pad = @[0, 3, 2, 1] + var smi = RawStunMessage(msgType: msg.msgType, + magicCookie: magicCookie, + transactionId: msg.transactionId) + for attr in msg.attributes: + smi.content.add(Binary.encode(attr)) + smi.content.add(newSeq[byte](pad[smi.content.len() mod 4])) + + result = Binary.encode(smi) + + if userOpt.isSome(): + let username = string.fromBytes(userOpt.get()) + let usersplit = username.split(":") + if usersplit.len() == 2 and usersplit[0].startsWith("libp2p+webrtc+v1/"): + result.addLength(24) + result.add(Binary.encode(MessageIntegrity.encode(result, toBytes(usersplit[0])))) + + result.addLength(8) + result.add(Binary.encode(Fingerprint.encode(result))) + +proc getResponse*(T: typedesc[Stun], msg: seq[byte], + ta: TransportAddress): Option[seq[byte]] = + if ta.family != AddressFamily.IPv4 and ta.family != AddressFamily.IPv6: + return none(seq[byte]) + let sm = + try: + StunMessage.decode(msg) + except CatchableError as exc: + return none(seq[byte]) + + if sm.msgType != BindingRequest: + return none(seq[byte]) + + var res = StunMessage(msgType: BindingResponse, + transactionId: sm.transactionId) + + var unknownAttr: seq[uint16] + for attr in sm.attributes: + let typ = attr.attributeType + if typ.isRequired() and typ notin StunAttributeEnum: + unknownAttr.add(typ) + if unknownAttr.len() > 0: + res.attributes.add(ErrorCode.encode(ECUnknownAttribute)) + res.attributes.add(UnknownAttribute.encode(unknownAttr)) + return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16))) + + res.attributes.add(XorMappedAddress.encode(ta, sm.transactionId)) + return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16))) + +proc new*(T: typedesc[Stun]): T = + result = T() diff --git a/webrtc/stun/stun_attributes.nim b/webrtc/stun/stun_attributes.nim new file mode 100644 index 0000000..11e3c0e --- /dev/null +++ b/webrtc/stun/stun_attributes.nim @@ -0,0 +1,228 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/sha1, sequtils, typetraits, std/md5 +import binary_serialization, + stew/byteutils, + chronos + +# -- Utils -- + +proc createCrc32Table(): array[0..255, uint32] = + for i in 0..255: + var rem = i.uint32 + for j in 0..7: + if (rem and 1) > 0: + rem = (rem shr 1) xor 0xedb88320'u32 + else: + rem = rem shr 1 + result[i] = rem + +proc crc32(s: seq[byte]): uint32 = + # CRC-32 is used for the fingerprint attribute + # See https://datatracker.ietf.org/doc/html/rfc5389#section-15.5 + const crc32table = createCrc32Table() + result = 0xffffffff'u32 + for c in s: + result = (result shr 8) xor crc32table[(result and 0xff) xor c] + result = not result + +proc hmacSha1(key: seq[byte], msg: seq[byte]): seq[byte] = + # HMAC-SHA1 is used for the message integrity attribute + # See https://datatracker.ietf.org/doc/html/rfc5389#section-15.4 + let + keyPadded = + if len(key) > 64: + @(secureHash(key.mapIt(it.chr)).distinctBase) + elif key.len() < 64: + key.concat(newSeq[byte](64 - key.len())) + else: + key + innerHash = keyPadded. + mapIt(it xor 0x36'u8). + concat(msg). + mapIt(it.chr). + secureHash() + outerHash = keyPadded. + mapIt(it xor 0x5c'u8). + concat(@(innerHash.distinctBase)). + mapIt(it.chr). + secureHash() + return @(outerHash.distinctBase) + +# -- Attributes -- +# There are obviously some attributes implementation that are missing, +# it might be something to do eventually if we want to make this +# repository work for other project than nim-libp2p +# +# Stun Attribute +# 0 1 2 3 +# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# | Type | Length | +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +# | Value (variable) .... +# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type + StunAttributeEncodingError* = object of CatchableError + + RawStunAttribute* = object + attributeType*: uint16 + length* {.bin_value: it.value.len.}: uint16 + value* {.bin_len: it.length.}: seq[byte] + + StunAttributeEnum* = enum + AttrMappedAddress = 0x0001 + AttrChangeRequest = 0x0003 # RFC5780 Nat Behavior Discovery + AttrSourceAddress = 0x0004 # Deprecated + AttrChangedAddress = 0x0005 # Deprecated + AttrUsername = 0x0006 + AttrMessageIntegrity = 0x0008 + AttrErrorCode = 0x0009 + AttrUnknownAttributes = 0x000A + AttrChannelNumber = 0x000C # RFC5766 TURN + AttrLifetime = 0x000D # RFC5766 TURN + AttrXORPeerAddress = 0x0012 # RFC5766 TURN + AttrData = 0x0013 # RFC5766 TURN + AttrRealm = 0x0014 + AttrNonce = 0x0015 + AttrXORRelayedAddress = 0x0016 # RFC5766 TURN + AttrRequestedAddressFamily = 0x0017 # RFC6156 + AttrEvenPort = 0x0018 # RFC5766 TURN + AttrRequestedTransport = 0x0019 # RFC5766 TURN + AttrDontFragment = 0x001A # RFC5766 TURN + AttrMessageIntegritySHA256 = 0x001C # RFC8489 STUN (v2) + AttrPasswordAlgorithm = 0x001D # RFC8489 STUN (v2) + AttrUserhash = 0x001E # RFC8489 STUN (v2) + AttrXORMappedAddress = 0x0020 + AttrReservationToken = 0x0022 # RFC5766 TURN + AttrPriority = 0x0024 # RFC5245 ICE + AttrUseCandidate = 0x0025 # RFC5245 ICE + AttrPadding = 0x0026 # RFC5780 Nat Behavior Discovery + AttrResponsePort = 0x0027 # RFC5780 Nat Behavior Discovery + AttrConnectionID = 0x002a # RFC6062 TURN Extensions + AttrPasswordAlgorithms = 0x8002 # RFC8489 STUN (v2) + AttrAlternateDomain = 0x8003 # RFC8489 STUN (v2) + AttrSoftware = 0x8022 + AttrAlternateServer = 0x8023 + AttrCacheTimeout = 0x8027 # RFC5780 Nat Behavior Discovery + AttrFingerprint = 0x8028 + AttrICEControlled = 0x8029 # RFC5245 ICE + AttrICEControlling = 0x802A # RFC5245 ICE + AttrResponseOrigin = 0x802b # RFC5780 Nat Behavior Discovery + AttrOtherAddress = 0x802C # RFC5780 Nat Behavior Discovery + AttrOrigin = 0x802F + +proc isRequired*(typ: uint16): bool = typ <= 0x7FFF'u16 +proc isOptional*(typ: uint16): bool = typ >= 0x8000'u16 + +# Error Code +# https://datatracker.ietf.org/doc/html/rfc5389#section-15.6 + +type + ErrorCodeEnum* = enum + ECTryAlternate = 300 + ECBadRequest = 400 + ECUnauthenticated = 401 + ECUnknownAttribute = 420 + ECStaleNonce = 438 + ECServerError = 500 + ErrorCode* = object + reserved1: uint16 # should be 0 + reserved2 {.bin_bitsize: 5.}: uint8 # should be 0 + class {.bin_bitsize: 3.}: uint8 + number: uint8 + reason: seq[byte] + +proc encode*(T: typedesc[ErrorCode], code: ErrorCodeEnum, reason: string = ""): RawStunAttribute = + let + ec = T(class: (code.uint16 div 100'u16).uint8, + number: (code.uint16 mod 100'u16).uint8, + reason: reason.toBytes()) + value = Binary.encode(ec) + result = RawStunAttribute(attributeType: AttrErrorCode.uint16, + length: value.len().uint16, + value: value) + +# Unknown Attribute +# https://datatracker.ietf.org/doc/html/rfc5389#section-15.9 + +type + UnknownAttribute* = object + unknownAttr: seq[uint16] + +proc encode*(T: typedesc[UnknownAttribute], unknownAttr: seq[uint16]): RawStunAttribute = + let + ua = T(unknownAttr: unknownAttr) + value = Binary.encode(ua) + result = RawStunAttribute(attributeType: AttrUnknownAttributes.uint16, + length: value.len().uint16, + value: value) + +# Fingerprint +# https://datatracker.ietf.org/doc/html/rfc5389#section-15.5 + +type + Fingerprint* = object + crc32: uint32 + +proc encode*(T: typedesc[Fingerprint], msg: seq[byte]): RawStunAttribute = + let value = Binary.encode(T(crc32: crc32(msg) xor 0x5354554e'u32)) + result = RawStunAttribute(attributeType: AttrFingerprint.uint16, + length: value.len().uint16, + value: value) + +# Xor Mapped Address +# https://datatracker.ietf.org/doc/html/rfc5389#section-15.2 + +type + MappedAddressFamily {.size: 1.} = enum + MAFIPv4 = 0x01 + MAFIPv6 = 0x02 + + XorMappedAddress* = object + reserved: uint8 # should be 0 + family: MappedAddressFamily + port: uint16 + address: seq[byte] + +proc encode*(T: typedesc[XorMappedAddress], ta: TransportAddress, + tid: array[12, byte]): RawStunAttribute = + const magicCookie = @[ 0x21'u8, 0x12, 0xa4, 0x42 ] + let + (address, family) = + if ta.family == AddressFamily.IPv4: + var s = newSeq[uint8](4) + for i in 0..3: + s[i] = ta.address_v4[i] xor magicCookie[i] + (s, MAFIPv4) + else: + let magicCookieTid = magicCookie.concat(@tid) + var s = newSeq[uint8](16) + for i in 0..15: + s[i] = ta.address_v6[i] xor magicCookieTid[i] + (s, MAFIPv6) + xma = T(family: family, port: ta.port.distinctBase xor 0x2112'u16, address: address) + value = Binary.encode(xma) + result = RawStunAttribute(attributeType: AttrXORMappedAddress.uint16, + length: value.len().uint16, + value: value) + +# Message Integrity +# https://datatracker.ietf.org/doc/html/rfc5389#section-15.4 + +type + MessageIntegrity* = object + msgInt: seq[byte] + +proc encode*(T: typedesc[MessageIntegrity], msg: seq[byte], key: seq[byte]): RawStunAttribute = + let value = Binary.encode(T(msgInt: hmacSha1(key, msg))) + result = RawStunAttribute(attributeType: AttrMessageIntegrity.uint16, + length: value.len().uint16, value: value) diff --git a/webrtc/stun/stun_connection.nim b/webrtc/stun/stun_connection.nim new file mode 100644 index 0000000..a7289ec --- /dev/null +++ b/webrtc/stun/stun_connection.nim @@ -0,0 +1,61 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import chronos, chronicles +import ../udp_connection, stun + +logScope: + topics = "webrtc stun" + +# TODO: Work fine when behaves like a server, need to implement the client side + +type + StunConn* = ref object + conn: UdpConn + laddr: TransportAddress + dataRecv: AsyncQueue[(seq[byte], TransportAddress)] + handlesFut: Future[void] + closed: bool + +proc handles(self: StunConn) {.async.} = + while true: + let (msg, raddr) = await self.conn.read() + if Stun.isMessage(msg): + let res = Stun.getResponse(msg, self.laddr) + if res.isSome(): + await self.conn.write(raddr, res.get()) + else: + self.dataRecv.addLastNoWait((msg, raddr)) + +proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) = + self.conn = conn + self.laddr = laddr + self.closed = false + + self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]() + self.handlesFut = self.handles() + +proc close*(self: StunConn) {.async.} = + if self.closed: + debug "Try to close StunConn twice" + return + self.handlesFut.cancel() # check before? + await self.conn.close() + +proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} = + if self.closed: + debug "Try to write on an already closed StunConn" + return + await self.conn.write(raddr, msg) + +proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} = + if self.closed: + debug "Try to read on an already closed StunConn" + return + return await self.dataRecv.popFirst() diff --git a/webrtc/udp_connection.nim b/webrtc/udp_connection.nim new file mode 100644 index 0000000..a096231 --- /dev/null +++ b/webrtc/udp_connection.nim @@ -0,0 +1,58 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import sequtils +import chronos, chronicles + +logScope: + topics = "webrtc udp" + +# UdpConn is a small wrapper of the chronos DatagramTransport. +# It's the simplest solution we found to store the message and +# the remote address used by the underlying protocols (dtls/sctp etc...) + +type + UdpConn* = ref object + laddr*: TransportAddress + udp: DatagramTransport + dataRecv: AsyncQueue[(seq[byte], TransportAddress)] + closed: bool + +proc init*(self: UdpConn, laddr: TransportAddress) = + self.laddr = laddr + self.closed = false + + proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} = + trace "UDP onReceive" + let msg = udp.getMessage() + self.dataRecv.addLastNoWait((msg, address)) + + self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]() + self.udp = newDatagramTransport(onReceive, local = laddr) + +proc close*(self: UdpConn) {.async.} = + if self.closed: + debug "Try to close UdpConn twice" + return + self.closed = true + self.udp.close() + +proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} = + if self.closed: + debug "Try to write on an already closed UdpConn" + return + trace "UDP write", msg + await self.udp.sendTo(raddr, msg) + +proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} = + if self.closed: + debug "Try to read on an already closed UdpConn" + return + trace "UDP read" + return await self.dataRecv.popFirst() diff --git a/webrtc/usrsctp.nim b/webrtc/usrsctp.nim index 3aa7180..c3e9d69 100644 --- a/webrtc/usrsctp.nim +++ b/webrtc/usrsctp.nim @@ -4,12 +4,12 @@ import strformat, os import nativesockets # C include directory -const root = currentSourcePath.parentDir +const root = currentSourcePath.parentDir.parentDir const usrsctpInclude = root/"usrsctp"/"usrsctplib" {.passc: fmt"-I{usrsctpInclude}".} -# Generated @ 2022-11-23T14:21:00+01:00 +# Generated @ 2023-03-30T13:55:23+02:00 # Command line: # /home/lchenut/.nimble/pkgs/nimterop-0.6.13/nimterop/toast --compile=./usrsctp/usrsctplib/netinet/sctp_input.c --compile=./usrsctp/usrsctplib/netinet/sctp_asconf.c --compile=./usrsctp/usrsctplib/netinet/sctp_pcb.c --compile=./usrsctp/usrsctplib/netinet/sctp_usrreq.c --compile=./usrsctp/usrsctplib/netinet/sctp_cc_functions.c --compile=./usrsctp/usrsctplib/netinet/sctp_auth.c --compile=./usrsctp/usrsctplib/netinet/sctp_userspace.c --compile=./usrsctp/usrsctplib/netinet/sctp_output.c --compile=./usrsctp/usrsctplib/netinet/sctp_callout.c --compile=./usrsctp/usrsctplib/netinet/sctp_crc32.c --compile=./usrsctp/usrsctplib/netinet/sctp_sysctl.c --compile=./usrsctp/usrsctplib/netinet/sctp_sha1.c --compile=./usrsctp/usrsctplib/netinet/sctp_timer.c --compile=./usrsctp/usrsctplib/netinet/sctputil.c --compile=./usrsctp/usrsctplib/netinet/sctp_bsd_addr.c --compile=./usrsctp/usrsctplib/netinet/sctp_peeloff.c --compile=./usrsctp/usrsctplib/netinet/sctp_indata.c --compile=./usrsctp/usrsctplib/netinet/sctp_ss_functions.c --compile=./usrsctp/usrsctplib/user_socket.c --compile=./usrsctp/usrsctplib/netinet6/sctp6_usrreq.c --compile=./usrsctp/usrsctplib/user_mbuf.c --compile=./usrsctp/usrsctplib/user_environment.c --compile=./usrsctp/usrsctplib/user_recv_thread.c --pnim --preprocess --noHeader --defines=SCTP_PROCESS_LEVEL_LOCKS --defines=SCTP_SIMPLE_ALLOCATOR --defines=__Userspace__ --defines=STDC_HEADERS=1 --defines=HAVE_SYS_TYPES_H=1 --defines=HAVE_SYS_STAT_H=1 --defines=HAVE_STDLIB_H=1 --defines=HAVE_STRING_H=1 --defines=HAVE_MEMORY_H=1 --defines=HAVE_STRINGS_H=1 --defines=HAVE_INTTYPES_H=1 --defines=HAVE_STDINT_H=1 --defines=HAVE_UNISTD_H=1 --defines=HAVE_DLFCN_H=1 --defines=LT_OBJDIR=".libs/" --defines=SCTP_DEBUG=1 --defines=INET=1 --defines=INET6=1 --defines=HAVE_SOCKET=1 --defines=HAVE_INET_ADDR=1 --defines=HAVE_STDATOMIC_H=1 --defines=HAVE_SYS_QUEUE_H=1 --defines=HAVE_LINUX_IF_ADDR_H=1 --defines=HAVE_LINUX_RTNETLINK_H=1 --defines=HAVE_NETINET_IP_ICMP_H=1 --defines=HAVE_NET_ROUTE_H=1 --defines=_GNU_SOURCE --replace=sockaddr=SockAddr --replace=SockAddr_storage=Sockaddr_storage --replace=SockAddr_in=Sockaddr_in --replace=SockAddr_conn=Sockaddr_conn --replace=socklen_t=SockLen --includeDirs=./usrsctp/usrsctplib ./usrsctp/usrsctplib/usrsctp.h @@ -47,30 +47,29 @@ const usrsctpInclude = root/"usrsctp"/"usrsctplib" {.passc: "-DHAVE_NETINET_IP_ICMP_H=1".} {.passc: "-DHAVE_NET_ROUTE_H=1".} {.passc: "-D_GNU_SOURCE".} -{.passc: "-I./usrsctp/usrsctplib".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_input.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_asconf.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_pcb.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_usrreq.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_cc_functions.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_auth.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_userspace.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_output.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_callout.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_crc32.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_sysctl.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_sha1.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_timer.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctputil.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_bsd_addr.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_peeloff.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_indata.c".} -{.compile: "./usrsctp/usrsctplib/netinet/sctp_ss_functions.c".} -{.compile: "./usrsctp/usrsctplib/user_socket.c".} -{.compile: "./usrsctp/usrsctplib/netinet6/sctp6_usrreq.c".} -{.compile: "./usrsctp/usrsctplib/user_mbuf.c".} -{.compile: "./usrsctp/usrsctplib/user_environment.c".} -{.compile: "./usrsctp/usrsctplib/user_recv_thread.c".} +{.compile: usrsctpInclude / "netinet/sctp_input.c".} +{.compile: usrsctpInclude / "netinet/sctp_asconf.c".} +{.compile: usrsctpInclude / "netinet/sctp_pcb.c".} +{.compile: usrsctpInclude / "netinet/sctp_usrreq.c".} +{.compile: usrsctpInclude / "netinet/sctp_cc_functions.c".} +{.compile: usrsctpInclude / "netinet/sctp_auth.c".} +{.compile: usrsctpInclude / "netinet/sctp_userspace.c".} +{.compile: usrsctpInclude / "netinet/sctp_output.c".} +{.compile: usrsctpInclude / "netinet/sctp_callout.c".} +{.compile: usrsctpInclude / "netinet/sctp_crc32.c".} +{.compile: usrsctpInclude / "netinet/sctp_sysctl.c".} +{.compile: usrsctpInclude / "netinet/sctp_sha1.c".} +{.compile: usrsctpInclude / "netinet/sctp_timer.c".} +{.compile: usrsctpInclude / "netinet/sctputil.c".} +{.compile: usrsctpInclude / "netinet/sctp_bsd_addr.c".} +{.compile: usrsctpInclude / "netinet/sctp_peeloff.c".} +{.compile: usrsctpInclude / "netinet/sctp_indata.c".} +{.compile: usrsctpInclude / "netinet/sctp_ss_functions.c".} +{.compile: usrsctpInclude / "user_socket.c".} +{.compile: usrsctpInclude / "netinet6/sctp6_usrreq.c".} +{.compile: usrsctpInclude / "user_mbuf.c".} +{.compile: usrsctpInclude / "user_environment.c".} +{.compile: usrsctpInclude / "user_recv_thread.c".} const MSG_NOTIFICATION* = 0x00002000 AF_CONN* = 123 diff --git a/webrtc/webrtc.nim b/webrtc/webrtc.nim new file mode 100644 index 0000000..f67a231 --- /dev/null +++ b/webrtc/webrtc.nim @@ -0,0 +1,44 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import chronos, chronicles + +import udp_connection +import stun/stun_connection +import dtls/dtls +import sctp, datachannel + +logScope: + topics = "webrtc" + +type + WebRTC* = ref object + udp*: UdpConn + stun*: StunConn + dtls*: Dtls + sctp*: Sctp + port: int + +proc new*(T: typedesc[WebRTC], address: TransportAddress): T = + result = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls(), sctp: Sctp()) + result.udp.init(address) + result.stun.init(result.udp, address) + result.dtls.init(result.stun, address) + result.sctp.init(result.dtls, address) + +proc listen*(self: WebRTC) = + self.sctp.listen() + +proc connect*(self: WebRTC, raddr: TransportAddress): Future[DataChannelConnection] {.async.} = + let sctpConn = await self.sctp.connect(raddr) # TODO: Port? + result = DataChannelConnection.new(sctpConn) + +proc accept*(w: WebRTC): Future[DataChannelConnection] {.async.} = + let sctpConn = await w.sctp.accept() + result = DataChannelConnection.new(sctpConn)