Cleanup purchasing state machine (#422)

* [state machine] Allow querying of state properties

* [purchasing] use new state machine

* [state machine] remove old state machine implementation

* [purchasing] remove duplication in error handling
This commit is contained in:
markspanbroek 2023-06-05 10:48:06 +02:00 committed by GitHub
parent 3e7ce137a4
commit 3181361658
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 173 additions and 379 deletions

View File

@ -44,10 +44,10 @@ func new*(_: type Purchase,
return purchase
proc start*(purchase: Purchase) =
purchase.switch(PurchasePending())
purchase.start(PurchasePending())
proc load*(purchase: Purchase) =
purchase.switch(PurchaseUnknown())
purchase.start(PurchaseUnknown())
proc wait*(purchase: Purchase) {.async.} =
await purchase.future
@ -63,3 +63,8 @@ func error*(purchase: Purchase): ?(ref CatchableError) =
some purchase.future.error
else:
none (ref CatchableError)
func state*(purchase: Purchase): ?string =
proc description(state: State): string =
$state
purchase.query(description)

View File

@ -1,21 +1,18 @@
import ../utils/statemachine
import ../utils/asyncstatemachine
import ../market
import ../clock
import ../errors
export market
export clock
export statemachine
export asyncstatemachine
type
Purchase* = ref object of StateMachine
Purchase* = ref object of Machine
future*: Future[void]
market*: Market
clock*: Clock
requestId*: RequestId
request*: ?StorageRequest
PurchaseState* = ref object of AsyncState
PurchaseState* = ref object of State
PurchaseError* = object of CodexError
method description*(state: PurchaseState): string {.base.} =
raiseAssert "description not implemented for state"

View File

@ -1,20 +1,14 @@
import ../statemachine
import ./errorhandling
import ./error
type PurchaseCancelled* = ref object of PurchaseState
type PurchaseCancelled* = ref object of ErrorHandlingState
method enterAsync*(state: PurchaseCancelled) {.async.} =
without purchase =? (state.context as Purchase):
raiseAssert "invalid state"
try:
await purchase.market.withdrawFunds(purchase.requestId)
except CatchableError as error:
state.switch(PurchaseErrored(error: error))
return
let error = newException(Timeout, "Purchase cancelled due to timeout")
state.switch(PurchaseErrored(error: error))
method description*(state: PurchaseCancelled): string =
method `$`*(state: PurchaseCancelled): string =
"cancelled"
method run*(state: PurchaseCancelled, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
await purchase.market.withdrawFunds(purchase.requestId)
let error = newException(Timeout, "Purchase cancelled due to timeout")
return some State(PurchaseErrored(error: error))

View File

@ -3,11 +3,9 @@ import ../statemachine
type PurchaseErrored* = ref object of PurchaseState
error*: ref CatchableError
method enter*(state: PurchaseErrored) =
without purchase =? (state.context as Purchase):
raiseAssert "invalid state"
purchase.future.fail(state.error)
method description*(state: PurchaseErrored): string =
method `$`*(state: PurchaseErrored): string =
"errored"
method run*(state: PurchaseErrored, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
purchase.future.fail(state.error)

View File

@ -0,0 +1,9 @@
import pkg/questionable
import ../statemachine
import ./error
type
ErrorHandlingState* = ref object of PurchaseState
method onError*(state: ErrorHandlingState, error: ref CatchableError): ?State =
some State(PurchaseErrored(error: error))

View File

@ -4,9 +4,9 @@ import ./error
type
PurchaseFailed* = ref object of PurchaseState
method enter*(state: PurchaseFailed) =
let error = newException(PurchaseError, "Purchase failed")
state.switch(PurchaseErrored(error: error))
method description*(state: PurchaseFailed): string =
method `$`*(state: PurchaseFailed): string =
"failed"
method run*(state: PurchaseFailed, machine: Machine): Future[?State] {.async.} =
let error = newException(PurchaseError, "Purchase failed")
return some State(PurchaseErrored(error: error))

View File

@ -2,11 +2,9 @@ import ../statemachine
type PurchaseFinished* = ref object of PurchaseState
method enter*(state: PurchaseFinished) =
without purchase =? (state.context as Purchase):
raiseAssert "invalid state"
purchase.future.complete()
method description*(state: PurchaseFinished): string =
method `$`*(state: PurchaseFinished): string =
"finished"
method run*(state: PurchaseFinished, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
purchase.future.complete()

View File

@ -1,21 +1,15 @@
import ../statemachine
import ./errorhandling
import ./submitted
import ./error
type PurchasePending* = ref object of PurchaseState
type PurchasePending* = ref object of ErrorHandlingState
method enterAsync(state: PurchasePending) {.async.} =
without purchase =? (state.context as Purchase) and
request =? purchase.request:
raiseAssert "invalid state"
try:
await purchase.market.requestStorage(request)
except CatchableError as error:
state.switch(PurchaseErrored(error: error))
return
state.switch(PurchaseSubmitted())
method description*(state: PurchasePending): string =
method `$`*(state: PurchasePending): string =
"pending"
method run*(state: PurchasePending, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
let request = !purchase.request
await purchase.market.requestStorage(request)
return some State(PurchaseSubmitted())

View File

@ -1,13 +1,16 @@
import ../statemachine
import ./errorhandling
import ./error
import ./finished
import ./failed
type PurchaseStarted* = ref object of PurchaseState
type PurchaseStarted* = ref object of ErrorHandlingState
method enterAsync*(state: PurchaseStarted) {.async.} =
without purchase =? (state.context as Purchase):
raiseAssert "invalid state"
method `$`*(state: PurchaseStarted): string =
"started"
method run*(state: PurchaseStarted, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
let clock = purchase.clock
let market = purchase.market
@ -18,17 +21,11 @@ method enterAsync*(state: PurchaseStarted) {.async.} =
let subscription = await market.subscribeRequestFailed(purchase.requestId, callback)
let ended = clock.waitUntil(await market.getRequestEnd(purchase.requestId))
try:
let fut = await one(ended, failed)
await subscription.unsubscribe()
if fut.id == failed.id:
ended.cancel()
state.switch(PurchaseFailed())
return some State(PurchaseFailed())
else:
failed.cancel()
state.switch(PurchaseFinished())
await subscription.unsubscribe()
except CatchableError as error:
state.switch(PurchaseErrored(error: error))
method description*(state: PurchaseStarted): string =
"started"
return some State(PurchaseFinished())

View File

@ -1,15 +1,17 @@
import ../statemachine
import ./errorhandling
import ./error
import ./started
import ./cancelled
type PurchaseSubmitted* = ref object of PurchaseState
type PurchaseSubmitted* = ref object of ErrorHandlingState
method enterAsync(state: PurchaseSubmitted) {.async.} =
without purchase =? (state.context as Purchase) and
request =? purchase.request:
raiseAssert "invalid state"
method `$`*(state: PurchaseSubmitted): string =
"submitted"
method run*(state: PurchaseSubmitted, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
let request = !purchase.request
let market = purchase.market
let clock = purchase.clock
@ -28,13 +30,6 @@ method enterAsync(state: PurchaseSubmitted) {.async.} =
try:
await wait().withTimeout()
except Timeout:
state.switch(PurchaseCancelled())
return
except CatchableError as error:
state.switch(PurchaseErrored(error: error))
return
return some State(PurchaseCancelled())
state.switch(PurchaseStarted())
method description*(state: PurchaseSubmitted): string =
"submitted"
return some State(PurchaseStarted())

View File

@ -1,4 +1,5 @@
import ../statemachine
import ./errorhandling
import ./submitted
import ./started
import ./cancelled
@ -6,13 +7,13 @@ import ./finished
import ./failed
import ./error
type PurchaseUnknown* = ref object of PurchaseState
type PurchaseUnknown* = ref object of ErrorHandlingState
method enterAsync(state: PurchaseUnknown) {.async.} =
without purchase =? (state.context as Purchase):
raiseAssert "invalid state"
method `$`*(state: PurchaseUnknown): string =
"unknown"
try:
method run*(state: PurchaseUnknown, machine: Machine): Future[?State] {.async.} =
let purchase = Purchase(machine)
if (request =? await purchase.market.getRequest(purchase.requestId)) and
(requestState =? await purchase.market.requestState(purchase.requestId)):
@ -20,18 +21,12 @@ method enterAsync(state: PurchaseUnknown) {.async.} =
case requestState
of RequestState.New:
state.switch(PurchaseSubmitted())
return some State(PurchaseSubmitted())
of RequestState.Started:
state.switch(PurchaseStarted())
return some State(PurchaseStarted())
of RequestState.Cancelled:
state.switch(PurchaseCancelled())
return some State(PurchaseCancelled())
of RequestState.Finished:
state.switch(PurchaseFinished())
return some State(PurchaseFinished())
of RequestState.Failed:
state.switch(PurchaseFailed())
except CatchableError as error:
state.switch(PurchaseErrored(error: error))
method description*(state: PurchaseUnknown): string =
"unknown"
return some State(PurchaseFailed())

View File

@ -57,7 +57,7 @@ func `%`*(id: RequestId | SlotId | Nonce | AvailabilityId): JsonNode =
func `%`*(purchase: Purchase): JsonNode =
%*{
"state": (purchase.state as PurchaseState).?description |? "none",
"state": purchase.state |? "none",
"error": purchase.error.?msg,
"request": purchase.request,
}

View File

@ -13,6 +13,7 @@ type
scheduling: Future[void]
started: bool
State* = ref object of RootObj
Query*[T] = proc(state: State): T
Event* = proc(state: State): ?State {.gcsafe, upraises:[].}
logScope:
@ -26,6 +27,12 @@ proc transition(_: type Event, previous, next: State): Event =
if state == previous:
return some next
proc query*[T](machine: Machine, query: Query[T]): ?T =
if machine.state == nil:
none T
else:
some query(machine.state)
proc schedule*(machine: Machine, event: Event) =
if not machine.started:
return
@ -90,4 +97,5 @@ proc stop*(machine: Machine) =
if not machine.running.isNil:
machine.running.cancel()
machine.state = nil
machine.started = false

View File

@ -1,130 +0,0 @@
import std/typetraits
import pkg/chronicles
import pkg/questionable
import pkg/chronos
import ./optionalcast
## Implementation of the the state pattern:
## https://en.wikipedia.org/wiki/State_pattern
##
## Define your own state machine and state types:
##
## type
## Light = ref object of StateMachine
## color: string
## LightState = ref object of State
##
## let light = Light(color: "yellow")
##
## Define the states:
##
## type
## On = ref object of LightState
## Off = ref object of LightState
##
## Perform actions on state entry and exit:
##
## method enter(state: On) =
## echo light.color, " light switched on"
##
## method exit(state: On) =
## echo light.color, " light no longer switched on"
##
## light.switch(On()) # prints: 'light switched on'
## light.switch(Off()) # prints: 'light no longer switched on'
##
## Allow behaviour to change based on the current state:
##
## method description*(state: LightState): string {.base.} =
## return "a light"
##
## method description*(state: On): string =
## if light =? (state.context as Light):
## return "a " & light.color & " light"
##
## method description*(state: Off): string =
## return "a dark light"
##
## proc description*(light: Light): string =
## if state =? (light.state as LightState):
## return state.description
##
## light.switch(On())
## echo light.description # prints: 'a yellow light'
## light.switch(Off())
## echo light.description # prints 'a dark light'
export questionable
export optionalcast
type
StateMachine* = ref object of RootObj
state: ?State
State* = ref object of RootObj
context: ?StateMachine
method `$`*(state: State): string {.base.} =
(typeof state).name
method enter(state: State) {.base.} =
discard
method exit(state: State) {.base.} =
discard
func state*(machine: StateMachine): ?State =
machine.state
func context*(state: State): ?StateMachine =
state.context
proc switch*(machine: StateMachine, newState: State) =
if state =? machine.state:
state.exit()
state.context = StateMachine.none
machine.state = newState.some
newState.context = machine.some
newState.enter()
proc switch*(oldState, newState: State) =
if context =? oldState.context:
context.switch(newState)
type
AsyncState* = ref object of State
activeTransition: ?Future[void]
StateMachineAsync* = ref object of StateMachine
method enterAsync(state: AsyncState) {.base, async.} =
discard
method exitAsync(state: AsyncState) {.base, async.} =
discard
method enter(state: AsyncState) =
asyncSpawn state.enterAsync()
method exit(state: AsyncState) =
asyncSpawn state.exitAsync()
proc switchAsync*(machine: StateMachineAsync, newState: AsyncState) {.async.} =
if state =? (machine.state as AsyncState):
trace "Switching sales state", `from` = $state, to = $newState
if activeTransition =? state.activeTransition and
not activeTransition.completed:
await activeTransition.cancelAndWait()
# should wait for exit before switch. could add a transition option during
# switch if we don't need to wait
await state.exitAsync()
state.context = none StateMachine
else:
trace "Switching state", `from` = "no state", to = $newState
machine.state = some State(newState)
newState.context = some StateMachine(machine)
newState.activeTransition = some newState.enterAsync()
proc switchAsync*(oldState, newState: AsyncState) {.async.} =
if context =? oldState.context:
await StateMachineAsync(context).switchAsync(newState)

View File

@ -3,7 +3,12 @@ import pkg/asynctest
import pkg/chronos
import pkg/stint
import pkg/codex/purchasing
import pkg/codex/purchasing/states/[finished, error, started, submitted, unknown]
import pkg/codex/purchasing/states/finished
import pkg/codex/purchasing/states/started
import pkg/codex/purchasing/states/submitted
import pkg/codex/purchasing/states/unknown
import pkg/codex/purchasing/states/cancelled
import pkg/codex/purchasing/states/failed
import ./helpers/mockmarket
import ./helpers/mockclock
import ./helpers/eventually
@ -31,11 +36,11 @@ suite "Purchasing":
test "submits a storage request when asked":
discard await purchasing.purchase(request)
let submitted = market.requested[0]
check submitted.ask.slots == request.ask.slots
check submitted.ask.slotSize == request.ask.slotSize
check submitted.ask.duration == request.ask.duration
check submitted.ask.reward == request.ask.reward
check eventually market.requested.len > 0
check market.requested[0].ask.slots == request.ask.slots
check market.requested[0].ask.slotSize == request.ask.slotSize
check market.requested[0].ask.duration == request.ask.duration
check market.requested[0].ask.reward == request.ask.reward
test "remembers purchases":
let purchase1 = await purchasing.purchase(request)
@ -49,11 +54,13 @@ suite "Purchasing":
test "can change default value for proof probability":
purchasing.proofProbability = 42.u256
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].ask.proofProbability == 42.u256
test "can override proof probability per request":
request.ask.proofProbability = 42.u256
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].ask.proofProbability == 42.u256
test "has a default value for request expiration interval":
@ -63,25 +70,30 @@ suite "Purchasing":
purchasing.requestExpiryInterval = 42.u256
let start = getTime().toUnix()
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].expiry == (start + 42).u256
test "can override expiry time per request":
let expiry = (getTime().toUnix() + 42).u256
request.expiry = expiry
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].expiry == expiry
test "includes a random nonce in every storage request":
discard await purchasing.purchase(request)
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].nonce != market.requested[1].nonce
test "sets client address in request":
discard await purchasing.purchase(request)
check eventually market.requested.len > 0
check market.requested[0].client == await market.getSigner()
test "succeeds when request is finished":
let purchase = await purchasing.purchase(request)
check eventually market.requested.len > 0
let request = market.requested[0]
let requestEnd = getTime().toUnix() + 42
market.requestEnds[request.id] = requestEnd
@ -92,6 +104,7 @@ suite "Purchasing":
test "fails when request times out":
let purchase = await purchasing.purchase(request)
check eventually market.requested.len > 0
let request = market.requested[0]
clock.set(request.expiry.truncate(int64))
expect PurchaseTimeout:
@ -99,6 +112,7 @@ suite "Purchasing":
test "checks that funds were withdrawn when purchase times out":
let purchase = await purchasing.purchase(request)
check eventually market.requested.len > 0
let request = market.requested[0]
clock.set(request.expiry.truncate(int64))
expect PurchaseTimeout:
@ -150,20 +164,20 @@ suite "Purchasing state machine":
market.requestEnds[request2.id] = clock.now() - 1
await purchasing.load()
check purchasing.getPurchase(PurchaseId(request1.id)).?finished == false.some
check purchasing.getPurchase(PurchaseId(request2.id)).?finished == true.some
check purchasing.getPurchase(PurchaseId(request3.id)).?finished == true.some
check purchasing.getPurchase(PurchaseId(request4.id)).?finished == true.some
check purchasing.getPurchase(PurchaseId(request5.id)).?finished == true.some
check purchasing.getPurchase(PurchaseId(request5.id)).?error.isSome
check eventually purchasing.getPurchase(PurchaseId(request1.id)).?finished == false.some
check eventually purchasing.getPurchase(PurchaseId(request2.id)).?finished == true.some
check eventually purchasing.getPurchase(PurchaseId(request3.id)).?finished == true.some
check eventually purchasing.getPurchase(PurchaseId(request4.id)).?finished == true.some
check eventually purchasing.getPurchase(PurchaseId(request5.id)).?finished == true.some
check eventually purchasing.getPurchase(PurchaseId(request5.id)).?error.isSome
test "moves to PurchaseSubmitted when request state is New":
let request = StorageRequest.example
let purchase = Purchase.new(request, market, clock)
market.requested = @[request]
market.requestState[request.id] = RequestState.New
purchase.switch(PurchaseUnknown())
check (purchase.state as PurchaseSubmitted).isSome
let next = await PurchaseUnknown().run(purchase)
check !next of PurchaseSubmitted
test "moves to PurchaseStarted when request state is Started":
let request = StorageRequest.example
@ -171,69 +185,51 @@ suite "Purchasing state machine":
market.requestEnds[request.id] = clock.now() + request.ask.duration.truncate(int64)
market.requested = @[request]
market.requestState[request.id] = RequestState.Started
purchase.switch(PurchaseUnknown())
check (purchase.state as PurchaseStarted).isSome
let next = await PurchaseUnknown().run(purchase)
check !next of PurchaseStarted
test "moves to PurchaseErrored when request state is Cancelled":
test "moves to PurchaseCancelled when request state is Cancelled":
let request = StorageRequest.example
let purchase = Purchase.new(request, market, clock)
market.requested = @[request]
market.requestState[request.id] = RequestState.Cancelled
purchase.switch(PurchaseUnknown())
check (purchase.state as PurchaseErrored).isSome
check purchase.error.?msg == "Purchase cancelled due to timeout".some
let next = await PurchaseUnknown().run(purchase)
check !next of PurchaseCancelled
test "moves to PurchaseFinished when request state is Finished":
let request = StorageRequest.example
let purchase = Purchase.new(request, market, clock)
market.requested = @[request]
market.requestState[request.id] = RequestState.Finished
purchase.switch(PurchaseUnknown())
check (purchase.state as PurchaseFinished).isSome
let next = await PurchaseUnknown().run(purchase)
check !next of PurchaseFinished
test "moves to PurchaseErrored when request state is Failed":
test "moves to PurchaseFailed when request state is Failed":
let request = StorageRequest.example
let purchase = Purchase.new(request, market, clock)
market.requested = @[request]
market.requestState[request.id] = RequestState.Failed
purchase.switch(PurchaseUnknown())
check (purchase.state as PurchaseErrored).isSome
check purchase.error.?msg == "Purchase failed".some
let next = await PurchaseUnknown().run(purchase)
check !next of PurchaseFailed
test "moves to PurchaseErrored state once RequestFailed emitted":
let me = await market.getSigner()
test "moves to PurchaseFailed state once RequestFailed emitted":
let request = StorageRequest.example
market.requested = @[request]
market.activeRequests[me] = @[request.id]
market.requestState[request.id] = RequestState.Started
let purchase = Purchase.new(request, market, clock)
market.requestEnds[request.id] = clock.now() + request.ask.duration.truncate(int64)
await purchasing.load()
let future = PurchaseStarted().run(purchase)
# emit mock contract failure event
market.emitRequestFailed(request.id)
# must allow time for the callback to trigger the completion of the future
await sleepAsync(chronos.milliseconds(10))
# now check the result
let purchase = purchasing.getPurchase(PurchaseId(request.id))
let state = purchase.?state
check (state as PurchaseErrored).isSome
check (!purchase).error.?msg == "Purchase failed".some
let next = await future
check !next of PurchaseFailed
test "moves to PurchaseFinished state once request finishes":
let me = await market.getSigner()
let request = StorageRequest.example
market.requested = @[request]
market.activeRequests[me] = @[request.id]
market.requestState[request.id] = RequestState.Started
let purchase = Purchase.new(request, market, clock)
market.requestEnds[request.id] = clock.now() + request.ask.duration.truncate(int64)
await purchasing.load()
let future = PurchaseStarted().run(purchase)
# advance the clock to the end of the request
clock.advance(request.ask.duration.truncate(int64))
# now check the result
proc requestState: ?PurchaseState =
purchasing.getPurchase(PurchaseId(request.id)).?state as PurchaseState
check eventually (requestState() as PurchaseFinished).isSome
let next = await future
check !next of PurchaseFinished

View File

@ -1,5 +1,3 @@
import ./utils/teststatemachine
import ./utils/teststatemachineasync
import ./utils/testoptionalcast
import ./utils/testkeyutils
import ./utils/testasyncstatemachine

View File

@ -113,3 +113,21 @@ suite "async state machines":
machine.schedule(moveToNextStateEvent)
check eventually cancellations == [0, 1, 0, 0]
check errors == [0, 0, 0, 0]
test "queries properties of the current state":
proc description(state: State): string =
$state
machine.start(State2.new())
check eventually machine.query(description) == some "State2"
machine.schedule(moveToNextStateEvent)
check eventually machine.query(description) == some "State3"
test "stops handling queries when stopped":
proc description(state: State): string =
$state
machine.start(State2.new())
check eventually machine.query(description).isSome
machine.stop()
check machine.query(description).isNone

View File

@ -1,48 +0,0 @@
import std/unittest
import pkg/questionable
import codex/utils/statemachine
type
Light = ref object of StateMachine
On = ref object of State
Off = ref object of State
var enteredOn: bool
var exitedOn: bool
method enter(state: On) =
enteredOn = true
method exit(state: On) =
exitedOn = true
suite "state machines":
setup:
enteredOn = false
exitedOn = false
test "calls `enter` when entering state":
Light().switch(On())
check enteredOn
test "calls `exit` when exiting state":
let light = Light()
light.switch(On())
check not exitedOn
light.switch(Off())
check exitedOn
test "allows access to state machine from state":
let light = Light()
let on = On()
check not isSome on.context
light.switch(on)
check on.context == some StateMachine(light)
test "removes access to state machine when state exited":
let light = Light()
let on = On()
light.switch(on)
light.switch(Off())
check not isSome on.context

View File

@ -1,30 +0,0 @@
import pkg/asynctest
import pkg/chronos
import pkg/questionable
import codex/utils/statemachine
type
AsyncMachine = ref object of StateMachineAsync
LongRunningStart = ref object of AsyncState
LongRunningFinish = ref object of AsyncState
LongRunningError = ref object of AsyncState
Callback = proc(): Future[void] {.gcsafe.}
proc triggerIn(time: Duration, cb: Callback) {.async.} =
await sleepAsync(time)
await cb()
method enterAsync(state: LongRunningStart) {.async.} =
proc cb() {.async.} =
await state.switchAsync(LongRunningFinish())
asyncSpawn triggerIn(500.milliseconds, cb)
await sleepAsync(1.seconds)
await state.switchAsync(LongRunningError())
suite "async state machines":
test "can cancel a state":
let am = AsyncMachine()
await am.switchAsync(LongRunningStart())
await sleepAsync(2.seconds)
check (am.state as LongRunningFinish).isSome