222 lines
6.4 KiB
Go
Raw Normal View History

2023-07-05 10:59:39 -03:00
package async
import (
"context"
"errors"
"fmt"
"sync"
orderedmap "github.com/wk8/go-ordered-map/v2"
"github.com/status-im/status-go/common"
)
var ErrTaskOverwritten = errors.New("task overwritten")
// Scheduler ensures that only one task of a type is running at a time.
type Scheduler struct {
queue *orderedmap.OrderedMap[TaskType, *taskContext]
queueMutex sync.Mutex
context context.Context
cancelFn context.CancelFunc
doNotDeleteCurrentTask bool
}
type ReplacementPolicy = int
const (
// ReplacementPolicyCancelOld for when the task arguments might change the result
ReplacementPolicyCancelOld ReplacementPolicy = iota
// ReplacementPolicyIgnoreNew for when the task arguments doesn't change the result
ReplacementPolicyIgnoreNew
)
type TaskType struct {
ID int64
Policy ReplacementPolicy
}
type taskFunction func(context.Context) (interface{}, error)
type resultFunction func(interface{}, TaskType, error)
type taskContext struct {
taskType TaskType
policy ReplacementPolicy
taskFn taskFunction
resFn resultFunction
}
func NewScheduler() *Scheduler {
return &Scheduler{
queue: orderedmap.New[TaskType, *taskContext](),
}
}
// Enqueue provides a queue of task types allowing only one task at a time of the corresponding type. The running task is the first one in the queue (s.queue.Oldest())
//
// Schedule policy for new tasks
// - pushed at the back of the queue (s.queue.PushBack()) if none of the same time already scheduled
// - overwrite the queued one of the same type, depending on the policy
// - In case of ReplacementPolicyIgnoreNew, the new task will be ignored
// - In case of ReplacementPolicyCancelOld, the old running task will be canceled or if not yet run overwritten and the new one will be executed when its turn comes.
//
// The task function (taskFn) might not be executed if
// - the task is ignored
// - the task is overwritten. The result function (resFn) will be called with ErrTaskOverwritten
//
// The result function (resFn) will always be called if the task is not ignored
func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn resultFunction) (ignored bool) {
s.queueMutex.Lock()
defer s.queueMutex.Unlock()
taskRunning := s.queue.Len() > 0
existingTask, typeInQueue := s.queue.Get(taskType)
// we need wrap the original resFn to ensure it is called only once
// otherwise, there's a chance that it will be called twice if we
// call Stop() quickly after Enqueue while task is running
var invokeResFnOnce sync.Once
onceResFn := func(res interface{}, taskType TaskType, err error) {
invokeResFnOnce.Do(func() {
resFn(res, taskType, err)
})
}
newTask := &taskContext{
taskType: taskType,
policy: taskType.Policy,
taskFn: taskFn,
resFn: onceResFn,
}
if taskRunning {
if typeInQueue {
if s.queue.Oldest().Value.taskType == taskType {
// If same task type is running
if existingTask.policy == ReplacementPolicyCancelOld {
// If a previous task is running, cancel it
if s.cancelFn != nil {
s.cancelFn()
s.cancelFn = nil
} else {
// In case of multiple tasks of the same type, the previous one is overwritten
go func() {
defer common.LogOnPanic()
existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten)
}()
}
s.doNotDeleteCurrentTask = true
// Add it again to refresh the order of the task
s.queue.Delete(taskType)
s.queue.Set(taskType, newTask)
} else {
ignored = true
}
} else {
// if other task type is running
// notify the queued one that it is overwritten or ignored
if existingTask.policy == ReplacementPolicyCancelOld {
oldResFn := existingTask.resFn
go func() {
defer common.LogOnPanic()
oldResFn(nil, existingTask.taskType, ErrTaskOverwritten)
}()
// Overwrite the queued one of the same type
existingTask.taskFn = taskFn
existingTask.resFn = onceResFn
} else {
ignored = true
}
}
} else {
// Policy does not matter for the fist enqueued task of a type
s.queue.Set(taskType, newTask)
}
} else {
// If no task is running add and run it. The worker will take care of scheduling new tasks added while running
s.queue.Set(taskType, newTask)
existingTask = newTask
s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
s.finishedTask(res, runningTask, onceResFn, err)
})
}
return ignored
}
func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(interface{}, *taskContext, error)) {
thisContext, thisCancelFn := context.WithCancel(context.Background())
s.cancelFn = thisCancelFn
s.context = thisContext
go func() {
defer common.LogOnPanic()
res, err := taskFn(thisContext)
// Release context resources
thisCancelFn()
if errors.Is(err, context.Canceled) {
resFn(res, tc, fmt.Errorf("task canceled: %w", err))
} else {
resFn(res, tc, err)
}
}()
}
// finishedTask is the only one that can remove a task from the queue
// if the current running task completed (doNotDeleteCurrentTask is true)
func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
s.queueMutex.Lock()
current := s.queue.Oldest()
// Delete current task if not overwritten
if s.doNotDeleteCurrentTask {
s.doNotDeleteCurrentTask = false
} else {
// current maybe nil if Stop() is called
if current != nil {
s.queue.Delete(current.Value.taskType)
}
}
// Run next task
if pair := s.queue.Oldest(); pair != nil {
nextTask := pair.Value
s.runTask(nextTask, nextTask.taskFn, func(res interface{}, runningTask *taskContext, err error) {
s.finishedTask(res, runningTask, runningTask.resFn, err)
})
} else {
s.cancelFn = nil
}
s.queueMutex.Unlock()
// Report result
finishedResFn(finishedRes, doneTask.taskType, finishedErr)
}
func (s *Scheduler) Stop() {
s.queueMutex.Lock()
defer s.queueMutex.Unlock()
if s.cancelFn != nil {
s.cancelFn()
s.cancelFn = nil
}
// Empty the queue so the running task will not be restarted
for pair := s.queue.Oldest(); pair != nil; pair = pair.Next() {
// Notify the queued one that they are canceled
if pair.Value.policy == ReplacementPolicyCancelOld {
go func(val *taskContext) {
defer common.LogOnPanic()
val.resFn(nil, val.taskType, context.Canceled)
}(pair.Value)
}
s.queue.Delete(pair.Value.taskType)
}
}