fix(asyncstatemachine): fixes not awaiting or asyncSpawning futures (#1033)

- adds a break in scheduler when CancelledError is caught
- tracks asyncSpawned state.run, so that it can be cancelled during stop
- removes usages of `then`
- ensures that no exceptions are leaked from async procs
This commit is contained in:
Eric 2024-12-13 09:35:39 +07:00 committed by GitHub
parent 19af79786e
commit 7c804b0ec9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 27 additions and 31 deletions

View File

@ -1,12 +1,11 @@
import std/sugar import std/sugar
import pkg/questionable import pkg/questionable
import pkg/chronos import pkg/chronos
import pkg/upraises
import ../logutils import ../logutils
import ./then import ./then
import ./trackedfutures import ./trackedfutures
push: {.upraises:[].} {.push raises:[].}
type type
Machine* = ref object of RootObj Machine* = ref object of RootObj
@ -17,7 +16,7 @@ type
trackedFutures: TrackedFutures trackedFutures: TrackedFutures
State* = ref object of RootObj State* = ref object of RootObj
Query*[T] = proc(state: State): T Query*[T] = proc(state: State): T
Event* = proc(state: State): ?State {.gcsafe, upraises:[].} Event* = proc(state: State): ?State {.gcsafe, raises:[].}
logScope: logScope:
topics = "statemachine" topics = "statemachine"
@ -58,29 +57,31 @@ proc onError(machine: Machine, error: ref CatchableError): Event =
return proc (state: State): ?State = return proc (state: State): ?State =
state.onError(error) state.onError(error)
proc run(machine: Machine, state: State) {.async.} = proc run(machine: Machine, state: State) {.async: (raises:[]).} =
if next =? await state.run(machine): try:
machine.schedule(Event.transition(state, next)) if next =? await state.run(machine):
machine.schedule(Event.transition(state, next))
except CancelledError:
discard # do not propagate
except CatchableError as e:
machine.schedule(machine.onError(e))
proc scheduler(machine: Machine) {.async.} = proc scheduler(machine: Machine) {.async: (raises: []).} =
var running: Future[void] var running: Future[void].Raising([])
while machine.started: while machine.started:
let event = await machine.scheduled.get().track(machine) try:
if next =? event(machine.state): let event = await machine.scheduled.get()
if not running.isNil and not running.finished: if next =? event(machine.state):
trace "cancelling current state", state = $machine.state if not running.isNil and not running.finished:
await running.cancelAndWait() trace "cancelling current state", state = $machine.state
let fromState = if machine.state.isNil: "<none>" else: $machine.state await running.cancelAndWait()
machine.state = next let fromState = if machine.state.isNil: "<none>" else: $machine.state
debug "enter state", state = fromState & " => " & $machine.state machine.state = next
running = machine.run(machine.state) debug "enter state", state = fromState & " => " & $machine.state
running running = machine.run(machine.state)
.track(machine) asyncSpawn running.track(machine)
.cancelled(proc() = trace "state.run cancelled, swallowing", state = $machine.state) except CancelledError:
.catch(proc(err: ref CatchableError) = break # do not propagate bc it is asyncSpawned
trace "error caught in state.run, calling state.onError", state = $machine.state
machine.schedule(machine.onError(err))
)
proc start*(machine: Machine, initialState: State) = proc start*(machine: Machine, initialState: State) =
if machine.started: if machine.started:
@ -90,13 +91,8 @@ proc start*(machine: Machine, initialState: State) =
machine.scheduled = newAsyncQueue[Event]() machine.scheduled = newAsyncQueue[Event]()
machine.started = true machine.started = true
try: asyncSpawn machine.scheduler().track(machine)
discard machine.scheduler().track(machine) machine.schedule(Event.transition(machine.state, initialState))
machine.schedule(Event.transition(machine.state, initialState))
except CancelledError as e:
discard
except CatchableError as e:
error("Error in scheduler", error = e.msg)
proc stop*(machine: Machine) {.async.} = proc stop*(machine: Machine) {.async.} =
if not machine.started: if not machine.started: