216 lines
6.2 KiB
Go
216 lines
6.2 KiB
Go
package async
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
orderedmap "github.com/wk8/go-ordered-map/v2"
|
|
)
|
|
|
|
var ErrTaskOverwritten = errors.New("task overwritten")
|
|
|
|
// Scheduler ensures that only one task of a type is running at a time.
|
|
type Scheduler struct {
|
|
queue *orderedmap.OrderedMap[TaskType, *taskContext]
|
|
queueMutex sync.Mutex
|
|
|
|
context context.Context
|
|
cancelFn context.CancelFunc
|
|
doNotDeleteCurrentTask bool
|
|
}
|
|
|
|
type ReplacementPolicy = int
|
|
|
|
const (
|
|
// ReplacementPolicyCancelOld for when the task arguments might change the result
|
|
ReplacementPolicyCancelOld ReplacementPolicy = iota
|
|
// ReplacementPolicyIgnoreNew for when the task arguments doesn't change the result
|
|
ReplacementPolicyIgnoreNew
|
|
)
|
|
|
|
type TaskType struct {
|
|
ID int64
|
|
Policy ReplacementPolicy
|
|
}
|
|
|
|
type taskFunction func(context.Context) (interface{}, error)
|
|
type resultFunction func(interface{}, TaskType, error)
|
|
|
|
type taskContext struct {
|
|
taskType TaskType
|
|
policy ReplacementPolicy
|
|
|
|
taskFn taskFunction
|
|
resFn resultFunction
|
|
}
|
|
|
|
func NewScheduler() *Scheduler {
|
|
return &Scheduler{
|
|
queue: orderedmap.New[TaskType, *taskContext](),
|
|
}
|
|
}
|
|
|
|
// Enqueue provides a queue of task types allowing only one task at a time of the corresponding type. The running task is the first one in the queue (s.queue.Oldest())
|
|
//
|
|
// Schedule policy for new tasks
|
|
// - pushed at the back of the queue (s.queue.PushBack()) if none of the same time already scheduled
|
|
// - overwrite the queued one of the same type, depending on the policy
|
|
// - In case of ReplacementPolicyIgnoreNew, the new task will be ignored
|
|
// - In case of ReplacementPolicyCancelOld, the old running task will be canceled or if not yet run overwritten and the new one will be executed when its turn comes.
|
|
//
|
|
// The task function (taskFn) might not be executed if
|
|
// - the task is ignored
|
|
// - the task is overwritten. The result function (resFn) will be called with ErrTaskOverwritten
|
|
//
|
|
// The result function (resFn) will always be called if the task is not ignored
|
|
func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn resultFunction) (ignored bool) {
|
|
s.queueMutex.Lock()
|
|
defer s.queueMutex.Unlock()
|
|
|
|
taskRunning := s.queue.Len() > 0
|
|
existingTask, typeInQueue := s.queue.Get(taskType)
|
|
|
|
// we need wrap the original resFn to ensure it is called only once
|
|
// otherwise, there's a chance that it will be called twice if we
|
|
// call Stop() quickly after Enqueue while task is running
|
|
var invokeResFnOnce sync.Once
|
|
onceResFn := func(res interface{}, taskType TaskType, err error) {
|
|
invokeResFnOnce.Do(func() {
|
|
resFn(res, taskType, err)
|
|
})
|
|
}
|
|
newTask := &taskContext{
|
|
taskType: taskType,
|
|
policy: taskType.Policy,
|
|
taskFn: taskFn,
|
|
resFn: onceResFn,
|
|
}
|
|
|
|
if taskRunning {
|
|
if typeInQueue {
|
|
if s.queue.Oldest().Value.taskType == taskType {
|
|
// If same task type is running
|
|
if existingTask.policy == ReplacementPolicyCancelOld {
|
|
// If a previous task is running, cancel it
|
|
if s.cancelFn != nil {
|
|
s.cancelFn()
|
|
s.cancelFn = nil
|
|
} else {
|
|
// In case of multiple tasks of the same type, the previous one is overwritten
|
|
go func() {
|
|
existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten)
|
|
}()
|
|
}
|
|
|
|
s.doNotDeleteCurrentTask = true
|
|
|
|
// Add it again to refresh the order of the task
|
|
s.queue.Delete(taskType)
|
|
s.queue.Set(taskType, newTask)
|
|
} else {
|
|
ignored = true
|
|
}
|
|
} else {
|
|
// if other task type is running
|
|
// notify the queued one that it is overwritten or ignored
|
|
if existingTask.policy == ReplacementPolicyCancelOld {
|
|
oldResFn := existingTask.resFn
|
|
go func() {
|
|
oldResFn(nil, existingTask.taskType, ErrTaskOverwritten)
|
|
}()
|
|
// Overwrite the queued one of the same type
|
|
existingTask.taskFn = taskFn
|
|
existingTask.resFn = onceResFn
|
|
} else {
|
|
ignored = true
|
|
}
|
|
}
|
|
} else {
|
|
// Policy does not matter for the fist enqueued task of a type
|
|
s.queue.Set(taskType, newTask)
|
|
}
|
|
} else {
|
|
// If no task is running add and run it. The worker will take care of scheduling new tasks added while running
|
|
s.queue.Set(taskType, newTask)
|
|
existingTask = newTask
|
|
s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
|
|
s.finishedTask(res, runningTask, onceResFn, err)
|
|
})
|
|
}
|
|
|
|
return ignored
|
|
}
|
|
|
|
func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(interface{}, *taskContext, error)) {
|
|
thisContext, thisCancelFn := context.WithCancel(context.Background())
|
|
s.cancelFn = thisCancelFn
|
|
s.context = thisContext
|
|
|
|
go func() {
|
|
res, err := taskFn(thisContext)
|
|
|
|
// Release context resources
|
|
thisCancelFn()
|
|
|
|
if errors.Is(err, context.Canceled) {
|
|
resFn(res, tc, fmt.Errorf("task canceled: %w", err))
|
|
} else {
|
|
resFn(res, tc, err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// finishedTask is the only one that can remove a task from the queue
|
|
// if the current running task completed (doNotDeleteCurrentTask is true)
|
|
func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
|
|
s.queueMutex.Lock()
|
|
|
|
current := s.queue.Oldest()
|
|
// Delete current task if not overwritten
|
|
if s.doNotDeleteCurrentTask {
|
|
s.doNotDeleteCurrentTask = false
|
|
} else {
|
|
// current maybe nil if Stop() is called
|
|
if current != nil {
|
|
s.queue.Delete(current.Value.taskType)
|
|
}
|
|
}
|
|
|
|
// Run next task
|
|
if pair := s.queue.Oldest(); pair != nil {
|
|
nextTask := pair.Value
|
|
s.runTask(nextTask, nextTask.taskFn, func(res interface{}, runningTask *taskContext, err error) {
|
|
s.finishedTask(res, runningTask, runningTask.resFn, err)
|
|
})
|
|
} else {
|
|
s.cancelFn = nil
|
|
}
|
|
s.queueMutex.Unlock()
|
|
|
|
// Report result
|
|
finishedResFn(finishedRes, doneTask.taskType, finishedErr)
|
|
}
|
|
|
|
func (s *Scheduler) Stop() {
|
|
s.queueMutex.Lock()
|
|
defer s.queueMutex.Unlock()
|
|
|
|
if s.cancelFn != nil {
|
|
s.cancelFn()
|
|
s.cancelFn = nil
|
|
}
|
|
|
|
// Empty the queue so the running task will not be restarted
|
|
for pair := s.queue.Oldest(); pair != nil; pair = pair.Next() {
|
|
// Notify the queued one that they are canceled
|
|
if pair.Value.policy == ReplacementPolicyCancelOld {
|
|
go func(val *taskContext) {
|
|
val.resFn(nil, val.taskType, context.Canceled)
|
|
}(pair.Value)
|
|
}
|
|
s.queue.Delete(pair.Value.taskType)
|
|
}
|
|
}
|