Fix a lot of small bugs

This commit is contained in:
Ludovic Chenut 2024-02-15 16:10:20 +01:00
parent 03ff023e94
commit afe2b08129
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
2 changed files with 47 additions and 15 deletions

View File

@ -122,8 +122,8 @@ type
sendQueue: seq[(seq[byte], Future[void])] sendQueue: seq[(seq[byte], Future[void])]
sendLoop: Future[void] sendLoop: Future[void]
readData: seq[byte] readData: seq[byte]
txState: WebRtcState txState: WebRtcState # Transmission
rxState: WebRtcState rxState: WebRtcState # Reception
proc new( proc new(
_: type WebRtcStream, _: type WebRtcStream,
@ -176,7 +176,7 @@ method write*(s: WebRtcStream, msg2: seq[byte]): Future[void] =
return retFuture return retFuture
proc actuallyClose(s: WebRtcStream) {.async.} = proc actuallyClose(s: WebRtcStream) {.async.} =
debug "stream closed" debug "stream closed", rxState=s.rxState, txState=s.txState
if s.rxState == Closed and s.txState == Closed and s.readData.len == 0: if s.rxState == Closed and s.txState == Closed and s.readData.len == 0:
#TODO add support to DataChannel #TODO add support to DataChannel
#await s.dataChannel.close() #await s.dataChannel.close()
@ -186,7 +186,9 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a
if s.rxState == Closed: if s.rxState == Closed:
raise newLPStreamEOFError() raise newLPStreamEOFError()
while s.readData.len == 0: while s.readData.len == 0 or nbytes == 0:
# Check if there's no data left in readData or if nbytes is equal to 0
# in order to read an eventual Fin or FinAck
if s.rxState == Closed: if s.rxState == Closed:
await s.actuallyClose() await s.actuallyClose()
return 0 return 0
@ -196,6 +198,8 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a
message = await s.rawStream.readLp(MaxMessageSize) message = await s.rawStream.readLp(MaxMessageSize)
decoded = WebRtcMessage.decode(message).tryGet() decoded = WebRtcMessage.decode(message).tryGet()
s.readData = s.readData.concat(decoded.data)
decoded.flag.withValue(flag): decoded.flag.withValue(flag):
case flag: case flag:
of Fin: of Fin:
@ -205,10 +209,10 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a
of FinAck: of FinAck:
s.txState = Closed s.txState = Closed
await s.actuallyClose() await s.actuallyClose()
if nbytes == 0:
return 0
else: discard else: discard
s.readData = decoded.data
result = min(nbytes, s.readData.len) result = min(nbytes, s.readData.len)
copyMem(pbytes, addr s.readData[0], result) copyMem(pbytes, addr s.readData[0], result)
s.readData = s.readData[result..^1] s.readData = s.readData[result..^1]
@ -216,7 +220,8 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a
method closeImpl*(s: WebRtcStream) {.async.} = method closeImpl*(s: WebRtcStream) {.async.} =
s.send(WebRtcMessage(flag: Opt.some(Fin))) s.send(WebRtcMessage(flag: Opt.some(Fin)))
s.txState = Closing s.txState = Closing
await s.join() #TODO ?? while s.txState != Closed:
discard await s.readOnce(nil, 0)
# -- Connection -- # -- Connection --
type WebRtcConnection = ref object of Connection type WebRtcConnection = ref object of Connection
@ -236,13 +241,15 @@ proc new(
co co
proc getStream*(conn: WebRtcConnection, proc getStream*(conn: WebRtcConnection,
direction: Direction): Future[WebRtcStream] {.async.} = direction: Direction,
noiseHandshake: bool = false): Future[WebRtcStream] {.async.} =
var datachannel = var datachannel =
case direction: case direction:
of Direction.In: of Direction.In:
await conn.connection.accept() await conn.connection.accept()
of Direction.Out: of Direction.Out:
await conn.connection.openStream(0) #TODO don't hardcode stream id (should be in nim-webrtc) #TODO don't hardcode stream id (should be in nim-webrtc)
await conn.connection.openStream(noiseHandshake)
return WebRtcStream.new(datachannel, conn.observedAddr, conn.peerId) return WebRtcStream.new(datachannel, conn.observedAddr, conn.peerId)
# -- Muxer -- # -- Muxer --
@ -278,7 +285,10 @@ method close*(m: WebRtcMuxer) {.async, gcsafe.} =
await m.webRtcConn.close() await m.webRtcConn.close()
# -- Upgrader -- # -- Upgrader --
type WebRtcUpgrade = ref object of Upgrade type
WebRtcStreamHandler = proc(conn: Connection): Future[void] {.gcsafe, raises: [].}
WebRtcUpgrade = ref object of Upgrade
streamHandler: WebRtcStreamHandler
method upgrade*( method upgrade*(
self: WebRtcUpgrade, self: WebRtcUpgrade,
@ -287,7 +297,7 @@ method upgrade*(
peerId: Opt[PeerId]): Future[Muxer] {.async.} = peerId: Opt[PeerId]): Future[Muxer] {.async.} =
let webRtcConn = WebRtcConnection(conn) let webRtcConn = WebRtcConnection(conn)
result = WebRtcMuxer(webRtcConn: webRtcConn) result = WebRtcMuxer(connection: conn, webRtcConn: webRtcConn)
# Noise handshake # Noise handshake
let noiseHandler = self.secureManagers.filterIt(it of Noise) let noiseHandler = self.secureManagers.filterIt(it of Noise)
@ -300,7 +310,7 @@ method upgrade*(
echo "=> ", ((Noise)noiseHandler[0]).commonPrologue echo "=> ", ((Noise)noiseHandler[0]).commonPrologue
let let
stream = await webRtcConn.getStream(Out) #TODO add channelId: 0 stream = await webRtcConn.getStream(Out, true) #TODO add channelId: 0
secureStream = await noiseHandler[0].handshake( secureStream = await noiseHandler[0].handshake(
stream, stream,
initiator = true, # we are always the initiator in webrtc-direct initiator = true, # we are always the initiator in webrtc-direct
@ -311,6 +321,9 @@ method upgrade*(
await secureStream.close() await secureStream.close()
await stream.close() await stream.close()
result.streamHandler = self.streamHandler
result.handler = result.handle()
# -- Transport -- # -- Transport --
type type
WebRtcTransport* = ref object of Transport WebRtcTransport* = ref object of Transport
@ -354,9 +367,24 @@ proc new*(
upgrade: Upgrade, upgrade: Upgrade,
connectionsTimeout = 10.minutes): T {.public.} = connectionsTimeout = 10.minutes): T {.public.} =
let upgrader = WebRtcUpgrade(ms: upgrade.ms, secureManagers: upgrade.secureManagers)
upgrader.streamHandler = proc(conn: Connection)
{.async, gcsafe, raises: [].} =
# TODO: replace echo by trace and find why it fails compiling
echo "Starting stream handler"#, conn
try:
await upgrader.ms.handle(conn) # handle incoming connection
except CancelledError as exc:
raise exc
except CatchableError as exc:
echo "exception in stream handler", exc.msg#, conn, msg = exc.msg
finally:
await conn.closeWithEOF()
echo "Stream handler done"#, conn
let let
transport = T( transport = T(
upgrader: WebRtcUpgrade(secureManagers: upgrade.secureManagers), upgrader: upgrader,
connectionsTimeout: connectionsTimeout) connectionsTimeout: connectionsTimeout)
return transport return transport

View File

@ -6,10 +6,13 @@ proc echoHandler(conn: Connection, proto: string) {.async.} =
while true: while true:
try: try:
echo "\e[35;1m => Echo Handler <=\e[0m" echo "\e[35;1m => Echo Handler <=\e[0m"
let msg = string.fromBytes(await conn.readLp(1024)) var xx = newSeq[byte](1024)
let aa = await conn.readOnce(addr xx[0], 1024)
xx = xx[0..<aa]
let msg = string.fromBytes(xx)
echo " => Echo Handler Receive: ", msg, " <=" echo " => Echo Handler Receive: ", msg, " <="
echo " => Echo Handler Try Send: ", msg & "1", " <=" echo " => Echo Handler Try Send: ", msg & "1", " <="
await conn.writeLp(msg & "1") await conn.write(msg & "1")
except CatchableError as e: except CatchableError as e:
echo " => Echo Handler Error: ", e.msg, " <=" echo " => Echo Handler Error: ", e.msg, " <="
break break
@ -20,6 +23,7 @@ proc main {.async.} =
.withAddress(MultiAddress.init("/ip4/127.0.0.1/udp/4242/webrtc-direct/certhash/uEiDDq4_xNyDorZBH3TlGazyJdOWSwvo4PUo5YHFMrvDE8g").tryGet()) #TODO the certhash shouldn't be necessary .withAddress(MultiAddress.init("/ip4/127.0.0.1/udp/4242/webrtc-direct/certhash/uEiDDq4_xNyDorZBH3TlGazyJdOWSwvo4PUo5YHFMrvDE8g").tryGet()) #TODO the certhash shouldn't be necessary
.withRng(crypto.newRng()) .withRng(crypto.newRng())
.withMplex() .withMplex()
.withYamux()
.withTransport(proc (upgr: Upgrade): Transport = WebRtcTransport.new(upgr)) .withTransport(proc (upgr: Upgrade): Transport = WebRtcTransport.new(upgr))
.withNoise() .withNoise()
.build() .build()