diff --git a/libp2p/muxers/yamux/yamux.nim b/libp2p/muxers/yamux/yamux.nim index 667f60fba..45b262d5d 100644 --- a/libp2p/muxers/yamux/yamux.nim +++ b/libp2p/muxers/yamux/yamux.nim @@ -150,7 +150,10 @@ type conn: Connection isSrc: bool opened: bool - isSending: bool + + trySendFut: Future[void].Raising([CancelledError, LPStreamError]) + trySendEvent: AsyncEvent + sendQueue: seq[ToSend] recvQueue: seq[byte] isReset: bool @@ -233,7 +236,7 @@ proc reset(channel: YamuxChannel, isLocal: bool = false) {.async: (raises: []).} channel.recvQueue = @[] channel.sendWindow = 0 if not channel.closedLocally: - if isLocal and not channel.isSending: + if isLocal: try: await channel.conn.write(YamuxHeader.data(channel.id, 0, {Rst})) except CancelledError, LPStreamError: @@ -316,13 +319,13 @@ proc setMaxRecvWindow*(channel: YamuxChannel, maxRecvWindow: int) = proc trySend( channel: YamuxChannel ) {.async: (raises: [CancelledError, LPStreamError]).} = - if channel.isSending: - return - channel.isSending = true - defer: - channel.isSending = false - - while channel.sendQueue.len != 0: + channel.trySendEvent.clear() + while true: + if channel.sendQueue.len() == 0 or channel.sendWindow == 0: + await channel.trySendEvent.wait() + channel.trySendEvent.clear() + if channel.isReset or channel.closed(): + return channel.sendQueue.keepItIf(not (it.fut.cancelled() and it.sent == 0)) if channel.sendWindow == 0: trace "trying to send while the sendWindow is empty" @@ -331,7 +334,8 @@ proc trySend( maxSendQueueSize = channel.maxSendQueueSize, currentQueueSize = channel.lengthSendQueueWithLimit() await channel.reset(isLocal = true) - break + break + continue let bytesAvailable = channel.lengthSendQueue() @@ -403,7 +407,7 @@ method write*( channel.sendQueue.add((msg, 0, result)) when defined(libp2p_yamux_metrics): libp2p_yamux_send_queue.observe(channel.lengthSendQueue().int64) - asyncSpawn channel.trySend() + channel.trySendEvent.fire() proc open(channel: YamuxChannel) {.async: (raises: [CancelledError, LPStreamError]).} = ## Open a yamux channel by sending a window update with Syn or Ack flag @@ -442,6 +446,8 @@ proc lenBySrc(m: Yamux, isSrc: bool): int = proc cleanupChannel(m: Yamux, channel: YamuxChannel) {.async: (raises: []).} = try: await channel.join() + if not channel.trySendFut.finished(): + await channel.trySendFut.cancelAndWait() except CancelledError: discard m.channels.del(channel.id) @@ -490,6 +496,8 @@ proc createStream( stream.shortAgent = m.connection.shortAgent m.channels[id] = stream asyncSpawn m.cleanupChannel(stream) + stream.trySendEvent = newAsyncEvent() + stream.trySendFut = stream.trySend() trace "created channel", id, pid = m.connection.peerId when defined(libp2p_yamux_metrics): libp2p_yamux_channels.set(m.lenBySrc(isSrc).int64, [$isSrc, $stream.peerId]) @@ -587,7 +595,7 @@ method handle*(m: Yamux) {.async: (raises: []).} = if header.msgType == WindowUpdate: channel.sendWindow += int(header.length) - await channel.trySend() + channel.trySendEvent.fire() else: if header.length.int > channel.recvWindow.int: # check before allocating the buffer diff --git a/tests/testyamux.nim b/tests/testyamux.nim index 5e071e567..2628ed0d9 100644 --- a/tests/testyamux.nim +++ b/tests/testyamux.nim @@ -192,6 +192,7 @@ suite "Yamux": asyncTest "Saturate until reset": mSetup() let writerBlocker = newBlockerFut() + let readerBlocker = newBlockerFut() yamuxb.streamHandler = proc(conn: Connection) {.async: (raises: []).} = await writerBlocker try: @@ -201,6 +202,7 @@ suite "Yamux": except CancelledError, LPStreamError: return finally: + readerBlocker.complete() await conn.close() let streamA = await yamuxa.newStream() @@ -213,7 +215,9 @@ suite "Yamux": for i in 0 .. 3: expect(LPStreamEOFError): await wrFut[i] + await sleepAsync(50.millis) # waiting for reset to be send writerBlocker.complete() + await readerBlocker await streamA.close() asyncTest "Increase window size":