From e54567223b7a4f1a9a41389b76d44cf3406582ff Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 1 Oct 2020 01:14:21 -0400 Subject: [PATCH] lib/retry: Refactor to reduce the interface surface Reduce Jitter to one function Rename NewRetryWaiter Fix a bug in calculateWait where maxWait was applied before jitter, which would make it possible to wait longer than maxWait. --- agent/auto-config/auto_config.go | 34 ++-- agent/auto-config/auto_config_test.go | 2 +- agent/auto-config/auto_encrypt.go | 28 ++- agent/auto-config/auto_encrypt_test.go | 2 +- agent/auto-config/config.go | 6 +- agent/consul/replication.go | 33 ++- lib/retry/retry.go | 201 +++++++------------ lib/retry/retry_test.go | 268 +++++++++++-------------- 8 files changed, 243 insertions(+), 331 deletions(-) diff --git a/agent/auto-config/auto_config.go b/agent/auto-config/auto_config.go index 15988256f5..7d2f8e9617 100644 --- a/agent/auto-config/auto_config.go +++ b/agent/auto-config/auto_config.go @@ -85,7 +85,11 @@ func New(config Config) (*AutoConfig, error) { } if config.Waiter == nil { - config.Waiter = retry.NewRetryWaiter(1, 0, 10*time.Minute, retry.NewJitterRandomStagger(25)) + config.Waiter = &retry.Waiter{ + MinFailures: 1, + MaxWait: 10 * time.Minute, + Jitter: retry.NewJitter(25), + } } return &AutoConfig{ @@ -306,23 +310,21 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf. return nil, err } - // this resets the failures so that we will perform immediate request - wait := ac.acConfig.Waiter.Success() + ac.acConfig.Waiter.Reset() for { - select { - case <-wait: - if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil { - return resp, nil - } else if err != nil { - ac.logger.Error(err.Error()) - } else { - ac.logger.Error("No error returned when fetching configuration from the servers but no response was either") - } + resp, err := ac.getInitialConfigurationOnce(ctx, csr, key) + switch { + case err == nil && resp != nil: + return resp, nil + case err != nil: + ac.logger.Error(err.Error()) + default: + ac.logger.Error("No error returned when fetching configuration from the servers but no response was either") + } - wait = ac.acConfig.Waiter.Failed() - case <-ctx.Done(): - ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err()) - return nil, ctx.Err() + if err := ac.acConfig.Waiter.Wait(ctx); err != nil { + ac.logger.Info("interrupted during initial auto configuration", "err", err) + return nil, err } } } diff --git a/agent/auto-config/auto_config_test.go b/agent/auto-config/auto_config_test.go index 74f9d46209..b8ab0caf45 100644 --- a/agent/auto-config/auto_config_test.go +++ b/agent/auto-config/auto_config_test.go @@ -413,7 +413,7 @@ func TestInitialConfiguration_retries(t *testing.T) { mcfg.Config.Loader = loader.Load // reduce the retry wait times to make this test run faster - mcfg.Config.Waiter = retry.NewWaiter(2, 0, 1*time.Millisecond, nil) + mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond} indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret") diff --git a/agent/auto-config/auto_encrypt.go b/agent/auto-config/auto_encrypt.go index 2290bb332b..f2bf424aa5 100644 --- a/agent/auto-config/auto_encrypt.go +++ b/agent/auto-config/auto_encrypt.go @@ -16,23 +16,21 @@ func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.Sig return nil, err } - // this resets the failures so that we will perform immediate request - wait := ac.acConfig.Waiter.Success() + ac.acConfig.Waiter.Reset() for { - select { - case <-wait: - if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil { - return resp, nil - } else if err != nil { - ac.logger.Error(err.Error()) - } else { - ac.logger.Error("No error returned when fetching certificates from the servers but no response was either") - } + resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key) + switch { + case err == nil && resp != nil: + return resp, nil + case err != nil: + ac.logger.Error(err.Error()) + default: + ac.logger.Error("No error returned when fetching certificates from the servers but no response was either") + } - wait = ac.acConfig.Waiter.Failed() - case <-ctx.Done(): - ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err()) - return nil, ctx.Err() + if err := ac.acConfig.Waiter.Wait(ctx); err != nil { + ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", err) + return nil, err } } } diff --git a/agent/auto-config/auto_encrypt_test.go b/agent/auto-config/auto_encrypt_test.go index 2a04173826..7b929a7c3f 100644 --- a/agent/auto-config/auto_encrypt_test.go +++ b/agent/auto-config/auto_encrypt_test.go @@ -248,7 +248,7 @@ func TestAutoEncrypt_InitialCerts(t *testing.T) { resp.VerifyServerHostname = true }) - mcfg.Config.Waiter = retry.NewRetryWaiter(2, 0, 1*time.Millisecond, nil) + mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond} ac := AutoConfig{ config: &config.RuntimeConfig{ diff --git a/agent/auto-config/config.go b/agent/auto-config/config.go index 34726097bd..34f7484e69 100644 --- a/agent/auto-config/config.go +++ b/agent/auto-config/config.go @@ -68,12 +68,12 @@ type Config struct { // known servers during fallback operations. ServerProvider ServerProvider - // Waiter is a RetryWaiter to be used during retrieval of the - // initial configuration. When a round of requests fails we will + // Waiter is used during retrieval of the initial configuration. + // When around of requests fails we will // wait and eventually make another round of requests (1 round // is trying the RPC once against each configured server addr). The // waiting implements some backoff to prevent from retrying these RPCs - // to often. This field is not required and if left unset a waiter will + // too often. This field is not required and if left unset a waiter will // be used that has a max wait duration of 10 minutes and a randomized // jitter of 25% of the wait time. Setting this is mainly useful for // testing purposes to allow testing out the retrying functionality without diff --git a/agent/consul/replication.go b/agent/consul/replication.go index 3764caf227..910f18871c 100644 --- a/agent/consul/replication.go +++ b/agent/consul/replication.go @@ -18,8 +18,6 @@ const ( // replicationMaxRetryWait is the maximum number of seconds to wait between // failed blocking queries when backing off. replicationDefaultMaxRetryWait = 120 * time.Second - - replicationDefaultRate = 1 ) type ReplicatorDelegate interface { @@ -36,7 +34,7 @@ type ReplicatorConfig struct { // The number of replication rounds that can be done in a burst Burst int // Minimum number of RPC failures to ignore before backing off - MinFailures int + MinFailures uint // Maximum wait time between failing RPCs MaxRetryWait time.Duration // Where to send our logs @@ -71,12 +69,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) { if maxWait == 0 { maxWait = replicationDefaultMaxRetryWait } - - minFailures := config.MinFailures - if minFailures < 0 { - minFailures = 0 + waiter := &retry.Waiter{ + MinFailures: config.MinFailures, + MaxWait: maxWait, + Jitter: retry.NewJitter(10), } - waiter := retry.NewRetryWaiter(minFailures, 0*time.Second, maxWait, retry.NewJitterRandomStagger(10)) return &Replicator{ limiter: limiter, waiter: waiter, @@ -100,10 +97,8 @@ func (r *Replicator) Run(ctx context.Context) error { // Perform a single round of replication index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger) if exit { - // the replication function told us to exit return nil } - if err != nil { // reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during // the next round of replication @@ -112,18 +107,16 @@ func (r *Replicator) Run(ctx context.Context) error { if r.suppressErrorLog != nil && !r.suppressErrorLog(err) { r.logger.Warn("replication error (will retry if still leader)", "error", err) } - } else { - atomic.StoreUint64(&r.lastRemoteIndex, index) - r.logger.Debug("replication completed through remote index", "index", index) + + if err := r.waiter.Wait(ctx); err != nil { + return nil + } + continue } - select { - case <-ctx.Done(): - return nil - // wait some amount of time to prevent churning through many replication rounds while replication is failing - case <-r.waiter.WaitIfErr(err): - // do nothing - } + atomic.StoreUint64(&r.lastRemoteIndex, index) + r.logger.Debug("replication completed through remote index", "index", index) + r.waiter.Reset() } } diff --git a/lib/retry/retry.go b/lib/retry/retry.go index e344ff28fd..17bbeaf83f 100644 --- a/lib/retry/retry.go +++ b/lib/retry/retry.go @@ -1,9 +1,9 @@ package retry import ( + "context" + "math/rand" "time" - - "github.com/hashicorp/consul/lib" ) const ( @@ -11,153 +11,96 @@ const ( defaultMaxWait = 2 * time.Minute ) -// Interface used for offloading jitter calculations from the RetryWaiter -type Jitter interface { - AddJitter(baseTime time.Duration) time.Duration -} +// Jitter should return a new wait duration optionally with some time added or +// removed to create some randomness in wait time. +type Jitter func(baseTime time.Duration) time.Duration -// Calculates a random jitter between 0 and up to a specific percentage of the baseTime -type JitterRandomStagger struct { - // int64 because we are going to be doing math against an int64 to represent nanoseconds - percent int64 -} - -// Creates a new JitterRandomStagger -func NewJitterRandomStagger(percent int) *JitterRandomStagger { +// NewJitter returns a new random Jitter that is up to percent longer than the +// original wait time. +func NewJitter(percent int64) Jitter { if percent < 0 { percent = 0 } - return &JitterRandomStagger{ - percent: int64(percent), + return func(baseTime time.Duration) time.Duration { + if percent == 0 { + return baseTime + } + max := (int64(baseTime) * percent) / 100 + if max < 0 { // overflow + return baseTime + } + return baseTime + time.Duration(rand.Int63n(max)) } } -// Implments the Jitter interface -func (j *JitterRandomStagger) AddJitter(baseTime time.Duration) time.Duration { - if j.percent == 0 { - return baseTime - } - - // time.Duration is actually a type alias for int64 which is why casting - // to the duration type and then dividing works - return baseTime + lib.RandomStagger((baseTime*time.Duration(j.percent))/100) -} - -// RetryWaiter will record failed and successful operations and provide -// a channel to wait on before a failed operation can be retried. +// Waiter records the number of failures and performs exponential backoff when +// when there are consecutive failures. type Waiter struct { + // MinFailures before exponential backoff starts. Any failures before + // MinFailures is reached will wait MinWait time. MinFailures uint - MinWait time.Duration - MaxWait time.Duration - Jitter Jitter - failures uint + // MinWait time. Returned after the first failure. + MinWait time.Duration + // MaxWait time. + MaxWait time.Duration + // Jitter to add to each wait time. + Jitter Jitter + // Factor is the multiplier to use when calculating the delay. Defaults to + // 1 second. + Factor time.Duration + failures uint } -// Creates a new RetryWaiter -func NewRetryWaiter(minFailures int, minWait, maxWait time.Duration, jitter Jitter) *Waiter { - if minFailures < 0 { - minFailures = defaultMinFailures +// delay calculates the time to wait based on the number of failures +func (w *Waiter) delay() time.Duration { + if w.failures <= w.MinFailures { + return w.MinWait + } + factor := w.Factor + if factor == 0 { + factor = time.Second } - if maxWait <= 0 { - maxWait = defaultMaxWait + shift := w.failures - w.MinFailures - 1 + waitTime := w.MaxWait + if shift < 31 { + waitTime = (1 << shift) * factor } - - if minWait <= 0 { - minWait = 0 * time.Nanosecond + if w.Jitter != nil { + waitTime = w.Jitter(waitTime) } - - return &Waiter{ - MinFailures: uint(minFailures), - MinWait: minWait, - MaxWait: maxWait, - failures: 0, - Jitter: jitter, + if w.MaxWait != 0 && waitTime > w.MaxWait { + return w.MaxWait } -} - -// calculates the necessary wait time before the -// next operation should be allowed. -func (rw *Waiter) calculateWait() time.Duration { - waitTime := rw.MinWait - if rw.failures > rw.MinFailures { - shift := rw.failures - rw.MinFailures - 1 - waitTime = rw.MaxWait - if shift < 31 { - waitTime = (1 << shift) * time.Second - } - if waitTime > rw.MaxWait { - waitTime = rw.MaxWait - } - - if rw.Jitter != nil { - waitTime = rw.Jitter.AddJitter(waitTime) - } + if waitTime < w.MinWait { + return w.MinWait } - - if waitTime < rw.MinWait { - waitTime = rw.MinWait - } - return waitTime } -// calculates the waitTime and returns a chan -// that will become selectable once that amount -// of time has elapsed. -func (rw *Waiter) wait() <-chan struct{} { - waitTime := rw.calculateWait() - ch := make(chan struct{}) - if waitTime > 0 { - time.AfterFunc(waitTime, func() { close(ch) }) - } else { - // if there should be 0 wait time then we ensure - // that the chan will be immediately selectable - close(ch) +// Reset the failure count to 0. +func (w *Waiter) Reset() { + w.failures = 0 +} + +// Failures returns the count of consecutive failures. +func (w *Waiter) Failures() int { + return int(w.failures) +} + +// Wait increase the number of failures by one, and then blocks until the context +// is cancelled, or until the wait time is reached. +// The wait time increases exponentially as the number of failures increases. +// Wait will return ctx.Err() if the context is cancelled. +func (w *Waiter) Wait(ctx context.Context) error { + w.failures++ + timer := time.NewTimer(w.delay()) + select { + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + case <-timer.C: + return nil } - return ch -} - -// Marks that an operation is successful which resets the failure count. -// The chan that is returned will be immediately selectable -func (rw *Waiter) Success() <-chan struct{} { - rw.Reset() - return rw.wait() -} - -// Marks that an operation failed. The chan returned will be selectable -// once the calculated retry wait amount of time has elapsed -func (rw *Waiter) Failed() <-chan struct{} { - rw.failures += 1 - ch := rw.wait() - return ch -} - -// Resets the internal failure counter. -func (rw *Waiter) Reset() { - rw.failures = 0 -} - -// Failures returns the current number of consecutive failures recorded. -func (rw *Waiter) Failures() int { - return int(rw.failures) -} - -// WaitIf is a convenice method to record whether the last -// operation was a success or failure and return a chan that -// will be selectablw when the next operation can be done. -func (rw *Waiter) WaitIf(failure bool) <-chan struct{} { - if failure { - return rw.Failed() - } - return rw.Success() -} - -// WaitIfErr is a convenience method to record whether the last -// operation was a success or failure based on whether the err -// is nil and then return a chan that will be selectable when -// the next operation can be done. -func (rw *Waiter) WaitIfErr(err error) <-chan struct{} { - return rw.WaitIf(err != nil) } diff --git a/lib/retry/retry_test.go b/lib/retry/retry_test.go index e1cc776e76..92983e6d16 100644 --- a/lib/retry/retry_test.go +++ b/lib/retry/retry_test.go @@ -1,184 +1,160 @@ package retry import ( - "fmt" + "context" + "math" "testing" "time" "github.com/stretchr/testify/require" ) -func TestJitterRandomStagger(t *testing.T) { - t.Parallel() - - t.Run("0 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(0) +func TestJitter(t *testing.T) { + repeat(t, "0 percent", func(t *testing.T) { + jitter := NewJitter(0) for i := 0; i < 10; i++ { baseTime := time.Duration(i) * time.Second - require.Equal(t, baseTime, jitter.AddJitter(baseTime)) + require.Equal(t, baseTime, jitter(baseTime)) } }) - t.Run("10 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(10) - for i := 0; i < 10; i++ { - baseTime := 5000 * time.Millisecond - maxTime := 5500 * time.Millisecond - newTime := jitter.AddJitter(baseTime) - require.True(t, newTime > baseTime) - require.True(t, newTime <= maxTime) - } + repeat(t, "10 percent", func(t *testing.T) { + jitter := NewJitter(10) + baseTime := 5000 * time.Millisecond + maxTime := 5500 * time.Millisecond + newTime := jitter(baseTime) + require.True(t, newTime > baseTime) + require.True(t, newTime <= maxTime) }) - t.Run("100 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(100) - for i := 0; i < 10; i++ { - baseTime := 1234 * time.Millisecond - maxTime := 2468 * time.Millisecond - newTime := jitter.AddJitter(baseTime) - require.True(t, newTime > baseTime) - require.True(t, newTime <= maxTime) + repeat(t, "100 percent", func(t *testing.T) { + jitter := NewJitter(100) + baseTime := 1234 * time.Millisecond + maxTime := 2468 * time.Millisecond + newTime := jitter(baseTime) + require.True(t, newTime > baseTime) + require.True(t, newTime <= maxTime) + }) + + repeat(t, "overflow", func(t *testing.T) { + jitter := NewJitter(100) + baseTime := time.Duration(math.MaxInt64) - 2*time.Hour + newTime := jitter(baseTime) + require.Equal(t, baseTime, newTime) + }) +} + +func repeat(t *testing.T, name string, fn func(t *testing.T)) { + t.Run(name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + fn(t) } }) } -func TestRetryWaiter_calculateWait(t *testing.T) { - t.Parallel() - - t.Run("Defaults", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 0, nil) - - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 2*time.Second, rw.calculateWait()) - rw.failures = 31 - require.Equal(t, defaultMaxWait, rw.calculateWait()) - }) - - t.Run("Minimum Wait", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 5*time.Second, 0, nil) - - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 8*time.Second, rw.calculateWait()) - }) - - t.Run("Minimum Failures", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(5, 0, 0, nil) - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 5 - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - }) - - t.Run("Maximum Wait", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 5*time.Second, nil) - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 2*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 4*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures = 31 - require.Equal(t, 5*time.Second, rw.calculateWait()) - }) -} - -func TestRetryWaiter_WaitChans(t *testing.T) { - t.Parallel() - - t.Run("Minimum Wait - Success", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.Success(): - require.Fail(t, "minimum wait not respected") +func TestWaiter_Delay(t *testing.T) { + t.Run("zero value", func(t *testing.T) { + w := &Waiter{} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) } }) - t.Run("Minimum Wait - WaitIf", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.WaitIf(false): - require.Fail(t, "minimum wait not respected") + t.Run("with minimum wait", func(t *testing.T) { + w := &Waiter{MinWait: 5 * time.Second} + for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) } }) - t.Run("Minimum Wait - WaitIfErr", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.WaitIfErr(nil): - require.Fail(t, "minimum wait not respected") + t.Run("with maximum wait", func(t *testing.T) { + w := &Waiter{MaxWait: 20 * time.Second} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) } }) - t.Run("Maximum Wait - Failed", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.Failed(): + t.Run("with minimum failures", func(t *testing.T) { + w := &Waiter{MinFailures: 4} + for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) } }) - t.Run("Maximum Wait - WaitIf", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.WaitIf(true): + t.Run("with factor", func(t *testing.T) { + w := &Waiter{Factor: time.Millisecond} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i) } }) - t.Run("Maximum Wait - WaitIfErr", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.WaitIfErr(fmt.Errorf("Fake Error")): + t.Run("with all settings", func(t *testing.T) { + w := &Waiter{ + MinFailures: 2, + MinWait: 4 * time.Millisecond, + MaxWait: 20 * time.Millisecond, + Factor: time.Millisecond, + } + for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} { + w.failures = uint(i) + require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i) } }) } + +func TestWaiter_Wait(t *testing.T) { + ctx := context.Background() + + t.Run("first failure", func(t *testing.T) { + w := &Waiter{MinWait: time.Millisecond, Factor: 1} + elapsed, err := runWait(ctx, w) + require.NoError(t, err) + assertApproximateDuration(t, elapsed, time.Millisecond) + require.Equal(t, w.failures, uint(1)) + }) + + t.Run("max failures", func(t *testing.T) { + w := &Waiter{ + MaxWait: 100 * time.Millisecond, + failures: 200, + } + elapsed, err := runWait(ctx, w) + require.NoError(t, err) + assertApproximateDuration(t, elapsed, 100*time.Millisecond) + require.Equal(t, w.failures, uint(201)) + }) + + t.Run("context deadline", func(t *testing.T) { + w := &Waiter{failures: 200, MinWait: time.Second} + ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) + t.Cleanup(cancel) + + elapsed, err := runWait(ctx, w) + require.Equal(t, err, context.DeadlineExceeded) + assertApproximateDuration(t, elapsed, 5*time.Millisecond) + require.Equal(t, w.failures, uint(201)) + }) +} + +func runWait(ctx context.Context, w *Waiter) (time.Duration, error) { + before := time.Now() + err := w.Wait(ctx) + return time.Since(before), err +} + +func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) { + t.Helper() + delta := 20 * time.Millisecond + min, max := expected-delta, expected+delta + if min < 0 { + min = 0 + } + if actual < min || actual > max { + t.Fatalf("expected %v to be between %v and %v", actual, min, max) + } +}