nim-libp2p/libp2p/transports/webrtctransport.nim

436 lines
12 KiB
Nim
Raw Normal View History

2023-10-11 16:18:53 +02:00
# Nim-LibP2P
# Copyright (c) 2023 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
## WebRtc transport implementation
## For now, only support WebRtc direct (ie browser to server)
{.push raises: [].}
import std/[sequtils]
2023-10-13 18:07:52 +02:00
import stew/[endians2, byteutils, objects, results]
2023-10-11 16:18:53 +02:00
import chronos, chronicles
import transport,
../errors,
../wire,
../multicodec,
2023-10-13 18:07:52 +02:00
../protobuf/minprotobuf,
2023-10-11 16:18:53 +02:00
../connmanager,
2023-10-13 18:07:52 +02:00
../muxers/muxer,
2023-10-11 16:18:53 +02:00
../multiaddress,
../stream/connection,
../upgrademngrs/upgrade,
2023-10-13 18:07:52 +02:00
../protocols/secure/noise,
2023-10-11 16:18:53 +02:00
../utility
2023-10-13 18:07:52 +02:00
import webrtc/webrtc, webrtc/datachannel
2023-10-11 16:18:53 +02:00
logScope:
topics = "libp2p webrtctransport"
export transport, results
const
WebRtcTransportTrackerName* = "libp2p.webrtctransport"
# -- Message --
type
MessageFlag = enum
Fin = 0
StopSending = 1
ResetStream = 2
FinAck = 3
WebRtcMessage = object
flag: Opt[MessageFlag]
data: seq[byte]
proc decode(_: type WebRtcMessage, bytes: seq[byte]): Opt[WebRtcMessage] =
var
pb = initProtoBuffer(bytes)
flagOrd: uint32
res: WebRtcMessage
2023-10-13 18:07:52 +02:00
if ? pb.getField(1, flagOrd).toOpt():
2023-10-11 16:18:53 +02:00
var flag: MessageFlag
2023-10-13 18:07:52 +02:00
if flag.checkedEnumAssign(flagOrd):
2023-10-11 16:18:53 +02:00
res.flag = Opt.some(flag)
2023-10-13 18:07:52 +02:00
discard ? pb.getField(2, res.data).toOpt()
2023-10-11 16:18:53 +02:00
Opt.some(res)
proc encode(msg: WebRtcMessage): seq[byte] =
var pb = initProtoBuffer()
msg.flag.withValue(val):
2023-10-13 18:07:52 +02:00
pb.write(1, uint32(val))
2023-10-11 16:18:53 +02:00
if msg.data.len > 0:
2023-10-13 18:07:52 +02:00
pb.write(2, msg.data)
2023-10-11 16:18:53 +02:00
pb.finish()
pb.buffer
# -- Stream --
const MaxMessageSize = 16384 # 16KiB
type
WebRtcState = enum
Sending, Closing, Closed
WebRtcStream = ref object of Connection
2023-10-13 18:07:52 +02:00
dataChannel: DataChannelStream
2023-10-11 16:18:53 +02:00
sendQueue: seq[(seq[byte], Future[void])]
sendLoop: Future[void]
readData: seq[byte]
txState: WebRtcState
rxState: WebRtcState
proc new(
_: type WebRtcStream,
2023-10-13 18:07:52 +02:00
dataChannel: DataChannelStream,
2023-10-11 16:18:53 +02:00
oaddr: Opt[MultiAddress],
peerId: PeerId): WebRtcStream =
2023-10-13 18:07:52 +02:00
let stream = WebRtcStream(dataChannel: dataChannel, observedAddr: oaddr, peerId: peerId)
2023-10-11 16:18:53 +02:00
procCall Connection(stream).initStream()
stream
2023-10-13 18:07:52 +02:00
proc sender(s: WebRtcStream) {.async.} =
2023-10-11 16:18:53 +02:00
while s.sendQueue.len > 0:
let (message, fut) = s.sendQueue.pop()
#TODO handle exceptions
await s.dataChannel.write(message)
if not fut.isNil: fut.complete()
2023-10-13 18:07:52 +02:00
proc send(s: WebRtcStream, msg: WebRtcMessage, fut: Future[void] = nil) =
2023-10-11 16:18:53 +02:00
let wrappedMessage = msg.encode()
s.sendQueue.insert((wrappedMessage, fut))
if s.sendLoop == nil or s.sendLoop.finished:
s.sendLoop = s.sender()
2023-10-13 18:07:52 +02:00
method write*(s: WebRtcStream, msg2: seq[byte]): Future[void] =
2023-10-11 16:18:53 +02:00
# We need to make sure we send all of our data before another write
# Otherwise, two concurrent writes could get intertwined
# We avoid this by filling the s.sendQueue synchronously
2023-10-13 18:07:52 +02:00
var msg = msg2
let retFuture = newFuture[void]("WebRtcStream.write")
2023-10-11 16:18:53 +02:00
if s.txState != Sending:
retFuture.fail(newLPStreamClosedError())
return retFuture
var messages: seq[seq[byte]]
while msg.len > MaxMessageSize - 16:
let
endOfMessage = MaxMessageSize - 16
wrappedMessage = WebRtcMessage(data: msg[0 ..< endOfMessage])
s.send(wrappedMessage)
msg = msg[endOfMessage .. ^1]
let
wrappedMessage = WebRtcMessage(data: msg)
s.send(wrappedMessage, retFuture)
return retFuture
2023-10-13 18:07:52 +02:00
proc actuallyClose(s: WebRtcStream) {.async.} =
2023-10-11 16:18:53 +02:00
if s.rxState == Closed and s.txState == Closed and s.readData.len == 0:
2023-10-13 18:07:52 +02:00
#TODO add support to DataChannel
#await s.dataChannel.close()
2023-10-11 16:18:53 +02:00
await procCall Connection(s).closeImpl()
2023-10-13 18:07:52 +02:00
method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.async.} =
if s.rxState == Closed:
2023-10-11 16:18:53 +02:00
raise newLPStreamEOFError()
while s.readData.len == 0:
if s.rxState == Closed:
await s.actuallyClose()
return 0
let
#TODO handle exceptions
2023-10-13 18:07:52 +02:00
message = await s.dataChannel.read()
decoded = WebRtcMessage.decode(message).tryGet()
2023-10-11 16:18:53 +02:00
decoded.flag.withValue(flag):
case flag:
of Fin:
# Peer won't send any more data
s.rxState = Closed
s.send(WebRtcMessage(flag: Opt.some(FinAck)))
of FinAck:
s.txState = Closed
await s.actuallyClose()
2023-10-13 18:07:52 +02:00
else: discard
2023-10-11 16:18:53 +02:00
s.readData = decoded.data
result = min(nbytes, s.readData.len)
2023-10-13 18:07:52 +02:00
copyMem(pbytes, addr s.readData[0], result)
s.readData = s.readData[result..^1]
2023-10-11 16:18:53 +02:00
2023-10-13 18:07:52 +02:00
method closeImpl*(s: WebRtcStream) {.async.} =
2023-10-11 16:18:53 +02:00
s.send(WebRtcMessage(flag: Opt.some(Fin)))
s.txState = Closing
await s.join() #TODO ??
# -- Connection --
type WebRtcConnection = ref object of Connection
connection: DataChannelConnection
method close*(conn: WebRtcConnection) {.async.} =
#TODO
discard
2023-10-13 18:07:52 +02:00
proc new(
_: type WebRtcConnection,
conn: DataChannelConnection,
observedAddr: Opt[MultiAddress]
): WebRtcConnection =
let co = WebRtcConnection(connection: conn, observedAddr: observedAddr)
procCall Connection(co).initStream()
co
2023-10-11 16:18:53 +02:00
proc getStream*(conn: WebRtcConnection,
direction: Direction): Future[WebRtcStream] {.async.} =
var datachannel =
case direction:
of Direction.In:
2023-10-13 18:07:52 +02:00
await conn.connection.accept()
2023-10-11 16:18:53 +02:00
of Direction.Out:
2023-10-13 18:07:52 +02:00
await conn.connection.openStream(0) #TODO don't hardcode stream id (should be in nim-webrtc)
2023-10-11 16:18:53 +02:00
return WebRtcStream.new(datachannel, conn.observedAddr, conn.peerId)
# -- Muxer --
type WebRtcMuxer = ref object of Muxer
webRtcConn: WebRtcConnection
handleFut: Future[void]
method newStream*(m: WebRtcMuxer, name: string = "", lazy: bool = false): Future[Connection] {.async, gcsafe.} =
return await m.webRtcConn.getStream(Direction.Out)
proc handleStream(m: WebRtcMuxer, chann: WebRtcStream) {.async.} =
try:
await m.streamHandler(chann)
trace "finished handling stream"
doAssert(chann.closed, "connection not closed by handler!")
except CatchableError as exc:
trace "Exception in mplex stream handler", msg = exc.msg
await chann.close()
2023-10-13 18:07:52 +02:00
#TODO add atEof
2023-10-11 16:18:53 +02:00
method handle*(m: WebRtcMuxer): Future[void] {.async, gcsafe.} =
try:
2023-10-13 18:07:52 +02:00
#while not m.webRtcConn.atEof:
while true:
2023-10-11 16:18:53 +02:00
let incomingStream = await m.webRtcConn.getStream(Direction.In)
asyncSpawn m.handleStream(incomingStream)
finally:
await m.webRtcConn.close()
method close*(m: WebRtcMuxer) {.async, gcsafe.} =
m.handleFut.cancel()
await m.webRtcConn.close()
# -- Upgrader --
type WebRtcUpgrade = ref object of Upgrade
method upgrade*(
self: WebRtcUpgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
let webRtcConn = WebRtcConnection(conn)
result = WebRtcMuxer(webRtcConn: webRtcConn)
# Noise handshake
let noiseHandler = self.secureManagers.filterIt(it of Noise)
assert noiseHandler.len > 0
let
2023-10-13 18:07:52 +02:00
stream = await webRtcConn.getStream(Out) #TODO add channelId: 0
secureStream = await noiseHandler[0].handshake(
2023-10-11 16:18:53 +02:00
stream,
2023-10-13 18:07:52 +02:00
initiator = true, # we are always the initiator in webrtc-direct
peerId = peerId
2023-10-11 16:18:53 +02:00
#TODO: add prelude data
)
# Peer proved its identity, we can close this
await secureStream.close()
await stream.close()
# -- Transport --
type
WebRtcTransport* = ref object of Transport
connectionsTimeout: Duration
servers: seq[WebRtc]
2023-10-13 18:07:52 +02:00
acceptFuts: seq[Future[DataChannelConnection]]
2023-10-11 16:18:53 +02:00
clients: array[Direction, seq[DataChannelConnection]]
WebRtcTransportTracker* = ref object of TrackerBase
opened*: uint64
closed*: uint64
WebRtcTransportError* = object of transport.TransportError
proc setupWebRtcTransportTracker(): WebRtcTransportTracker {.gcsafe, raises: [].}
proc getWebRtcTransportTracker(): WebRtcTransportTracker {.gcsafe.} =
result = cast[WebRtcTransportTracker](getTracker(WebRtcTransportTrackerName))
if isNil(result):
result = setupWebRtcTransportTracker()
proc dumpTracking(): string {.gcsafe.} =
var tracker = getWebRtcTransportTracker()
result = "Opened tcp transports: " & $tracker.opened & "\n" &
"Closed tcp transports: " & $tracker.closed
proc leakTransport(): bool {.gcsafe.} =
var tracker = getWebRtcTransportTracker()
result = (tracker.opened != tracker.closed)
proc setupWebRtcTransportTracker(): WebRtcTransportTracker =
result = new WebRtcTransportTracker
result.opened = 0
result.closed = 0
result.dump = dumpTracking
result.isLeaked = leakTransport
addTracker(WebRtcTransportTrackerName, result)
proc new*(
T: typedesc[WebRtcTransport],
upgrade: Upgrade,
connectionsTimeout = 10.minutes): T {.public.} =
let
transport = T(
upgrader: WebRtcUpgrade(secureManagers: upgrade.secureManagers),
connectionsTimeout: connectionsTimeout)
return transport
method start*(
self: WebRtcTransport,
addrs: seq[MultiAddress]) {.async.} =
## listen on the transport
##
if self.running:
warn "WebRtc transport already running"
return
await procCall Transport(self).start(addrs)
trace "Starting WebRtc transport"
inc getWebRtcTransportTracker().opened
for i, ma in addrs:
if not self.handles(ma):
trace "Invalid address detected, skipping!", address = ma
continue
let
2023-10-13 18:07:52 +02:00
transportAddress = initTAddress(ma[0..1].tryGet()).tryGet()
2023-10-11 16:18:53 +02:00
server = WebRtc.new(transportAddress)
2023-10-13 18:07:52 +02:00
server.listen()
2023-10-11 16:18:53 +02:00
self.servers &= server
2023-10-13 18:07:52 +02:00
let
2023-10-19 12:12:56 +02:00
cert = server.dtlsLocalCertificate()
2023-10-13 18:07:52 +02:00
certHash = MultiHash.digest("sha2-256", cert).get().data.buffer
encodedCertHash = MultiBase.encode("base64", certHash).get()
self.addrs[i] = (MultiAddress.init(server.udp.laddr, IPPROTO_UDP).tryGet() & MultiAddress.init(multiCodec("webrtc-direct")).tryGet() & MultiAddress.init(multiCodec("cert-hash"), encodedCertHash).tryGet()).tryGet()
2023-10-11 16:18:53 +02:00
trace "Listening on", address = self.addrs[i]
proc connHandler(self: WebRtcTransport,
client: DataChannelConnection,
observedAddr: Opt[MultiAddress],
dir: Direction): Future[Connection] {.async.} =
trace "Handling ws connection", address = $observedAddr,
dir = $dir,
clients = self.clients[Direction.In].len +
self.clients[Direction.Out].len
let conn: Connection =
WebRtcConnection.new(
2023-10-13 18:07:52 +02:00
conn = client,
# dir = dir,
observedAddr = observedAddr
# timeout = self.connectionsTimeout
2023-10-11 16:18:53 +02:00
)
proc onClose() {.async.} =
try:
2023-10-13 18:07:52 +02:00
let futs = @[conn.join(), conn.join()] #TODO that's stupid
2023-10-11 16:18:53 +02:00
await futs[0] or futs[1]
for f in futs:
if not f.finished: await f.cancelAndWait() # cancel outstanding join()
trace "Cleaning up client", addrs = $client.remoteAddress,
conn
self.clients[dir].keepItIf( it != client )
2023-10-13 18:07:52 +02:00
#TODO
#await allFuturesThrowing(
# conn.close(), client.closeWait())
2023-10-11 16:18:53 +02:00
trace "Cleaned up client", addrs = $client.remoteAddress,
conn
except CatchableError as exc:
let useExc {.used.} = exc
debug "Error cleaning up client", errMsg = exc.msg, conn
self.clients[dir].add(client)
asyncSpawn onClose()
return conn
method accept*(self: WebRtcTransport): Future[Connection] {.async, gcsafe.} =
if not self.running:
raise newTransportClosedError()
#TODO handle errors
if self.acceptFuts.len <= 0:
self.acceptFuts = self.servers.mapIt(it.accept())
if self.acceptFuts.len <= 0:
return
let
finished = await one(self.acceptFuts)
index = self.acceptFuts.find(finished)
self.acceptFuts[index] = self.servers[index].accept()
let transp = await finished
try:
2023-10-13 18:07:52 +02:00
#TODO add remoteAddress to DataChannelConnection
#let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet() #TODO add /webrtc-direct
let observedAddr = MultiAddress.init("/ip4/127.0.0.1").tryGet()
2023-10-11 16:18:53 +02:00
return await self.connHandler(transp, Opt.some(observedAddr), Direction.In)
except CancelledError as exc:
2023-10-13 18:07:52 +02:00
#TODO
#transp.close()
2023-10-11 16:18:53 +02:00
raise exc
except CatchableError as exc:
debug "Failed to handle connection", exc = exc.msg
2023-10-13 18:07:52 +02:00
#TODO
#transp.close()
2023-10-11 16:18:53 +02:00
method handles*(t: WebRtcTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address):
if address.protocols.isOk:
return WebRtcDirect2.match(address)