fix scheduler panic after quick stop (#5724)
* fix_: scheduler panic after quick stop * test_: fix failed test TestScheduler_Enqueue_ValidateOrder
This commit is contained in:
parent
9e5fa3f22c
commit
9cdfd6fb42
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
var ErrTaskOverwritten = errors.New("task overwritten")
|
var ErrTaskOverwritten = errors.New("task overwritten")
|
||||||
|
|
||||||
|
// Scheduler ensures that only one task of a type is running at a time.
|
||||||
type Scheduler struct {
|
type Scheduler struct {
|
||||||
queue *orderedmap.OrderedMap[TaskType, *taskContext]
|
queue *orderedmap.OrderedMap[TaskType, *taskContext]
|
||||||
queueMutex sync.Mutex
|
queueMutex sync.Mutex
|
||||||
|
@ -71,11 +72,20 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result
|
||||||
taskRunning := s.queue.Len() > 0
|
taskRunning := s.queue.Len() > 0
|
||||||
existingTask, typeInQueue := s.queue.Get(taskType)
|
existingTask, typeInQueue := s.queue.Get(taskType)
|
||||||
|
|
||||||
|
// we need wrap the original resFn to ensure it is called only once
|
||||||
|
// otherwise, there's a chance that it will be called twice if we
|
||||||
|
// call Stop() quickly after Enqueue while task is running
|
||||||
|
var invokeResFnOnce sync.Once
|
||||||
|
onceResFn := func(res interface{}, taskType TaskType, err error) {
|
||||||
|
invokeResFnOnce.Do(func() {
|
||||||
|
resFn(res, taskType, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
newTask := &taskContext{
|
newTask := &taskContext{
|
||||||
taskType: taskType,
|
taskType: taskType,
|
||||||
policy: taskType.Policy,
|
policy: taskType.Policy,
|
||||||
taskFn: taskFn,
|
taskFn: taskFn,
|
||||||
resFn: resFn,
|
resFn: onceResFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
if taskRunning {
|
if taskRunning {
|
||||||
|
@ -106,12 +116,13 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result
|
||||||
// if other task type is running
|
// if other task type is running
|
||||||
// notify the queued one that it is overwritten or ignored
|
// notify the queued one that it is overwritten or ignored
|
||||||
if existingTask.policy == ReplacementPolicyCancelOld {
|
if existingTask.policy == ReplacementPolicyCancelOld {
|
||||||
|
oldResFn := existingTask.resFn
|
||||||
go func() {
|
go func() {
|
||||||
existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten)
|
oldResFn(nil, existingTask.taskType, ErrTaskOverwritten)
|
||||||
}()
|
}()
|
||||||
// Overwrite the queued one of the same type
|
// Overwrite the queued one of the same type
|
||||||
existingTask.taskFn = taskFn
|
existingTask.taskFn = taskFn
|
||||||
existingTask.resFn = resFn
|
existingTask.resFn = onceResFn
|
||||||
} else {
|
} else {
|
||||||
ignored = true
|
ignored = true
|
||||||
}
|
}
|
||||||
|
@ -125,7 +136,7 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result
|
||||||
s.queue.Set(taskType, newTask)
|
s.queue.Set(taskType, newTask)
|
||||||
existingTask = newTask
|
existingTask = newTask
|
||||||
s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
|
s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
|
||||||
s.finishedTask(res, runningTask, resFn, err)
|
s.finishedTask(res, runningTask, onceResFn, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,13 +167,15 @@ func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(int
|
||||||
func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
|
func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
|
||||||
s.queueMutex.Lock()
|
s.queueMutex.Lock()
|
||||||
|
|
||||||
// We always have a running task
|
|
||||||
current := s.queue.Oldest()
|
current := s.queue.Oldest()
|
||||||
// Delete current task if not overwritten
|
// Delete current task if not overwritten
|
||||||
if s.doNotDeleteCurrentTask {
|
if s.doNotDeleteCurrentTask {
|
||||||
s.doNotDeleteCurrentTask = false
|
s.doNotDeleteCurrentTask = false
|
||||||
} else {
|
} else {
|
||||||
s.queue.Delete(current.Value.taskType)
|
// current maybe nil if Stop() is called
|
||||||
|
if current != nil {
|
||||||
|
s.queue.Delete(current.Value.taskType)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run next task
|
// Run next task
|
||||||
|
|
|
@ -287,6 +287,7 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) {
|
||||||
|
|
||||||
if p.taskType.Policy == ReplacementPolicyCancelOld && ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) {
|
if p.taskType.Policy == ReplacementPolicyCancelOld && ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) {
|
||||||
taskCanceledChan <- p
|
taskCanceledChan <- p
|
||||||
|
t.Logf("task canceled, task seq: %d, task type: %+v", currentIndex, p.taskType)
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,13 +296,16 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) {
|
||||||
return nil, errors.New("test error")
|
return nil, errors.New("test error")
|
||||||
}
|
}
|
||||||
taskSuccessChan <- p
|
taskSuccessChan <- p
|
||||||
|
t.Logf("task executed successfully, task seq: %d, task type: %+v", currentIndex, p.taskType)
|
||||||
return 10 * (currentIndex + 1), nil
|
return 10 * (currentIndex + 1), nil
|
||||||
}, func(res interface{}, taskType TaskType, err error) {
|
}, func(res interface{}, taskType TaskType, err error) {
|
||||||
require.Equal(t, p.taskType, taskType)
|
require.Equal(t, p.taskType, taskType)
|
||||||
resChan <- p
|
resChan <- p
|
||||||
|
t.Logf("response invoked, task seq: %d, task type: %+v, result: %+v", currentIndex, taskType, res)
|
||||||
})
|
})
|
||||||
|
|
||||||
if ignored {
|
if ignored {
|
||||||
|
t.Logf("task ignored, task seq: %d, task type: %+v", currentIndex, p.taskType)
|
||||||
ignoredCount++
|
ignoredCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,7 +359,7 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) {
|
||||||
require.Equal(t, 1, taskFailedCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
|
require.Equal(t, 1, taskFailedCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
|
||||||
|
|
||||||
require.Equal(t, 2, resChanCount[testTask1], "expected two task call for type: %d had %d", 1, taskSuccessCount[testTask1])
|
require.Equal(t, 2, resChanCount[testTask1], "expected two task call for type: %d had %d", 1, taskSuccessCount[testTask1])
|
||||||
require.Equal(t, 2, resChanCount[testTask2], "expected tow task call for type: %d had %d", 2, taskSuccessCount[testTask2])
|
require.Equal(t, 2, resChanCount[testTask2], "expected two task call for type: %d had %d", 2, taskSuccessCount[testTask2])
|
||||||
require.Equal(t, 1, resChanCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
|
require.Equal(t, 1, resChanCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -392,3 +396,34 @@ func TestScheduler_Enqueue_InResult(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScheduler_Enqueue_Quick_Stop(t *testing.T) {
|
||||||
|
scheduler := NewScheduler()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
longRunningTask := func(ctx context.Context) (interface{}, error) {
|
||||||
|
defer wg.Done()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// we should reach here rather than other condition branch as Stop() canceled the context quickly
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
return "task completed", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resFn := func(res interface{}, taskType TaskType, err error) {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
scheduler.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld}, longRunningTask, resFn)
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
scheduler.Stop()
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue