diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 54d0da5da..5f6faa809 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -58,6 +58,8 @@ type initiator*: bool # initiated remotely or locally flag isOpen*: bool # has channel been opened closedLocal*: bool # has channel been closed locally + remoteReset*: bool # has channel been remotely reset + localReset*: bool # has channel been reset locally msgCode*: MessageType # cached in/out message code closeCode*: MessageType # cached in/out close code resetCode*: MessageType # cached in/out reset code @@ -103,6 +105,7 @@ proc reset*(s: LPChannel) {.async, gcsafe.} = s.isClosed = true s.closedLocal = true + s.localReset = not s.remoteReset trace "Resetting channel", s, len = s.len @@ -168,6 +171,14 @@ method readOnce*(s: LPChannel, ## channels are blocked - in particular, this means that reading from one ## channel must not be done from within a callback / read handler of another ## or the reads will lock each other. + if s.remoteReset: + raise newLPStreamResetError() + if s.localReset: + raise newLPStreamClosedError() + if s.atEof(): + raise newLPStreamRemoteClosedError() + if s.conn.closed: + raise newLPStreamConnDownError() try: let bytes = await procCall BufferStream(s).readOnce(pbytes, nbytes) when defined(libp2p_network_protocols_metrics): @@ -184,13 +195,17 @@ method readOnce*(s: LPChannel, # data has been lost in s.readBuf and there's no way to gracefully recover / # use the channel any more await s.reset() - raise exc + raise newLPStreamConnDownError(exc) proc prepareWrite(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = # prepareWrite is the slow path of writing a message - see conditions in # write - if s.closedLocal or s.conn.closed: + if s.remoteReset: + raise newLPStreamResetError() + if s.closedLocal: raise newLPStreamClosedError() + if s.conn.closed: + raise newLPStreamConnDownError() if msg.len == 0: return @@ -235,7 +250,7 @@ proc completeWrite( trace "exception in lpchannel write handler", s, msg = exc.msg await s.reset() await s.conn.close() - raise exc + raise newLPStreamConnDownError(exc) finally: s.writes -= 1 diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 90838120d..fc0294c2d 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -183,6 +183,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.CloseIn, MessageType.CloseOut: await channel.pushEof() of MessageType.ResetIn, MessageType.ResetOut: + channel.remoteReset = true await channel.reset() except CancelledError: debug "Unexpected cancellation in mplex handler", m diff --git a/libp2p/muxers/yamux/yamux.nim b/libp2p/muxers/yamux/yamux.nim index 83deb1009..a5273794d 100644 --- a/libp2p/muxers/yamux/yamux.nim +++ b/libp2p/muxers/yamux/yamux.nim @@ -153,6 +153,7 @@ type sendQueue: seq[ToSend] recvQueue: seq[byte] isReset: bool + remoteReset: bool closedRemotely: Future[void] closedLocally: bool receivedData: AsyncEvent @@ -194,23 +195,25 @@ method closeImpl*(channel: YamuxChannel) {.async, gcsafe.} = await channel.actuallyClose() proc reset(channel: YamuxChannel, isLocal: bool = false) {.async.} = - if not channel.isReset: - trace "Reset channel" - channel.isReset = true - for (d, s, fut) in channel.sendQueue: - fut.fail(newLPStreamEOFError()) - channel.sendQueue = @[] - channel.recvQueue = @[] - channel.sendWindow = 0 - if not channel.closedLocally: - if isLocal: - try: await channel.conn.write(YamuxHeader.data(channel.id, 0, {Rst})) - except LPStreamEOFError as exc: discard - except LPStreamClosedError as exc: discard - await channel.close() - if not channel.closedRemotely.done(): - await channel.remoteClosed() - channel.receivedData.fire() + if channel.isReset: + return + trace "Reset channel" + channel.isReset = true + channel.remoteReset = not isLocal + for (d, s, fut) in channel.sendQueue: + fut.fail(newLPStreamEOFError()) + channel.sendQueue = @[] + channel.recvQueue = @[] + channel.sendWindow = 0 + if not channel.closedLocally: + if isLocal: + try: await channel.conn.write(YamuxHeader.data(channel.id, 0, {Rst})) + except LPStreamEOFError as exc: discard + except LPStreamClosedError as exc: discard + await channel.close() + if not channel.closedRemotely.done(): + await channel.remoteClosed() + channel.receivedData.fire() if not isLocal: # If we reset locally, we want to flush up to a maximum of recvWindow # bytes. We use the recvWindow in the proc cleanupChann. @@ -235,7 +238,15 @@ method readOnce*( nbytes: int): Future[int] {.async.} = - if channel.returnedEof: raise newLPStreamEOFError() + if channel.isReset: + raise if channel.remoteReset: + newLPStreamResetError() + elif channel.closedLocally: + newLPStreamClosedError() + else: + newLPStreamConnDownError() + if channel.returnedEof: + raise newLPStreamRemoteClosedError() if channel.recvQueue.len == 0: channel.receivedData.clear() await channel.closedRemotely or channel.receivedData.wait() @@ -313,8 +324,9 @@ proc trySend(channel: YamuxChannel) {.async.} = channel.sendWindow.dec(toSend) try: await channel.conn.write(sendBuffer) except CatchableError as exc: + let connDown = newLPStreamConnDownError(exc) for fut in futures.items(): - fut.fail(exc) + fut.fail(connDown) await channel.reset() break for fut in futures.items(): @@ -323,8 +335,11 @@ proc trySend(channel: YamuxChannel) {.async.} = method write*(channel: YamuxChannel, msg: seq[byte]): Future[void] = result = newFuture[void]("Yamux Send") + if channel.remoteReset: + result.fail(newLPStreamResetError()) + return result if channel.closedLocally or channel.isReset: - result.fail(newLPStreamEOFError()) + result.fail(newLPStreamClosedError()) return result if msg.len == 0: result.complete() @@ -396,8 +411,9 @@ method close*(m: Yamux) {.async.} = m.isClosed = true trace "Closing yamux" - for channel in m.channels.values: - await channel.reset() + let channels = toSeq(m.channels.values()) + for channel in channels: + await channel.reset(true) await m.connection.write(YamuxHeader.goAway(NormalTermination)) await m.connection.close() trace "Closed yamux" @@ -453,8 +469,9 @@ method handle*(m: Yamux) {.async, gcsafe.} = m.flushed[header.streamId].dec(int(header.length)) if m.flushed[header.streamId] < 0: raise newException(YamuxError, "Peer exhausted the recvWindow after reset") - var buffer = newSeqUninitialized[byte](header.length) - await m.connection.readExactly(addr buffer[0], int(header.length)) + if header.length > 0: + var buffer = newSeqUninitialized[byte](header.length) + await m.connection.readExactly(addr buffer[0], int(header.length)) continue let channel = m.channels[header.streamId] diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 6eb83ed0c..68cf862c7 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -79,7 +79,7 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} = &"Only one concurrent push allowed for stream {s.shortLog()}") if s.isClosed or s.pushedEof: - raise newLPStreamEOFError() + raise newLPStreamClosedError() if data.len == 0: return # Don't push 0-length buffers, these signal EOF diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 6857da3f4..fb9401a9a 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -59,7 +59,18 @@ type LPStreamWriteError* = object of LPStreamError par*: ref CatchableError LPStreamEOFError* = object of LPStreamError - LPStreamClosedError* = object of LPStreamError + +# X | Read | Write +# Local close | Works | LPStreamClosedError +# Remote close | LPStreamRemoteClosedError | Works +# Local reset | LPStreamClosedError | LPStreamClosedError +# Remote reset | LPStreamResetError | LPStreamResetError +# Connection down | LPStreamConnDown | LPStreamConnDownError + + LPStreamResetError* = object of LPStreamEOFError + LPStreamClosedError* = object of LPStreamEOFError + LPStreamRemoteClosedError* = object of LPStreamEOFError + LPStreamConnDownError* = object of LPStreamEOFError InvalidVarintError* = object of LPStreamError MaxSizeError* = object of LPStreamError @@ -119,9 +130,22 @@ proc newLPStreamIncorrectDefect*(m: string): ref LPStreamIncorrectDefect = proc newLPStreamEOFError*(): ref LPStreamEOFError = result = newException(LPStreamEOFError, "Stream EOF!") +proc newLPStreamResetError*(): ref LPStreamResetError = + result = newException(LPStreamResetError, "Stream Reset!") + proc newLPStreamClosedError*(): ref LPStreamClosedError = result = newException(LPStreamClosedError, "Stream Closed!") +proc newLPStreamRemoteClosedError*(): ref LPStreamRemoteClosedError = + result = newException(LPStreamRemoteClosedError, "Stream Remotely Closed!") + +proc newLPStreamConnDownError*( + parentException: ref Exception = nil): ref LPStreamConnDownError = + result = newException( + LPStreamConnDownError, + "Stream Underlying Connection Closed!", + parentException) + func shortLog*(s: LPStream): auto = if s.isNil: "LPStream(nil)" else: $s.oid @@ -165,6 +189,8 @@ proc readExactly*(s: LPStream, ## Waits for `nbytes` to be available, then read ## them and return them if s.atEof: + var ch: char + discard await s.readOnce(addr ch, 1) raise newLPStreamEOFError() if nbytes == 0: @@ -183,6 +209,10 @@ proc readExactly*(s: LPStream, if read == 0: doAssert s.atEof() trace "couldn't read all bytes, stream EOF", s, nbytes, read + # Re-readOnce to raise a more specific error than EOF + # Raise EOF if it doesn't raise anything(shouldn't happen) + discard await s.readOnce(addr pbuffer[read], nbytes - read) + warn "Read twice while at EOF" raise newLPStreamEOFError() if read < nbytes: @@ -200,8 +230,7 @@ proc readLine*(s: LPStream, while true: var ch: char - if (await readOnce(s, addr ch, 1)) == 0: - raise newLPStreamEOFError() + await readExactly(s, addr ch, 1) if sep[state] == ch: inc(state) @@ -224,8 +253,7 @@ proc readVarint*(conn: LPStream): Future[uint64] {.async, gcsafe, public.} = buffer: array[10, byte] for i in 0..