From d1ba2ee0bc691964067955ffd9d4a4de3ad83476 Mon Sep 17 00:00:00 2001 From: Ludovic Chenut Date: Tue, 25 Apr 2023 11:56:30 +0200 Subject: [PATCH] Stun done --- webrtc/stun.nim | 58 +++++++++++++++++++++++++++++---------- webrtc/stunattributes.nim | 21 ++++++++++++++ webrtc/webrtc.nim | 12 ++++++-- 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/webrtc/stun.nim b/webrtc/stun.nim index 9e575ff..344195c 100644 --- a/webrtc/stun.nim +++ b/webrtc/stun.nim @@ -1,8 +1,18 @@ -import bitops +# Nim-WebRTC +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import bitops, strutils import chronos, chronicles, binary_serialization, - stew/objects + stew/objects, + stew/byteutils import stunattributes export binary_serialization @@ -18,13 +28,13 @@ const BindingResponse = 0x0101'u16 proc decode(T: typedesc[RawStunAttribute], cnt: seq[byte]): seq[RawStunAttribute] = - const val = @[0, 3, 2, 1] + const pad = @[0, 3, 2, 1] var padding = 0 while padding < cnt.len(): let attr = Binary.decode(cnt[padding ..^ 1], RawStunAttribute) result.add(attr) padding += 4 + attr.value.len() - padding += val[padding mod 4] + padding += pad[padding mod 4] type # Stun Header @@ -50,7 +60,7 @@ type RawStunMessage = object msgType: uint16 # it.conten.len() + 8 Because the Fingerprint is added after the encoding - length* {.bin_value: it.content.len() + 8.}: uint16 + length* {.bin_value: it.content.len().}: uint16 magicCookie: uint32 transactionId: array[12, byte] content* {.bin_len: it.length.}: seq[byte] @@ -71,36 +81,56 @@ proc getAttribute(attrs: seq[RawStunAttribute], typ: uint16): Option[seq[byte]] proc isMessage*(T: typedesc[Stun], msg: seq[byte]): bool = msg.len >= msgHeaderSize and msg[4..<8] == magicCookieSeq and bitand(0xC0'u8, msg[0]) == 0'u8 +proc addLength(msgEncoded: var seq[byte], length: uint16) = + let + hi = (length div 256'u16).uint8 + lo = (length mod 256'u16).uint8 + msgEncoded[2] = msgEncoded[2] + hi + if msgEncoded[3].int + lo.int >= 256: + msgEncoded[2] = msgEncoded[2] + 1 + msgEncoded[3] = ((msgEncoded[3].int + lo.int) mod 256).uint8 + else: + msgEncoded[3] = msgEncoded[3] + lo + proc decode*(T: typedesc[StunMessage], msg: seq[byte]): StunMessage = let smi = Binary.decode(msg, RawStunMessage) return T(msgType: smi.msgType, transactionId: smi.transactionId, attributes: RawStunAttribute.decode(smi.content)) -proc encode*(msg: StunMessage): seq[byte] = - const val = @[0, 3, 2, 1] +proc encode*(msg: StunMessage, userOpt: Option[seq[byte]]): seq[byte] = + const pad = @[0, 3, 2, 1] var smi = RawStunMessage(msgType: msg.msgType, magicCookie: magicCookie, transactionId: msg.transactionId) for attr in msg.attributes: smi.content.add(Binary.encode(attr)) - smi.content.add(newSeq[byte](val[smi.content.len() mod 4])) + smi.content.add(newSeq[byte](pad[smi.content.len() mod 4])) result = Binary.encode(smi) + + if userOpt.isSome(): + let username = string.fromBytes(userOpt.get()) + let usersplit = username.split(":") + if usersplit.len() == 2 and usersplit[0].startsWith("libp2p+webrtc+v1/"): + result.addLength(24) + result.add(Binary.encode(MessageIntegrity.encode(result, toBytes(usersplit[0])))) + + result.addLength(8) result.add(Binary.encode(Fingerprint.encode(result))) proc getResponse*(T: typedesc[Stun], msg: seq[byte], - ta: TransportAddress): Option[StunMessage] = + ta: TransportAddress): Option[seq[byte]] = if ta.family != AddressFamily.IPv4 and ta.family != AddressFamily.IPv6: - return none(StunMessage) + return none(seq[byte]) let sm = try: StunMessage.decode(msg) except CatchableError as exc: - return none(StunMessage) + return none(seq[byte]) if sm.msgType != BindingRequest: - return none(StunMessage) + return none(seq[byte]) var res = StunMessage(msgType: BindingResponse, transactionId: sm.transactionId) @@ -113,10 +143,10 @@ proc getResponse*(T: typedesc[Stun], msg: seq[byte], if unknownAttr.len() > 0: res.attributes.add(ErrorCode.encode(ECUnknownAttribute)) res.attributes.add(UnknownAttribute.encode(unknownAttr)) - return some(res) + return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16))) res.attributes.add(XorMappedAddress.encode(ta, sm.transactionId)) - return some(res) + return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16))) proc new*(T: typedesc[Stun]): T = result = T() diff --git a/webrtc/stunattributes.nim b/webrtc/stunattributes.nim index e57a829..bf2b471 100644 --- a/webrtc/stunattributes.nim +++ b/webrtc/stunattributes.nim @@ -1,3 +1,12 @@ +# Nim-WebRTC +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + import sequtils, typetraits import binary_serialization, stew/byteutils, @@ -150,3 +159,15 @@ proc encode*(T: typedesc[XorMappedAddress], ta: TransportAddress, result = RawStunAttribute(attributeType: AttrXORMappedAddress.uint16, length: value.len().uint16, value: value) + +# Message Integrity + +type + MessageIntegrity* = object + msgInt: seq[byte] + +proc encode*(T: typedesc[MessageIntegrity], msg: seq[byte], key: seq[byte]): RawStunAttribute = + let value = Binary.encode(T(msgInt: hmacSha1(key, msg))) + result = RawStunAttribute(attributeType: AttrMessageIntegrity.uint16, + length: value.len().uint16, + value: value) diff --git a/webrtc/webrtc.nim b/webrtc/webrtc.nim index 05f451b..ce8eb0b 100644 --- a/webrtc/webrtc.nim +++ b/webrtc/webrtc.nim @@ -1,3 +1,12 @@ +# Nim-WebRTC +# Copyright (c) 2023 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + import chronos, chronicles import stun @@ -17,9 +26,8 @@ proc new*(T: typedesc[WebRTC], port: uint16 = 42657): T = msg = udp.getMessage() if Stun.isMessage(msg): let res = Stun.getResponse(msg, address) - echo res if res.isSome(): - await udp.sendTo(address, res.get().encode()) + await udp.sendTo(address, res.get()) trace "onReceive", isStun = Stun.isMessage(msg) if not fut.completed(): fut.complete()