package async import ( "context" "errors" "sync" "testing" "time" "github.com/stretchr/testify/require" ) const ( noActionPerformed = "no action performed" taskCalled = "task called" taskResultCalled = "task result called" ) func TestScheduler_Enqueue_Simple(t *testing.T) { s := NewScheduler() callChan := make(chan string, 10) testFunction := func(policy ReplacementPolicy, failTest bool) { testTask := TaskType{1, policy} ignored := s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) { callChan <- taskCalled if failTest { return nil, errors.New("test error") } return 123, nil }, func(res interface{}, taskType TaskType, err error) { if failTest { require.Error(t, err) require.Nil(t, res) } else { require.NoError(t, err) require.Equal(t, 123, res) } require.Equal(t, testTask, taskType) callChan <- taskResultCalled }) require.False(t, ignored) lastRes := noActionPerformed done := false for !done { select { case callRes := <-callChan: if callRes == taskCalled { require.Equal(t, noActionPerformed, lastRes) } else if callRes == taskResultCalled { require.Equal(t, taskCalled, lastRes) done = true } else { require.Fail(t, "unexpected result", `"%s" for policy %d`, callRes, policy) } lastRes = callRes case <-time.After(1 * time.Second): require.Fail(t, "test not completed in time", `last result: "%s" for policy %d`, lastRes, policy) } } require.Equal(t, taskResultCalled, lastRes) } testFailed := false for i := 0; i < 2; i++ { testFailed = (i == 0) for policy := range []ReplacementPolicy{ReplacementPolicyCancelOld, ReplacementPolicyIgnoreNew} { testFunction(policy, testFailed) } } } // Validate the task is cancelled when a new one is scheduled and that the third one will overwrite the second one func TestScheduler_Enqueue_VerifyReplacementPolicyCancelOld(t *testing.T) { s := NewScheduler() type testStage string const ( stage1FirstTaskStarted testStage = "First task started" stage2ThirdEnqueueOverwroteSecondTask testStage = "Third Enqueue overwrote second task" stage3ExitingFirstCancelledTask testStage = "Exiting first cancelled task" stage5ThirdTaskRunning testStage = "Third task running" stage6ThirdTaskResponse testStage = "Third task response" ) testStages := []testStage{ stage1FirstTaskStarted, stage2ThirdEnqueueOverwroteSecondTask, stage3ExitingFirstCancelledTask, stage5ThirdTaskRunning, stage6ThirdTaskResponse, } callChan := make(chan testStage, len(testStages)) var firstRunWG, secondRunWG, thirdRunWG sync.WaitGroup firstRunWG.Add(1) secondRunWG.Add(1) thirdRunWG.Add(1) stage4AsyncFirstTaskCanceledResponse := false testTask := TaskType{1, ReplacementPolicyCancelOld} for i := 0; i < 2; i++ { currentIndex := i ignored := s.Enqueue(testTask, func(workCtx context.Context) (interface{}, error) { callChan <- stage1FirstTaskStarted // Mark first task running so that the second Enqueue will cancel this one and overwrite it firstRunWG.Done() // Wait for the first task to be cancelled by the second one select { case <-workCtx.Done(): require.ErrorAs(t, workCtx.Err(), &context.Canceled) // Unblock the third Enqueue call secondRunWG.Done() // Block the second task from running until the third one is overwriting the second one that didn't run thirdRunWG.Wait() callChan <- stage3ExitingFirstCancelledTask case <-time.After(1 * time.Second): require.Fail(t, "task not cancelled in time") } return nil, workCtx.Err() }, func(res interface{}, taskType TaskType, err error) { switch currentIndex { case 0: // First task was cancelled by the second one Enqueue call stage4AsyncFirstTaskCanceledResponse = true require.ErrorAs(t, err, &context.Canceled) case 1: callChan <- stage2ThirdEnqueueOverwroteSecondTask // Unblock the first task from blocking execution of the third one // also validate that the third Enqueue call overwrote running the second one thirdRunWG.Done() require.True(t, errors.Is(err, ErrTaskOverwritten)) } }) require.False(t, ignored) // Wait first task to run firstRunWG.Wait() } // Wait for the second task to be cancelled before running the third one secondRunWG.Wait() ignored := s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) { callChan <- stage5ThirdTaskRunning return 123, errors.New("test error") }, func(res interface{}, taskType TaskType, err error) { require.Error(t, err) require.Equal(t, testTask, taskType) require.Equal(t, 123, res) callChan <- stage6ThirdTaskResponse }) require.False(t, ignored) lastRes := noActionPerformed expectedTestStageIndex := 0 for i := 0; i < len(testStages); i++ { select { case callRes := <-callChan: require.Equal(t, testStages[expectedTestStageIndex], callRes, "task stage out of order; expected %s, got %s", testStages[expectedTestStageIndex], callRes) expectedTestStageIndex++ case <-time.After(1 * time.Second): require.Fail(t, "test not completed in time", `last result: "%s" for cancel task policy`, lastRes) } } require.True(t, stage4AsyncFirstTaskCanceledResponse) } func TestScheduler_Enqueue_VerifyReplacementPolicyIgnoreNew(t *testing.T) { s := NewScheduler() callChan := make(chan string, 10) workloadWG := sync.WaitGroup{} taskCallCount := 0 resultCallCount := 0 workloadWG.Add(1) testTask := TaskType{1, ReplacementPolicyIgnoreNew} ignored := s.Enqueue(testTask, func(workCtx context.Context) (interface{}, error) { workloadWG.Wait() require.NoError(t, workCtx.Err()) taskCallCount++ callChan <- taskCalled return 123, nil }, func(res interface{}, taskType TaskType, err error) { require.NoError(t, err) require.Equal(t, testTask, taskType) require.Equal(t, 123, res) resultCallCount++ callChan <- taskResultCalled }) require.False(t, ignored) ignored = s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) { require.Fail(t, "unexpected call") return nil, errors.New("unexpected call") }, func(res interface{}, taskType TaskType, err error) { require.Fail(t, "unexpected result call") }) require.True(t, ignored) workloadWG.Done() lastRes := noActionPerformed done := false for !done { select { case callRes := <-callChan: if callRes == taskCalled { require.Equal(t, noActionPerformed, lastRes) } else if callRes == taskResultCalled { require.Equal(t, taskCalled, lastRes) done = true } else { require.Fail(t, "unexpected result", `"%s" for ignore task policy`, callRes) } lastRes = callRes case <-time.After(1 * time.Second): require.Fail(t, "test not completed in time", `last result: "%s" for ignore task policy`, lastRes) } } require.Equal(t, 1, resultCallCount) require.Equal(t, 1, taskCallCount) require.Equal(t, taskResultCalled, lastRes) } func TestScheduler_Enqueue_ValidateOrder(t *testing.T) { s := NewScheduler() waitEnqueueAll := sync.WaitGroup{} type failType bool const ( fail failType = true pass failType = false ) type enqueueParams struct { taskType TaskType taskAction failType callIndex int } testTask1 := TaskType{1, ReplacementPolicyCancelOld} testTask2 := TaskType{2, ReplacementPolicyCancelOld} testTask3 := TaskType{3, ReplacementPolicyIgnoreNew} // Task type, ReplacementPolicy: CancelOld if true IgnoreNew if false, task fail or success, index enqueueSequence := []enqueueParams{ {testTask1, pass, 0}, // 1 task event {testTask2, pass, 0}, // 0 task event {testTask3, fail, 0}, // 1 task event {testTask3, pass, 0}, // 0 task event {testTask2, pass, 0}, // 1 task event {testTask1, pass, 0}, // 1 task event {testTask3, fail, 0}, // 0 run event } const taskEventCount = 4 taskSuccessChan := make(chan enqueueParams, len(enqueueSequence)) taskCanceledChan := make(chan enqueueParams, len(enqueueSequence)) taskFailedChan := make(chan enqueueParams, len(enqueueSequence)) resChan := make(chan enqueueParams, len(enqueueSequence)) firstIgnoreNewProcessed := make(map[TaskType]bool) ignoredCount := 0 waitEnqueueAll.Add(1) for i := 0; i < len(enqueueSequence); i++ { enqueueSequence[i].callIndex = i p := enqueueSequence[i] currentIndex := i ignored := s.Enqueue(p.taskType, func(ctx context.Context) (interface{}, error) { waitEnqueueAll.Wait() if p.taskType.Policy == ReplacementPolicyCancelOld && ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) { taskCanceledChan <- p t.Logf("task canceled, task seq: %d, task type: %+v", currentIndex, p.taskType) return nil, ctx.Err() } if p.taskAction == fail { taskFailedChan <- p return nil, errors.New("test error") } taskSuccessChan <- p t.Logf("task executed successfully, task seq: %d, task type: %+v", currentIndex, p.taskType) return 10 * (currentIndex + 1), nil }, func(res interface{}, taskType TaskType, err error) { require.Equal(t, p.taskType, taskType) resChan <- p t.Logf("response invoked, task seq: %d, task type: %+v, result: %+v", currentIndex, taskType, res) }) if ignored { t.Logf("task ignored, task seq: %d, task type: %+v", currentIndex, p.taskType) ignoredCount++ } if _, ok := firstIgnoreNewProcessed[p.taskType]; !ok { require.False(t, ignored) firstIgnoreNewProcessed[p.taskType] = p.taskType.Policy == ReplacementPolicyCancelOld } else { if p.taskType.Policy == ReplacementPolicyIgnoreNew { require.True(t, ignored) } else { require.False(t, ignored) } } } waitEnqueueAll.Done() taskSuccessCount := make(map[TaskType]int) taskCanceledCount := make(map[TaskType]int) taskFailedCount := make(map[TaskType]int) resChanCount := make(map[TaskType]int) // Only ignored don't generate result events expectedEventsCount := len(enqueueSequence) - ignoredCount + taskEventCount for i := 0; i < expectedEventsCount; i++ { // Loop for run and result calls select { case p := <-taskSuccessChan: taskSuccessCount[p.taskType]++ case p := <-taskCanceledChan: taskCanceledCount[p.taskType]++ case p := <-taskFailedChan: taskFailedCount[p.taskType]++ case p := <-resChan: resChanCount[p.taskType]++ case <-time.After(1 * time.Second): require.Fail(t, "test not completed in time") } } require.Equal(t, 1, taskSuccessCount[testTask1], "expected one task call for type: %d had %d", 1, taskSuccessCount[testTask1]) require.Equal(t, 1, taskSuccessCount[testTask2], "expected one task call for type: %d had %d", 2, taskSuccessCount[testTask2]) require.Equal(t, 0, taskSuccessCount[testTask3], "expected no task call for type: %d had %d", 3, taskSuccessCount[testTask3]) require.Equal(t, 1, taskCanceledCount[testTask1], "expected one task call for type: %d had %d", 1, taskSuccessCount[testTask1]) require.Equal(t, 0, taskCanceledCount[testTask2], "expected no task call for type: %d had %d", 2, taskSuccessCount[testTask2]) require.Equal(t, 0, taskCanceledCount[testTask3], "expected no task call for type: %d had %d", 3, taskSuccessCount[testTask3]) require.Equal(t, 0, taskFailedCount[testTask1], "expected no task call for type: %d had %d", 1, taskSuccessCount[testTask1]) require.Equal(t, 0, taskFailedCount[testTask2], "expected no task call for type: %d had %d", 2, taskSuccessCount[testTask2]) require.Equal(t, 1, taskFailedCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3]) require.Equal(t, 2, resChanCount[testTask1], "expected two task call for type: %d had %d", 1, taskSuccessCount[testTask1]) require.Equal(t, 2, resChanCount[testTask2], "expected two task call for type: %d had %d", 2, taskSuccessCount[testTask2]) require.Equal(t, 1, resChanCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3]) } func TestScheduler_Enqueue_InResult(t *testing.T) { s := NewScheduler() callChan := make(chan int, 6) s.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld}, func(ctx context.Context) (interface{}, error) { callChan <- 0 return nil, nil }, func(res interface{}, taskType TaskType, err error) { callChan <- 1 s.Enqueue(TaskType{1, ReplacementPolicyCancelOld}, func(ctx context.Context) (interface{}, error) { callChan <- 2 return nil, nil }, func(res interface{}, taskType TaskType, err error) { callChan <- 3 s.Enqueue(TaskType{1, ReplacementPolicyCancelOld}, func(ctx context.Context) (interface{}, error) { callChan <- 4 return nil, nil }, func(res interface{}, taskType TaskType, err error) { callChan <- 5 }) }) }, ) for i := 0; i < 6; i++ { select { case res := <-callChan: require.Equal(t, i, res) case <-time.After(1 * time.Second): require.Fail(t, "test not completed in time") } } } func TestScheduler_Enqueue_Quick_Stop(t *testing.T) { scheduler := NewScheduler() var wg sync.WaitGroup wg.Add(2) longRunningTask := func(ctx context.Context) (interface{}, error) { defer wg.Done() select { case <-ctx.Done(): // we should reach here rather than other condition branch as Stop() canceled the context quickly return nil, ctx.Err() case <-time.After(10 * time.Second): return "task completed", nil } } resFn := func(res interface{}, taskType TaskType, err error) { require.Error(t, err) require.ErrorIs(t, err, context.Canceled) wg.Done() } scheduler.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld}, longRunningTask, resFn) require.NotPanics(t, func() { scheduler.Stop() wg.Wait() }) }