package async

import (
	"context"
	"errors"
	"fmt"
	"sync"

	orderedmap "github.com/wk8/go-ordered-map/v2"
)

var ErrTaskOverwritten = errors.New("task overwritten")

// Scheduler ensures that only one task of a type is running at a time.
type Scheduler struct {
	queue      *orderedmap.OrderedMap[TaskType, *taskContext]
	queueMutex sync.Mutex

	context                context.Context
	cancelFn               context.CancelFunc
	doNotDeleteCurrentTask bool
}

type ReplacementPolicy = int

const (
	// ReplacementPolicyCancelOld for when the task arguments might change the result
	ReplacementPolicyCancelOld ReplacementPolicy = iota
	// ReplacementPolicyIgnoreNew for when the task arguments doesn't change the result
	ReplacementPolicyIgnoreNew
)

type TaskType struct {
	ID     int64
	Policy ReplacementPolicy
}

type taskFunction func(context.Context) (interface{}, error)
type resultFunction func(interface{}, TaskType, error)

type taskContext struct {
	taskType TaskType
	policy   ReplacementPolicy

	taskFn taskFunction
	resFn  resultFunction
}

func NewScheduler() *Scheduler {
	return &Scheduler{
		queue: orderedmap.New[TaskType, *taskContext](),
	}
}

// Enqueue provides a queue of task types allowing only one task at a time of the corresponding type. The running task is the first one in the queue (s.queue.Oldest())
//
// Schedule policy for new tasks
//   - pushed at the back of the queue (s.queue.PushBack()) if none of the same time already scheduled
//   - overwrite the queued one of the same type, depending on the policy
//   - In case of ReplacementPolicyIgnoreNew, the new task will be ignored
//   - In case of ReplacementPolicyCancelOld, the old running task will be canceled or if not yet run overwritten and the new one will be executed when its turn comes.
//
// The task function (taskFn) might not be executed if
//   - the task is ignored
//   - the task is overwritten. The result function (resFn) will be called with ErrTaskOverwritten
//
// The result function (resFn) will always be called if the task is not ignored
func (s *Scheduler) Enqueue(taskType TaskType, taskFn taskFunction, resFn resultFunction) (ignored bool) {
	s.queueMutex.Lock()
	defer s.queueMutex.Unlock()

	taskRunning := s.queue.Len() > 0
	existingTask, typeInQueue := s.queue.Get(taskType)

	// we need wrap the original resFn to ensure it is called only once
	// otherwise, there's a chance that it will be called twice if we
	// call Stop() quickly after Enqueue while task is running
	var invokeResFnOnce sync.Once
	onceResFn := func(res interface{}, taskType TaskType, err error) {
		invokeResFnOnce.Do(func() {
			resFn(res, taskType, err)
		})
	}
	newTask := &taskContext{
		taskType: taskType,
		policy:   taskType.Policy,
		taskFn:   taskFn,
		resFn:    onceResFn,
	}

	if taskRunning {
		if typeInQueue {
			if s.queue.Oldest().Value.taskType == taskType {
				// If same task type is running
				if existingTask.policy == ReplacementPolicyCancelOld {
					// If a previous task is running, cancel it
					if s.cancelFn != nil {
						s.cancelFn()
						s.cancelFn = nil
					} else {
						// In case of multiple tasks of the same type, the previous one is overwritten
						go func() {
							existingTask.resFn(nil, existingTask.taskType, ErrTaskOverwritten)
						}()
					}

					s.doNotDeleteCurrentTask = true

					// Add it again to refresh the order of the task
					s.queue.Delete(taskType)
					s.queue.Set(taskType, newTask)
				} else {
					ignored = true
				}
			} else {
				// if other task type is running
				// notify the queued one that it is overwritten or ignored
				if existingTask.policy == ReplacementPolicyCancelOld {
					oldResFn := existingTask.resFn
					go func() {
						oldResFn(nil, existingTask.taskType, ErrTaskOverwritten)
					}()
					// Overwrite the queued one of the same type
					existingTask.taskFn = taskFn
					existingTask.resFn = onceResFn
				} else {
					ignored = true
				}
			}
		} else {
			// Policy does not matter for the fist enqueued task of a type
			s.queue.Set(taskType, newTask)
		}
	} else {
		// If no task is running add and run it. The worker will take care of scheduling new tasks added while running
		s.queue.Set(taskType, newTask)
		existingTask = newTask
		s.runTask(existingTask, taskFn, func(res interface{}, runningTask *taskContext, err error) {
			s.finishedTask(res, runningTask, onceResFn, err)
		})
	}

	return ignored
}

func (s *Scheduler) runTask(tc *taskContext, taskFn taskFunction, resFn func(interface{}, *taskContext, error)) {
	thisContext, thisCancelFn := context.WithCancel(context.Background())
	s.cancelFn = thisCancelFn
	s.context = thisContext

	go func() {
		res, err := taskFn(thisContext)

		// Release context resources
		thisCancelFn()

		if errors.Is(err, context.Canceled) {
			resFn(res, tc, fmt.Errorf("task canceled: %w", err))
		} else {
			resFn(res, tc, err)
		}
	}()
}

// finishedTask is the only one that can remove a task from the queue
// if the current running task completed (doNotDeleteCurrentTask is true)
func (s *Scheduler) finishedTask(finishedRes interface{}, doneTask *taskContext, finishedResFn resultFunction, finishedErr error) {
	s.queueMutex.Lock()

	current := s.queue.Oldest()
	// Delete current task if not overwritten
	if s.doNotDeleteCurrentTask {
		s.doNotDeleteCurrentTask = false
	} else {
		// current maybe nil if Stop() is called
		if current != nil {
			s.queue.Delete(current.Value.taskType)
		}
	}

	// Run next task
	if pair := s.queue.Oldest(); pair != nil {
		nextTask := pair.Value
		s.runTask(nextTask, nextTask.taskFn, func(res interface{}, runningTask *taskContext, err error) {
			s.finishedTask(res, runningTask, runningTask.resFn, err)
		})
	} else {
		s.cancelFn = nil
	}
	s.queueMutex.Unlock()

	// Report result
	finishedResFn(finishedRes, doneTask.taskType, finishedErr)
}

func (s *Scheduler) Stop() {
	s.queueMutex.Lock()
	defer s.queueMutex.Unlock()

	if s.cancelFn != nil {
		s.cancelFn()
		s.cancelFn = nil
	}

	// Empty the queue so the running task will not be restarted
	for pair := s.queue.Oldest(); pair != nil; pair = pair.Next() {
		// Notify the queued one that they are canceled
		if pair.Value.policy == ReplacementPolicyCancelOld {
			go func(val *taskContext) {
				val.resFn(nil, val.taskType, context.Canceled)
			}(pair.Value)
		}
		s.queue.Delete(pair.Value.taskType)
	}
}