diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index 412f1625..78cae403 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -109,6 +109,7 @@ type WritePending, # Writer operation pending (Windows) WritePaused, # Writer operations paused WriteClosed, # Writer operations closed + WriteEof, # Remote peer disconnected WriteError # Write error var diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index c0d4cc00..e92bb91c 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -245,6 +245,12 @@ proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} = result.isLeaked = leakServer addTracker(StreamServerTrackerName, result) +proc completePendingWriteQueue(queue: var Deque[StreamVector], + v: int) {.inline.} = + while len(queue) > 0: + var vector = queue.popFirst() + vector.writer.complete(v) + when defined(windows): template zeroOvelappedOffset(t: untyped) = @@ -274,6 +280,11 @@ when defined(windows): (t).wwsabuf.buf = cast[cstring](v.buf) (t).wwsabuf.len = cast[int32](v.buflen) + proc isConnResetError(err: OSErrorCode): bool {.inline.} = + result = (err == OSErrorCode(WSAECONNRESET)) or + (err == OSErrorCode(WSAECONNABORTED)) or + (err == OSErrorCode(ERROR_PIPE_NOT_CONNECTED)) + proc writeStreamLoop(udata: pointer) {.gcsafe, nimcall.} = var bytesCount: int32 var ovl = cast[PtrCustomOverlapped](udata) @@ -318,8 +329,16 @@ when defined(windows): break else: let v = transp.queue.popFirst() - transp.state.incl(WriteError) - v.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl(WriteEof) + v.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + break + else: + transp.state.incl(WriteError) + v.writer.fail(getTransportOsError(err)) else: ## Initiation transp.state.incl(WritePending) @@ -343,8 +362,16 @@ when defined(windows): transp.queue.addFirst(vector) else: transp.state.excl(WritePending) - transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl({WritePaused, WriteEof}) + vector.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + break + else: + transp.state.incl({WritePaused, WriteError}) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) else: @@ -372,8 +399,16 @@ when defined(windows): transp.queue.addFirst(vector) else: transp.state.excl(WritePending) - transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl({WritePaused, WriteEof}) + vector.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + break + else: + transp.state.incl({WritePaused, WriteError}) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) elif transp.kind == TransportKind.Pipe: @@ -401,8 +436,16 @@ when defined(windows): vector.writer.complete(0) else: transp.state.excl(WritePending) - transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl({WritePaused, WriteEof}) + vector.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + break + else: + transp.state.incl({WritePaused, WriteError}) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) break @@ -877,6 +920,10 @@ else: (v).buflen = int(n) (v).writer = (t) + proc isConnResetError(err: OSErrorCode): bool {.inline.} = + result = (err == OSErrorCode(ECONNRESET)) or + (err == OSErrorCode(EPIPE)) + proc writeStreamLoop(udata: pointer) {.gcsafe.} = var cdata = cast[ptr CompletionData](udata) var transp = cast[StreamTransport](cdata.udata) @@ -904,7 +951,15 @@ else: if int(err) == EINTR: continue else: - vector.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl({WriteEof, WritePaused}) + vector.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + transp.fd.removeWriter() + else: + vector.writer.fail(getTransportOsError(err)) else: var nbytes = cast[int](vector.buf) let res = sendfile(int(fd), cast[int](vector.buflen), @@ -923,7 +978,15 @@ else: if int(err) == EINTR: continue else: - vector.writer.fail(getTransportOsError(err)) + if isConnResetError(err): + # Soft error happens which indicates that remote peer got + # disconnected, complete all pending writes in queue with 0. + transp.state.incl({WriteEof, WritePaused}) + vector.writer.complete(0) + completePendingWriteQueue(transp.queue, 0) + transp.fd.removeWriter() + else: + vector.writer.fail(getTransportOsError(err)) break else: transp.state.incl(WritePaused) @@ -1475,7 +1538,7 @@ proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int, ## Read data from the transport ``transp`` until separator ``sep`` is found. ## ## On success, the data and separator will be removed from the internal - ## buffer (consumed). Returned data will NOT include the separator at the end. + ## buffer (consumed). Returned data will include the separator at the end. ## ## If EOF is received, and `sep` was not found, procedure will raise ## ``TransportIncompleteError``. @@ -1579,7 +1642,7 @@ proc readLine*(transp: StreamTransport, limit = 0, await fut proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = - ## Read all bytes (n == -1) or exactly `n` bytes from transport ``transp``. + ## Read all bytes (n <= 0) or exactly `n` bytes from transport ``transp``. ## ## This procedure allocates buffer seq[byte] and return it as result. checkClosed(transp) @@ -1594,7 +1657,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = if transp.offset > 0: let s = len(result) let o = s + transp.offset - if n < 0: + if n <= 0: # grabbing all incoming data, until EOF result.setLen(o) copyMem(cast[pointer](addr result[s]), addr(transp.buffer[0]), @@ -1637,7 +1700,7 @@ proc consume*(transp: StreamTransport, n = -1): Future[int] {.async.} = break if transp.offset > 0: - if n == -1: + if n <= 0: # consume all incoming data, until EOF result += transp.offset transp.offset = 0 diff --git a/tests/teststream.nim b/tests/teststream.nim index 9ebbfd19..6597c2d6 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -692,6 +692,28 @@ suite "Stream Transport test suite": except: discard + proc testWriteConnReset(address: TransportAddress): Future[int] {.async.} = + proc client(server: StreamServer, transp: StreamTransport) {.async.} = + await transp.closeWait() + var n = 10 + var server = createStreamServer(address, client, {ReuseAddr}) + server.start() + var msg = "HELLO" + var ntransp = await connect(address) + while true: + var res = await ntransp.write(msg) + if res == 0: + result = 1 + break + else: + dec(n) + if n == 0: + break + + server.stop() + await ntransp.closeWait() + await server.closeWait() + for i in 0..