prevent http `closeWait` future from being cancelled (#486)

* simplify `closeWait` implementations
  * remove redundant cancellation callbacks
  * use `noCancel` to avoid forgetting the right future flags
* add a few missing raises trackers
* enforce `OwnCancelSchedule` on manually created futures that don't raise `CancelledError`
* ensure cancellations don't reach internal futures
This commit is contained in:
Jacek Sieka 2024-01-04 16:17:42 +01:00 committed by GitHub
parent 41f77d261e
commit e15dc3b41f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 152 additions and 176 deletions

View File

@ -43,7 +43,7 @@ proc closeWait*(bstream: HttpBodyReader) {.async: (raises: []).} =
## Close and free resource allocated by body reader.
if bstream.bstate == HttpState.Alive:
bstream.bstate = HttpState.Closing
var res = newSeq[Future[void]]()
var res = newSeq[Future[void].Raising([])]()
# We closing streams in reversed order because stream at position [0], uses
# data from stream at position [1].
for index in countdown((len(bstream.streams) - 1), 0):
@ -68,7 +68,7 @@ proc closeWait*(bstream: HttpBodyWriter) {.async: (raises: []).} =
## Close and free all the resources allocated by body writer.
if bstream.bstate == HttpState.Alive:
bstream.bstate = HttpState.Closing
var res = newSeq[Future[void]]()
var res = newSeq[Future[void].Raising([])]()
for index in countdown(len(bstream.streams) - 1, 0):
res.add(bstream.streams[index].closeWait())
await noCancel(allFutures(res))

View File

@ -294,7 +294,7 @@ proc new*(t: typedesc[HttpSessionRef],
if HttpClientFlag.Http11Pipeline in flags:
sessionWatcher(res)
else:
Future[void].Raising([]).init("session.watcher.placeholder")
nil
res
proc getTLSFlags(flags: HttpClientFlags): set[TLSFlags] =
@ -607,7 +607,7 @@ proc closeWait(conn: HttpClientConnectionRef) {.async: (raises: []).} =
conn.state = HttpClientConnectionState.Closing
let pending =
block:
var res: seq[Future[void]]
var res: seq[Future[void].Raising([])]
if not(isNil(conn.reader)) and not(conn.reader.closed()):
res.add(conn.reader.closeWait())
if not(isNil(conn.writer)) and not(conn.writer.closed()):
@ -847,14 +847,14 @@ proc sessionWatcher(session: HttpSessionRef) {.async: (raises: []).} =
break
proc closeWait*(request: HttpClientRequestRef) {.async: (raises: []).} =
var pending: seq[FutureBase]
var pending: seq[Future[void].Raising([])]
if request.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
request.state = HttpReqRespState.Closing
if not(isNil(request.writer)):
if not(request.writer.closed()):
pending.add(FutureBase(request.writer.closeWait()))
pending.add(request.writer.closeWait())
request.writer = nil
pending.add(FutureBase(request.releaseConnection()))
pending.add(request.releaseConnection())
await noCancel(allFutures(pending))
request.session = nil
request.error = nil
@ -862,14 +862,14 @@ proc closeWait*(request: HttpClientRequestRef) {.async: (raises: []).} =
untrackCounter(HttpClientRequestTrackerName)
proc closeWait*(response: HttpClientResponseRef) {.async: (raises: []).} =
var pending: seq[FutureBase]
var pending: seq[Future[void].Raising([])]
if response.state notin {HttpReqRespState.Closing, HttpReqRespState.Closed}:
response.state = HttpReqRespState.Closing
if not(isNil(response.reader)):
if not(response.reader.closed()):
pending.add(FutureBase(response.reader.closeWait()))
pending.add(response.reader.closeWait())
response.reader = nil
pending.add(FutureBase(response.releaseConnection()))
pending.add(response.releaseConnection())
await noCancel(allFutures(pending))
response.session = nil
response.error = nil

View File

@ -523,15 +523,13 @@ proc closeWait*(ab: AsyncEventQueue): Future[void] {.
{FutureFlag.OwnCancelSchedule})
proc continuation(udata: pointer) {.gcsafe.} =
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
# We are not going to change the state of `retFuture` to cancelled, so we
# will prevent the entire sequence of Futures from being cancelled.
discard
# Ignore cancellation requests - we'll complete the future soon enough
retFuture.cancelCallback = nil
ab.close()
# Schedule `continuation` to be called only after all the `reader`
# notifications will be scheduled and processed.
retFuture.cancelCallback = cancellation
callSoon(continuation)
retFuture

View File

@ -34,6 +34,19 @@ type
FutureFlag* {.pure.} = enum
OwnCancelSchedule
## When OwnCancelSchedule is set, the owner of the future is responsible
## for implementing cancellation in one of 3 ways:
##
## * ensure that cancellation requests never reach the future by means of
## not exposing it to user code, `await` and `tryCancel`
## * set `cancelCallback` to `nil` to stop cancellation propagation - this
## is appropriate when it is expected that the future will be completed
## in a regular way "soon"
## * set `cancelCallback` to a handler that implements cancellation in an
## operation-specific way
##
## If `cancelCallback` is not set and the future gets cancelled, a
## `Defect` will be raised.
FutureFlags* = set[FutureFlag]
@ -104,6 +117,12 @@ proc internalInitFutureBase*(fut: FutureBase, loc: ptr SrcLoc,
fut.internalState = state
fut.internalLocation[LocationKind.Create] = loc
fut.internalFlags = flags
if FutureFlag.OwnCancelSchedule in flags:
# Owners must replace `cancelCallback` with `nil` if they want to ignore
# cancellations
fut.internalCancelcb = proc(_: pointer) =
raiseAssert "Cancellation request for non-cancellable future"
if state != FutureState.Pending:
fut.internalLocation[LocationKind.Finish] = loc

View File

@ -1013,6 +1013,7 @@ proc cancelAndWait*(future: FutureBase, loc: ptr SrcLoc): Future[void] {.
if future.finished():
retFuture.complete()
else:
retFuture.cancelCallback = nil
cancelSoon(future, continuation, cast[pointer](retFuture), loc)
retFuture
@ -1057,6 +1058,7 @@ proc noCancel*[F: SomeFuture](future: F): auto = # async: (raw: true, raises: as
if future.finished():
completeFuture()
else:
retFuture.cancelCallback = nil
future.addCallback(continuation)
retFuture

View File

@ -18,45 +18,6 @@ proc makeNoRaises*(): NimNode {.compileTime.} =
ident"void"
macro Raising*[T](F: typedesc[Future[T]], E: varargs[typedesc]): untyped =
## Given a Future type instance, return a type storing `{.raises.}`
## information
##
## Note; this type may change in the future
E.expectKind(nnkBracket)
let raises = if E.len == 0:
makeNoRaises()
else:
nnkTupleConstr.newTree(E.mapIt(it))
nnkBracketExpr.newTree(
ident "InternalRaisesFuture",
nnkDotExpr.newTree(F, ident"T"),
raises
)
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = ""): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
let res = F()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {})
res
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = "",
flags: static[FutureFlags]): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
let res = F()
internalInitFutureBase(
res, getSrcLocation(fromProc), FutureState.Pending, flags)
res
proc dig(n: NimNode): NimNode {.compileTime.} =
# Dig through the layers of type to find the raises list
if n.eqIdent("void"):
@ -87,6 +48,58 @@ proc members(tup: NimNode): seq[NimNode] {.compileTime.} =
for t in tup.members():
result.add(t)
macro hasException(raises: typedesc, ident: static string): bool =
newLit(raises.members.anyIt(it.eqIdent(ident)))
macro Raising*[T](F: typedesc[Future[T]], E: varargs[typedesc]): untyped =
## Given a Future type instance, return a type storing `{.raises.}`
## information
##
## Note; this type may change in the future
E.expectKind(nnkBracket)
let raises = if E.len == 0:
makeNoRaises()
else:
nnkTupleConstr.newTree(E.mapIt(it))
nnkBracketExpr.newTree(
ident "InternalRaisesFuture",
nnkDotExpr.newTree(F, ident"T"),
raises
)
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = ""): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
when not hasException(type(E), "CancelledError"):
static:
raiseAssert "Manually created futures must either own cancellation schedule or raise CancelledError"
let res = F()
internalInitFutureBase(res, getSrcLocation(fromProc), FutureState.Pending, {})
res
template init*[T, E](
F: type InternalRaisesFuture[T, E], fromProc: static[string] = "",
flags: static[FutureFlags]): F =
## Creates a new pending future.
##
## Specifying ``fromProc``, which is a string specifying the name of the proc
## that this future belongs to, is a good habit as it helps with debugging.
let res = F()
when not hasException(type(E), "CancelledError"):
static:
doAssert FutureFlag.OwnCancelSchedule in flags,
"Manually created futures must either own cancellation schedule or raise CancelledError"
internalInitFutureBase(
res, getSrcLocation(fromProc), FutureState.Pending, flags)
res
proc containsSignature(members: openArray[NimNode], typ: NimNode): bool {.compileTime.} =
let typHash = signatureHash(typ)

View File

@ -77,7 +77,7 @@ type
udata: pointer
error*: ref AsyncStreamError
bytesCount*: uint64
future: Future[void]
future: Future[void].Raising([])
AsyncStreamWriter* = ref object of RootRef
wsource*: AsyncStreamWriter
@ -88,7 +88,7 @@ type
error*: ref AsyncStreamError
udata: pointer
bytesCount*: uint64
future: Future[void]
future: Future[void].Raising([])
AsyncStream* = object of RootObj
reader*: AsyncStreamReader
@ -897,44 +897,27 @@ proc close*(rw: AsyncStreamRW) =
rw.future.addCallback(continuation)
rw.future.cancelSoon()
proc closeWait*(rw: AsyncStreamRW): Future[void] {.
async: (raw: true, raises: []).} =
proc closeWait*(rw: AsyncStreamRW): Future[void] {.async: (raises: []).} =
## Close and frees resources of stream ``rw``.
const FutureName =
when rw is AsyncStreamReader:
"async.stream.reader.closeWait"
else:
"async.stream.writer.closeWait"
let retFuture = Future[void].Raising([]).init(FutureName)
if rw.closed():
retFuture.complete()
return retFuture
proc continuation(udata: pointer) {.gcsafe, raises:[].} =
retFuture.complete()
if not rw.closed():
rw.close()
if rw.future.finished():
retFuture.complete()
else:
rw.future.addCallback(continuation, cast[pointer](retFuture))
retFuture
await noCancel(rw.join())
proc startReader(rstream: AsyncStreamReader) =
rstream.state = Running
if not isNil(rstream.readerLoop):
rstream.future = rstream.readerLoop(rstream)
else:
rstream.future = newFuture[void]("async.stream.empty.reader")
rstream.future = Future[void].Raising([]).init(
"async.stream.empty.reader", {FutureFlag.OwnCancelSchedule})
proc startWriter(wstream: AsyncStreamWriter) =
wstream.state = Running
if not isNil(wstream.writerLoop):
wstream.future = wstream.writerLoop(wstream)
else:
wstream.future = newFuture[void]("async.stream.empty.writer")
wstream.future = Future[void].Raising([]).init(
"async.stream.empty.writer", {FutureFlag.OwnCancelSchedule})
proc init*(child, wsource: AsyncStreamWriter, loop: StreamWriterLoop,
queueSize = AsyncStreamDefaultQueueSize) =

View File

@ -73,7 +73,7 @@ when defined(windows) or defined(nimdoc):
udata*: pointer # User-defined pointer
flags*: set[ServerFlags] # Flags
bufferSize*: int # Size of internal transports' buffer
loopFuture*: Future[void] # Server's main Future
loopFuture*: Future[void].Raising([]) # Server's main Future
domain*: Domain # Current server domain (IPv4 or IPv6)
apending*: bool
asock*: AsyncFD # Current AcceptEx() socket
@ -92,7 +92,7 @@ else:
udata*: pointer # User-defined pointer
flags*: set[ServerFlags] # Flags
bufferSize*: int # Size of internal transports' buffer
loopFuture*: Future[void] # Server's main Future
loopFuture*: Future[void].Raising([]) # Server's main Future
errorCode*: OSErrorCode # Current error code
dualstack*: DualStackType # IPv4/IPv6 dualstack parameters

View File

@ -44,7 +44,7 @@ type
remote: TransportAddress # Remote address
udata*: pointer # User-driven pointer
function: DatagramCallback # Receive data callback
future: Future[void] # Transport's life future
future: Future[void].Raising([]) # Transport's life future
raddr: Sockaddr_storage # Reader address storage
ralen: SockLen # Reader address length
waddr: Sockaddr_storage # Writer address storage
@ -359,7 +359,8 @@ when defined(windows):
res.queue = initDeque[GramVector]()
res.udata = udata
res.state = {ReadPaused, WritePaused}
res.future = newFuture[void]("datagram.transport")
res.future = Future[void].Raising([]).init(
"datagram.transport", {FutureFlag.OwnCancelSchedule})
res.rovl.data = CompletionData(cb: readDatagramLoop,
udata: cast[pointer](res))
res.wovl.data = CompletionData(cb: writeDatagramLoop,
@ -568,7 +569,8 @@ else:
res.queue = initDeque[GramVector]()
res.udata = udata
res.state = {ReadPaused, WritePaused}
res.future = newFuture[void]("datagram.transport")
res.future = Future[void].Raising([]).init(
"datagram.transport", {FutureFlag.OwnCancelSchedule})
GC_ref(res)
# Start tracking transport
trackCounter(DgramTransportTrackerName)
@ -840,31 +842,16 @@ proc join*(transp: DatagramTransport): Future[void] {.
return retFuture
proc closed*(transp: DatagramTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
{ReadClosed, WriteClosed} * transp.state != {}
proc closeWait*(transp: DatagramTransport): Future[void] {.
async: (raw: true, raises: []).} =
async: (raises: []).} =
## Close transport ``transp`` and release all resources.
let retFuture = newFuture[void](
"datagram.transport.closeWait", {FutureFlag.OwnCancelSchedule})
if {ReadClosed, WriteClosed} * transp.state != {}:
retFuture.complete()
return retFuture
proc continuation(udata: pointer) {.gcsafe.} =
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
# We are not going to change the state of `retFuture` to cancelled, so we
# will prevent the entire sequence of Futures from being cancelled.
discard
if not transp.closed():
transp.close()
if transp.future.finished():
retFuture.complete()
else:
transp.future.addCallback(continuation, cast[pointer](retFuture))
retFuture.cancelCallback = cancellation
retFuture
await noCancel(transp.join())
proc send*(transp: DatagramTransport, pbytes: pointer,
nbytes: int): Future[void] {.
@ -1020,7 +1007,3 @@ proc getMessage*(transp: DatagramTransport): seq[byte] {.
proc getUserData*[T](transp: DatagramTransport): T {.inline.} =
## Obtain user data stored in ``transp`` object.
cast[T](transp.udata)
proc closed*(transp: DatagramTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
{ReadClosed, WriteClosed} * transp.state != {}

View File

@ -76,7 +76,7 @@ when defined(windows):
offset: int # Reading buffer offset
error: ref TransportError # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
future: Future[void].Raising([]) # Stream life future
# Windows specific part
rwsabuf: WSABUF # Reader WSABUF
wwsabuf: WSABUF # Writer WSABUF
@ -103,7 +103,7 @@ else:
offset: int # Reading buffer offset
error: ref TransportError # Current error
queue: Deque[StreamVector] # Writer queue
future: Future[void] # Stream life future
future: Future[void].Raising([]) # Stream life future
case kind*: TransportKind
of TransportKind.Socket:
domain: Domain # Socket transport domain (IPv4/IPv6)
@ -598,7 +598,8 @@ when defined(windows):
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.socket.transport")
transp.future = Future[void].Raising([]).init(
"stream.socket.transport", {FutureFlag.OwnCancelSchedule})
GC_ref(transp)
transp
@ -619,7 +620,8 @@ when defined(windows):
transp.flags = flags
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("stream.pipe.transport")
transp.future = Future[void].Raising([]).init(
"stream.pipe.transport", {FutureFlag.OwnCancelSchedule})
GC_ref(transp)
transp
@ -1457,7 +1459,8 @@ else:
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("socket.stream.transport")
transp.future = Future[void].Raising([]).init(
"socket.stream.transport", {FutureFlag.OwnCancelSchedule})
GC_ref(transp)
transp
@ -1473,7 +1476,8 @@ else:
transp.buffer = newSeq[byte](bufsize)
transp.state = {ReadPaused, WritePaused}
transp.queue = initDeque[StreamVector]()
transp.future = newFuture[void]("pipe.stream.transport")
transp.future = Future[void].Raising([]).init(
"pipe.stream.transport", {FutureFlag.OwnCancelSchedule})
GC_ref(transp)
transp
@ -1806,6 +1810,9 @@ proc connect*(address: TransportAddress,
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
connect(address, bufferSize, child, localAddress, mappedFlags, dualstack)
proc closed*(server: StreamServer): bool =
server.status == ServerStatus.Closed
proc close*(server: StreamServer) =
## Release ``server`` resources.
##
@ -1832,22 +1839,11 @@ proc close*(server: StreamServer) =
else:
server.sock.closeSocket(continuation)
proc closeWait*(server: StreamServer): Future[void] {.
async: (raw: true, raises: []).} =
proc closeWait*(server: StreamServer): Future[void] {.async: (raises: []).} =
## Close server ``server`` and release all resources.
let retFuture = newFuture[void](
"stream.server.closeWait", {FutureFlag.OwnCancelSchedule})
proc continuation(udata: pointer) =
retFuture.complete()
if not server.closed():
server.close()
if not(server.loopFuture.finished()):
server.loopFuture.addCallback(continuation, cast[pointer](retFuture))
else:
retFuture.complete()
retFuture
await noCancel(server.join())
proc getBacklogSize(backlog: int): cint =
doAssert(backlog >= 0 and backlog <= high(int32))
@ -2058,7 +2054,9 @@ proc createStreamServer*(host: TransportAddress,
sres.init = init
sres.bufferSize = bufferSize
sres.status = Starting
sres.loopFuture = newFuture[void]("stream.transport.server")
sres.loopFuture = asyncloop.init(
Future[void].Raising([]), "stream.transport.server",
{FutureFlag.OwnCancelSchedule})
sres.udata = udata
sres.dualstack = dualstack
if localAddress.family == AddressFamily.None:
@ -2630,6 +2628,23 @@ proc join*(transp: StreamTransport): Future[void] {.
retFuture.complete()
return retFuture
proc closed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
({ReadClosed, WriteClosed} * transp.state != {})
proc finished*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in finished (EOF) state.
({ReadEof, WriteEof} * transp.state != {})
proc failed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in error state.
({ReadError, WriteError} * transp.state != {})
proc running*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport is still pending.
({ReadClosed, ReadEof, ReadError,
WriteClosed, WriteEof, WriteError} * transp.state == {})
proc close*(transp: StreamTransport) =
## Closes and frees resources of transport ``transp``.
##
@ -2672,31 +2687,11 @@ proc close*(transp: StreamTransport) =
elif transp.kind == TransportKind.Socket:
closeSocket(transp.fd, continuation)
proc closeWait*(transp: StreamTransport): Future[void] {.
async: (raw: true, raises: []).} =
proc closeWait*(transp: StreamTransport): Future[void] {.async: (raises: []).} =
## Close and frees resources of transport ``transp``.
let retFuture = newFuture[void](
"stream.transport.closeWait", {FutureFlag.OwnCancelSchedule})
if {ReadClosed, WriteClosed} * transp.state != {}:
retFuture.complete()
return retFuture
proc continuation(udata: pointer) {.gcsafe.} =
retFuture.complete()
proc cancellation(udata: pointer) {.gcsafe.} =
# We are not going to change the state of `retFuture` to cancelled, so we
# will prevent the entire sequence of Futures from being cancelled.
discard
if not transp.closed():
transp.close()
if transp.future.finished():
retFuture.complete()
else:
transp.future.addCallback(continuation, cast[pointer](retFuture))
retFuture.cancelCallback = cancellation
retFuture
await noCancel(transp.join())
proc shutdownWait*(transp: StreamTransport): Future[void] {.
async: (raw: true, raises: [TransportError, CancelledError]).} =
@ -2756,23 +2751,6 @@ proc shutdownWait*(transp: StreamTransport): Future[void] {.
callSoon(continuation, nil)
retFuture
proc closed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in closed state.
({ReadClosed, WriteClosed} * transp.state != {})
proc finished*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in finished (EOF) state.
({ReadEof, WriteEof} * transp.state != {})
proc failed*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport in error state.
({ReadError, WriteError} * transp.state != {})
proc running*(transp: StreamTransport): bool {.inline.} =
## Returns ``true`` if transport is still pending.
({ReadClosed, ReadEof, ReadError,
WriteClosed, WriteEof, WriteError} * transp.state == {})
proc fromPipe2*(fd: AsyncFD, child: StreamTransport = nil,
bufferSize = DefaultStreamBufferSize
): Result[StreamTransport, OSErrorCode] =