diff --git a/services/wallet/async/scheduler.go b/services/wallet/async/scheduler.go index c20c24d69..b750595ca 100644 --- a/services/wallet/async/scheduler.go +++ b/services/wallet/async/scheduler.go @@ -11,6 +11,7 @@ import ( var ErrTaskOverwritten = errors.New("task overwritten") +// Scheduler ensures that only one task of a type is running at a time. type Scheduler struct { queue *orderedmap.OrderedMap[TaskType, *taskContext] queueMutex sync.Mutex @@ -71,11 +72,20 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result taskRunning := s.queue.Len() > 0 existingTask, typeInQueue := s.queue.Get(taskType) + // we need wrap the original resFn to ensure it is called only once + // otherwise, there's a chance that it will be called twice if we + // call Stop() quickly after Enqueue while task is running + var invokeResFnOnce sync.Once + onceResFn := func(res interface{}, taskType TaskType, err error) { + invokeResFnOnce.Do(func() { + resFn(res, taskType, err) + }) + } newTask := &taskContext{ taskType: taskType, policy: taskType.Policy, taskFn: taskFn, - resFn: resFn, + resFn: onceResFn, } if taskRunning { @@ -106,12 +116,13 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result // if other task type is running // notify the queued one that it is overwritten or ignored if existingTask.policy == ReplacementPolicyCancelOld { + oldResFn := existingTask.resFn go func() { - existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten) + oldResFn(nil, existingTask.taskType, ErrTaskOverwritten) }() // Overwrite the queued one of the same type existingTask.taskFn = taskFn - existingTask.resFn = resFn + existingTask.resFn = onceResFn } else { ignored = true } @@ -125,7 +136,7 @@ func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn result s.queue.Set(taskType, newTask) existingTask = newTask s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) { - s.finishedTask(res, runningTask, resFn, err) + s.finishedTask(res, runningTask, onceResFn, err) }) } @@ -156,13 +167,15 @@ func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(int func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) { s.queueMutex.Lock() - // We always have a running task current := s.queue.Oldest() // Delete current task if not overwritten if s.doNotDeleteCurrentTask { s.doNotDeleteCurrentTask = false } else { - s.queue.Delete(current.Value.taskType) + // current maybe nil if Stop() is called + if current != nil { + s.queue.Delete(current.Value.taskType) + } } // Run next task diff --git a/services/wallet/async/scheduler_test.go b/services/wallet/async/scheduler_test.go index 4f55b621a..5532b7491 100644 --- a/services/wallet/async/scheduler_test.go +++ b/services/wallet/async/scheduler_test.go @@ -287,6 +287,7 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) { if p.taskType.Policy == ReplacementPolicyCancelOld && ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) { taskCanceledChan <- p + t.Logf("task canceled, task seq: %d, task type: %+v", currentIndex, p.taskType) return nil, ctx.Err() } @@ -295,13 +296,16 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) { return nil, errors.New("test error") } taskSuccessChan <- p + t.Logf("task executed successfully, task seq: %d, task type: %+v", currentIndex, p.taskType) return 10 * (currentIndex + 1), nil }, func(res interface{}, taskType TaskType, err error) { require.Equal(t, p.taskType, taskType) resChan <- p + t.Logf("response invoked, task seq: %d, task type: %+v, result: %+v", currentIndex, taskType, res) }) if ignored { + t.Logf("task ignored, task seq: %d, task type: %+v", currentIndex, p.taskType) ignoredCount++ } @@ -355,7 +359,7 @@ func TestScheduler_Enqueue_ValidateOrder(t *testing.T) { require.Equal(t, 1, taskFailedCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3]) require.Equal(t, 2, resChanCount[testTask1], "expected two task call for type: %d had %d", 1, taskSuccessCount[testTask1]) - require.Equal(t, 2, resChanCount[testTask2], "expected tow task call for type: %d had %d", 2, taskSuccessCount[testTask2]) + require.Equal(t, 2, resChanCount[testTask2], "expected two task call for type: %d had %d", 2, taskSuccessCount[testTask2]) require.Equal(t, 1, resChanCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3]) } @@ -392,3 +396,34 @@ func TestScheduler_Enqueue_InResult(t *testing.T) { } } } + +func TestScheduler_Enqueue_Quick_Stop(t *testing.T) { + scheduler := NewScheduler() + + var wg sync.WaitGroup + wg.Add(2) + + longRunningTask := func(ctx context.Context) (interface{}, error) { + defer wg.Done() + select { + case <-ctx.Done(): + // we should reach here rather than other condition branch as Stop() canceled the context quickly + return nil, ctx.Err() + case <-time.After(10 * time.Second): + return "task completed", nil + } + } + + resFn := func(res interface{}, taskType TaskType, err error) { + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + wg.Done() + } + + scheduler.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld}, longRunningTask, resFn) + + require.NotPanics(t, func() { + scheduler.Stop() + wg.Wait() + }) +}