diff --git a/libp2p/transports/webrtctransport.nim b/libp2p/transports/webrtctransport.nim index d5d839204..19d09df4d 100644 --- a/libp2p/transports/webrtctransport.nim +++ b/libp2p/transports/webrtctransport.nim @@ -122,8 +122,8 @@ type sendQueue: seq[(seq[byte], Future[void])] sendLoop: Future[void] readData: seq[byte] - txState: WebRtcState - rxState: WebRtcState + txState: WebRtcState # Transmission + rxState: WebRtcState # Reception proc new( _: type WebRtcStream, @@ -176,7 +176,7 @@ method write*(s: WebRtcStream, msg2: seq[byte]): Future[void] = return retFuture proc actuallyClose(s: WebRtcStream) {.async.} = - debug "stream closed" + debug "stream closed", rxState=s.rxState, txState=s.txState if s.rxState == Closed and s.txState == Closed and s.readData.len == 0: #TODO add support to DataChannel #await s.dataChannel.close() @@ -186,7 +186,9 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a if s.rxState == Closed: raise newLPStreamEOFError() - while s.readData.len == 0: + while s.readData.len == 0 or nbytes == 0: + # Check if there's no data left in readData or if nbytes is equal to 0 + # in order to read an eventual Fin or FinAck if s.rxState == Closed: await s.actuallyClose() return 0 @@ -196,6 +198,8 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a message = await s.rawStream.readLp(MaxMessageSize) decoded = WebRtcMessage.decode(message).tryGet() + s.readData = s.readData.concat(decoded.data) + decoded.flag.withValue(flag): case flag: of Fin: @@ -205,10 +209,10 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a of FinAck: s.txState = Closed await s.actuallyClose() + if nbytes == 0: + return 0 else: discard - s.readData = decoded.data - result = min(nbytes, s.readData.len) copyMem(pbytes, addr s.readData[0], result) s.readData = s.readData[result..^1] @@ -216,7 +220,8 @@ method readOnce*(s: WebRtcStream, pbytes: pointer, nbytes: int): Future[int] {.a method closeImpl*(s: WebRtcStream) {.async.} = s.send(WebRtcMessage(flag: Opt.some(Fin))) s.txState = Closing - await s.join() #TODO ?? + while s.txState != Closed: + discard await s.readOnce(nil, 0) # -- Connection -- type WebRtcConnection = ref object of Connection @@ -236,13 +241,15 @@ proc new( co proc getStream*(conn: WebRtcConnection, - direction: Direction): Future[WebRtcStream] {.async.} = + direction: Direction, + noiseHandshake: bool = false): Future[WebRtcStream] {.async.} = var datachannel = case direction: of Direction.In: await conn.connection.accept() of Direction.Out: - await conn.connection.openStream(0) #TODO don't hardcode stream id (should be in nim-webrtc) + #TODO don't hardcode stream id (should be in nim-webrtc) + await conn.connection.openStream(noiseHandshake) return WebRtcStream.new(datachannel, conn.observedAddr, conn.peerId) # -- Muxer -- @@ -278,7 +285,10 @@ method close*(m: WebRtcMuxer) {.async, gcsafe.} = await m.webRtcConn.close() # -- Upgrader -- -type WebRtcUpgrade = ref object of Upgrade +type + WebRtcStreamHandler = proc(conn: Connection): Future[void] {.gcsafe, raises: [].} + WebRtcUpgrade = ref object of Upgrade + streamHandler: WebRtcStreamHandler method upgrade*( self: WebRtcUpgrade, @@ -287,7 +297,7 @@ method upgrade*( peerId: Opt[PeerId]): Future[Muxer] {.async.} = let webRtcConn = WebRtcConnection(conn) - result = WebRtcMuxer(webRtcConn: webRtcConn) + result = WebRtcMuxer(connection: conn, webRtcConn: webRtcConn) # Noise handshake let noiseHandler = self.secureManagers.filterIt(it of Noise) @@ -300,7 +310,7 @@ method upgrade*( echo "=> ", ((Noise)noiseHandler[0]).commonPrologue let - stream = await webRtcConn.getStream(Out) #TODO add channelId: 0 + stream = await webRtcConn.getStream(Out, true) #TODO add channelId: 0 secureStream = await noiseHandler[0].handshake( stream, initiator = true, # we are always the initiator in webrtc-direct @@ -311,6 +321,9 @@ method upgrade*( await secureStream.close() await stream.close() + result.streamHandler = self.streamHandler + result.handler = result.handle() + # -- Transport -- type WebRtcTransport* = ref object of Transport @@ -354,9 +367,24 @@ proc new*( upgrade: Upgrade, connectionsTimeout = 10.minutes): T {.public.} = + let upgrader = WebRtcUpgrade(ms: upgrade.ms, secureManagers: upgrade.secureManagers) + upgrader.streamHandler = proc(conn: Connection) + {.async, gcsafe, raises: [].} = + # TODO: replace echo by trace and find why it fails compiling + echo "Starting stream handler"#, conn + try: + await upgrader.ms.handle(conn) # handle incoming connection + except CancelledError as exc: + raise exc + except CatchableError as exc: + echo "exception in stream handler", exc.msg#, conn, msg = exc.msg + finally: + await conn.closeWithEOF() + echo "Stream handler done"#, conn + let transport = T( - upgrader: WebRtcUpgrade(secureManagers: upgrade.secureManagers), + upgrader: upgrader, connectionsTimeout: connectionsTimeout) return transport diff --git a/testwebrtc.nim b/testwebrtc.nim index efedbd56a..461a910ac 100644 --- a/testwebrtc.nim +++ b/testwebrtc.nim @@ -6,10 +6,13 @@ proc echoHandler(conn: Connection, proto: string) {.async.} = while true: try: echo "\e[35;1m => Echo Handler <=\e[0m" - let msg = string.fromBytes(await conn.readLp(1024)) + var xx = newSeq[byte](1024) + let aa = await conn.readOnce(addr xx[0], 1024) + xx = xx[0.. Echo Handler Receive: ", msg, " <=" echo " => Echo Handler Try Send: ", msg & "1", " <=" - await conn.writeLp(msg & "1") + await conn.write(msg & "1") except CatchableError as e: echo " => Echo Handler Error: ", e.msg, " <=" break @@ -20,6 +23,7 @@ proc main {.async.} = .withAddress(MultiAddress.init("/ip4/127.0.0.1/udp/4242/webrtc-direct/certhash/uEiDDq4_xNyDorZBH3TlGazyJdOWSwvo4PUo5YHFMrvDE8g").tryGet()) #TODO the certhash shouldn't be necessary .withRng(crypto.newRng()) .withMplex() + .withYamux() .withTransport(proc (upgr: Upgrade): Transport = WebRtcTransport.new(upgr)) .withNoise() .build()