package clock import ( "context" "fmt" "sync" "time" ) func (m *Mock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { return m.WithDeadline(parent, m.Now().Add(timeout)) } func (m *Mock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { // The current deadline is already sooner than the new one. return context.WithCancel(parent) } ctx := &timerCtx{clock: m, parent: parent, deadline: deadline, done: make(chan struct{})} propagateCancel(parent, ctx) dur := m.Until(deadline) if dur <= 0 { ctx.cancel(context.DeadlineExceeded) // deadline has already passed return ctx, func() {} } ctx.Lock() defer ctx.Unlock() if ctx.err == nil { ctx.timer = m.AfterFunc(dur, func() { ctx.cancel(context.DeadlineExceeded) }) } return ctx, func() { ctx.cancel(context.Canceled) } } // propagateCancel arranges for child to be canceled when parent is. func propagateCancel(parent context.Context, child *timerCtx) { if parent.Done() == nil { return // parent is never canceled } go func() { select { case <-parent.Done(): child.cancel(parent.Err()) case <-child.Done(): } }() } type timerCtx struct { sync.Mutex clock Clock parent context.Context deadline time.Time done chan struct{} err error timer *Timer } func (c *timerCtx) cancel(err error) { c.Lock() defer c.Unlock() if c.err != nil { return // already canceled } c.err = err close(c.done) if c.timer != nil { c.timer.Stop() c.timer = nil } } func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, true } func (c *timerCtx) Done() <-chan struct{} { return c.done } func (c *timerCtx) Err() error { return c.err } func (c *timerCtx) Value(key interface{}) interface{} { return c.parent.Value(key) } func (c *timerCtx) String() string { return fmt.Sprintf("clock.WithDeadline(%s [%s])", c.deadline, c.deadline.Sub(c.clock.Now())) }