diff --git a/agent/auto-config/auto_config.go b/agent/auto-config/auto_config.go index 335f0f9872..7d2f8e9617 100644 --- a/agent/auto-config/auto_config.go +++ b/agent/auto-config/auto_config.go @@ -7,13 +7,14 @@ import ( "sync" "time" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/token" - "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/consul/logging" "github.com/hashicorp/consul/proto/pbautoconf" - "github.com/hashicorp/go-hclog" ) // AutoConfig is all the state necessary for being able to parse a configuration @@ -24,7 +25,7 @@ type AutoConfig struct { acConfig Config logger hclog.Logger cache Cache - waiter *lib.RetryWaiter + waiter *retry.Waiter config *config.RuntimeConfig autoConfigResponse *pbautoconf.AutoConfigResponse autoConfigSource config.Source @@ -84,7 +85,11 @@ func New(config Config) (*AutoConfig, error) { } if config.Waiter == nil { - config.Waiter = lib.NewRetryWaiter(1, 0, 10*time.Minute, lib.NewJitterRandomStagger(25)) + config.Waiter = &retry.Waiter{ + MinFailures: 1, + MaxWait: 10 * time.Minute, + Jitter: retry.NewJitter(25), + } } return &AutoConfig{ @@ -305,23 +310,21 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf. return nil, err } - // this resets the failures so that we will perform immediate request - wait := ac.acConfig.Waiter.Success() + ac.acConfig.Waiter.Reset() for { - select { - case <-wait: - if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil { - return resp, nil - } else if err != nil { - ac.logger.Error(err.Error()) - } else { - ac.logger.Error("No error returned when fetching configuration from the servers but no response was either") - } + resp, err := ac.getInitialConfigurationOnce(ctx, csr, key) + switch { + case err == nil && resp != nil: + return resp, nil + case err != nil: + ac.logger.Error(err.Error()) + default: + ac.logger.Error("No error returned when fetching configuration from the servers but no response was either") + } - wait = ac.acConfig.Waiter.Failed() - case <-ctx.Done(): - ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err()) - return nil, ctx.Err() + if err := ac.acConfig.Waiter.Wait(ctx); err != nil { + ac.logger.Info("interrupted during initial auto configuration", "err", err) + return nil, err } } } diff --git a/agent/auto-config/auto_config_test.go b/agent/auto-config/auto_config_test.go index e3469862a4..b8ab0caf45 100644 --- a/agent/auto-config/auto_config_test.go +++ b/agent/auto-config/auto_config_test.go @@ -11,6 +11,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" @@ -18,13 +21,11 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/token" - "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/consul/proto/pbautoconf" "github.com/hashicorp/consul/proto/pbconfig" "github.com/hashicorp/consul/sdk/testutil" - "github.com/hashicorp/consul/sdk/testutil/retry" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" + testretry "github.com/hashicorp/consul/sdk/testutil/retry" ) type configLoader struct { @@ -412,7 +413,7 @@ func TestInitialConfiguration_retries(t *testing.T) { mcfg.Config.Loader = loader.Load // reduce the retry wait times to make this test run faster - mcfg.Config.Waiter = lib.NewRetryWaiter(2, 0, 1*time.Millisecond, nil) + mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond} indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret") @@ -927,7 +928,7 @@ func TestRootsUpdate(t *testing.T) { // however there is no deterministic way to know once its been written outside of maybe a filesystem // event notifier. That seems a little heavy handed just for this and especially to do in any sort // of cross platform way. - retry.Run(t, func(r *retry.R) { + testretry.Run(t, func(r *testretry.R) { resp, err := testAC.ac.readPersistedAutoConfig() require.NoError(r, err) require.Equal(r, secondRoots.ActiveRootID, resp.CARoots.GetActiveRootID()) @@ -972,7 +973,7 @@ func TestCertUpdate(t *testing.T) { // persisting these to disk happens after all the things we would wait for in assertCertUpdated // will have fired. There is no deterministic way to know once its been written so we wrap // this in a retry. - retry.Run(t, func(r *retry.R) { + testretry.Run(t, func(r *testretry.R) { resp, err := testAC.ac.readPersistedAutoConfig() require.NoError(r, err) @@ -1099,7 +1100,7 @@ func TestFallback(t *testing.T) { // persisting these to disk happens after the RPC we waited on above will have fired // There is no deterministic way to know once its been written so we wrap this in a retry. - retry.Run(t, func(r *retry.R) { + testretry.Run(t, func(r *testretry.R) { resp, err := testAC.ac.readPersistedAutoConfig() require.NoError(r, err) diff --git a/agent/auto-config/auto_encrypt.go b/agent/auto-config/auto_encrypt.go index 2290bb332b..f2bf424aa5 100644 --- a/agent/auto-config/auto_encrypt.go +++ b/agent/auto-config/auto_encrypt.go @@ -16,23 +16,21 @@ func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.Sig return nil, err } - // this resets the failures so that we will perform immediate request - wait := ac.acConfig.Waiter.Success() + ac.acConfig.Waiter.Reset() for { - select { - case <-wait: - if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil { - return resp, nil - } else if err != nil { - ac.logger.Error(err.Error()) - } else { - ac.logger.Error("No error returned when fetching certificates from the servers but no response was either") - } + resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key) + switch { + case err == nil && resp != nil: + return resp, nil + case err != nil: + ac.logger.Error(err.Error()) + default: + ac.logger.Error("No error returned when fetching certificates from the servers but no response was either") + } - wait = ac.acConfig.Waiter.Failed() - case <-ctx.Done(): - ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err()) - return nil, ctx.Err() + if err := ac.acConfig.Waiter.Wait(ctx); err != nil { + ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", err) + return nil, err } } } diff --git a/agent/auto-config/auto_encrypt_test.go b/agent/auto-config/auto_encrypt_test.go index 867db9441f..7b929a7c3f 100644 --- a/agent/auto-config/auto_encrypt_test.go +++ b/agent/auto-config/auto_encrypt_test.go @@ -11,16 +11,17 @@ import ( "testing" "time" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/lib" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/consul/sdk/testutil" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func TestAutoEncrypt_generateCSR(t *testing.T) { @@ -247,7 +248,7 @@ func TestAutoEncrypt_InitialCerts(t *testing.T) { resp.VerifyServerHostname = true }) - mcfg.Config.Waiter = lib.NewRetryWaiter(2, 0, 1*time.Millisecond, nil) + mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond} ac := AutoConfig{ config: &config.RuntimeConfig{ diff --git a/agent/auto-config/config.go b/agent/auto-config/config.go index c812cae6a4..34f7484e69 100644 --- a/agent/auto-config/config.go +++ b/agent/auto-config/config.go @@ -5,12 +5,13 @@ import ( "net" "time" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/token" - "github.com/hashicorp/consul/lib" - "github.com/hashicorp/go-hclog" + "github.com/hashicorp/consul/lib/retry" ) // DirectRPC is the interface that needs to be satisifed for AutoConfig to be able to perform @@ -67,17 +68,17 @@ type Config struct { // known servers during fallback operations. ServerProvider ServerProvider - // Waiter is a RetryWaiter to be used during retrieval of the - // initial configuration. When a round of requests fails we will + // Waiter is used during retrieval of the initial configuration. + // When around of requests fails we will // wait and eventually make another round of requests (1 round // is trying the RPC once against each configured server addr). The // waiting implements some backoff to prevent from retrying these RPCs - // to often. This field is not required and if left unset a waiter will + // too often. This field is not required and if left unset a waiter will // be used that has a max wait duration of 10 minutes and a randomized // jitter of 25% of the wait time. Setting this is mainly useful for // testing purposes to allow testing out the retrying functionality without // having the test take minutes/hours to complete. - Waiter *lib.RetryWaiter + Waiter *retry.Waiter // Loader merges source with the existing FileSources and returns the complete // RuntimeConfig. diff --git a/agent/consul/replication.go b/agent/consul/replication.go index 39c9af4f15..910f18871c 100644 --- a/agent/consul/replication.go +++ b/agent/consul/replication.go @@ -7,18 +7,17 @@ import ( "time" metrics "github.com/armon/go-metrics" - "github.com/hashicorp/consul/lib" - "github.com/hashicorp/consul/logging" "github.com/hashicorp/go-hclog" "golang.org/x/time/rate" + + "github.com/hashicorp/consul/lib/retry" + "github.com/hashicorp/consul/logging" ) const ( // replicationMaxRetryWait is the maximum number of seconds to wait between // failed blocking queries when backing off. replicationDefaultMaxRetryWait = 120 * time.Second - - replicationDefaultRate = 1 ) type ReplicatorDelegate interface { @@ -35,7 +34,7 @@ type ReplicatorConfig struct { // The number of replication rounds that can be done in a burst Burst int // Minimum number of RPC failures to ignore before backing off - MinFailures int + MinFailures uint // Maximum wait time between failing RPCs MaxRetryWait time.Duration // Where to send our logs @@ -46,7 +45,7 @@ type ReplicatorConfig struct { type Replicator struct { limiter *rate.Limiter - waiter *lib.RetryWaiter + waiter *retry.Waiter delegate ReplicatorDelegate logger hclog.Logger lastRemoteIndex uint64 @@ -70,12 +69,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) { if maxWait == 0 { maxWait = replicationDefaultMaxRetryWait } - - minFailures := config.MinFailures - if minFailures < 0 { - minFailures = 0 + waiter := &retry.Waiter{ + MinFailures: config.MinFailures, + MaxWait: maxWait, + Jitter: retry.NewJitter(10), } - waiter := lib.NewRetryWaiter(minFailures, 0*time.Second, maxWait, lib.NewJitterRandomStagger(10)) return &Replicator{ limiter: limiter, waiter: waiter, @@ -99,10 +97,8 @@ func (r *Replicator) Run(ctx context.Context) error { // Perform a single round of replication index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger) if exit { - // the replication function told us to exit return nil } - if err != nil { // reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during // the next round of replication @@ -111,18 +107,16 @@ func (r *Replicator) Run(ctx context.Context) error { if r.suppressErrorLog != nil && !r.suppressErrorLog(err) { r.logger.Warn("replication error (will retry if still leader)", "error", err) } - } else { - atomic.StoreUint64(&r.lastRemoteIndex, index) - r.logger.Debug("replication completed through remote index", "index", index) + + if err := r.waiter.Wait(ctx); err != nil { + return nil + } + continue } - select { - case <-ctx.Done(): - return nil - // wait some amount of time to prevent churning through many replication rounds while replication is failing - case <-r.waiter.WaitIfErr(err): - // do nothing - } + atomic.StoreUint64(&r.lastRemoteIndex, index) + r.logger.Debug("replication completed through remote index", "index", index) + r.waiter.Reset() } } diff --git a/lib/retry.go b/lib/retry.go deleted file mode 100644 index 59cb91c753..0000000000 --- a/lib/retry.go +++ /dev/null @@ -1,156 +0,0 @@ -package lib - -import ( - "time" -) - -const ( - defaultMinFailures = 0 - defaultMaxWait = 2 * time.Minute -) - -// Interface used for offloading jitter calculations from the RetryWaiter -type Jitter interface { - AddJitter(baseTime time.Duration) time.Duration -} - -// Calculates a random jitter between 0 and up to a specific percentage of the baseTime -type JitterRandomStagger struct { - // int64 because we are going to be doing math against an int64 to represent nanoseconds - percent int64 -} - -// Creates a new JitterRandomStagger -func NewJitterRandomStagger(percent int) *JitterRandomStagger { - if percent < 0 { - percent = 0 - } - - return &JitterRandomStagger{ - percent: int64(percent), - } -} - -// Implments the Jitter interface -func (j *JitterRandomStagger) AddJitter(baseTime time.Duration) time.Duration { - if j.percent == 0 { - return baseTime - } - - // time.Duration is actually a type alias for int64 which is why casting - // to the duration type and then dividing works - return baseTime + RandomStagger((baseTime*time.Duration(j.percent))/100) -} - -// RetryWaiter will record failed and successful operations and provide -// a channel to wait on before a failed operation can be retried. -type RetryWaiter struct { - minFailures uint - minWait time.Duration - maxWait time.Duration - jitter Jitter - failures uint -} - -// Creates a new RetryWaiter -func NewRetryWaiter(minFailures int, minWait, maxWait time.Duration, jitter Jitter) *RetryWaiter { - if minFailures < 0 { - minFailures = defaultMinFailures - } - - if maxWait <= 0 { - maxWait = defaultMaxWait - } - - if minWait <= 0 { - minWait = 0 * time.Nanosecond - } - - return &RetryWaiter{ - minFailures: uint(minFailures), - minWait: minWait, - maxWait: maxWait, - failures: 0, - jitter: jitter, - } -} - -// calculates the necessary wait time before the -// next operation should be allowed. -func (rw *RetryWaiter) calculateWait() time.Duration { - waitTime := rw.minWait - if rw.failures > rw.minFailures { - shift := rw.failures - rw.minFailures - 1 - waitTime = rw.maxWait - if shift < 31 { - waitTime = (1 << shift) * time.Second - } - if waitTime > rw.maxWait { - waitTime = rw.maxWait - } - - if rw.jitter != nil { - waitTime = rw.jitter.AddJitter(waitTime) - } - } - - if waitTime < rw.minWait { - waitTime = rw.minWait - } - - return waitTime -} - -// calculates the waitTime and returns a chan -// that will become selectable once that amount -// of time has elapsed. -func (rw *RetryWaiter) wait() <-chan struct{} { - waitTime := rw.calculateWait() - ch := make(chan struct{}) - if waitTime > 0 { - time.AfterFunc(waitTime, func() { close(ch) }) - } else { - // if there should be 0 wait time then we ensure - // that the chan will be immediately selectable - close(ch) - } - return ch -} - -// Marks that an operation is successful which resets the failure count. -// The chan that is returned will be immediately selectable -func (rw *RetryWaiter) Success() <-chan struct{} { - rw.Reset() - return rw.wait() -} - -// Marks that an operation failed. The chan returned will be selectable -// once the calculated retry wait amount of time has elapsed -func (rw *RetryWaiter) Failed() <-chan struct{} { - rw.failures += 1 - ch := rw.wait() - return ch -} - -// Resets the internal failure counter -func (rw *RetryWaiter) Reset() { - rw.failures = 0 -} - -// WaitIf is a convenice method to record whether the last -// operation was a success or failure and return a chan that -// will be selectablw when the next operation can be done. -func (rw *RetryWaiter) WaitIf(failure bool) <-chan struct{} { - if failure { - return rw.Failed() - } - return rw.Success() -} - -// WaitIfErr is a convenience method to record whether the last -// operation was a success or failure based on whether the err -// is nil and then return a chan that will be selectable when -// the next operation can be done. -func (rw *RetryWaiter) WaitIfErr(err error) <-chan struct{} { - return rw.WaitIf(err != nil) -} diff --git a/lib/retry/retry.go b/lib/retry/retry.go new file mode 100644 index 0000000000..17bbeaf83f --- /dev/null +++ b/lib/retry/retry.go @@ -0,0 +1,106 @@ +package retry + +import ( + "context" + "math/rand" + "time" +) + +const ( + defaultMinFailures = 0 + defaultMaxWait = 2 * time.Minute +) + +// Jitter should return a new wait duration optionally with some time added or +// removed to create some randomness in wait time. +type Jitter func(baseTime time.Duration) time.Duration + +// NewJitter returns a new random Jitter that is up to percent longer than the +// original wait time. +func NewJitter(percent int64) Jitter { + if percent < 0 { + percent = 0 + } + + return func(baseTime time.Duration) time.Duration { + if percent == 0 { + return baseTime + } + max := (int64(baseTime) * percent) / 100 + if max < 0 { // overflow + return baseTime + } + return baseTime + time.Duration(rand.Int63n(max)) + } +} + +// Waiter records the number of failures and performs exponential backoff when +// when there are consecutive failures. +type Waiter struct { + // MinFailures before exponential backoff starts. Any failures before + // MinFailures is reached will wait MinWait time. + MinFailures uint + // MinWait time. Returned after the first failure. + MinWait time.Duration + // MaxWait time. + MaxWait time.Duration + // Jitter to add to each wait time. + Jitter Jitter + // Factor is the multiplier to use when calculating the delay. Defaults to + // 1 second. + Factor time.Duration + failures uint +} + +// delay calculates the time to wait based on the number of failures +func (w *Waiter) delay() time.Duration { + if w.failures <= w.MinFailures { + return w.MinWait + } + factor := w.Factor + if factor == 0 { + factor = time.Second + } + + shift := w.failures - w.MinFailures - 1 + waitTime := w.MaxWait + if shift < 31 { + waitTime = (1 << shift) * factor + } + if w.Jitter != nil { + waitTime = w.Jitter(waitTime) + } + if w.MaxWait != 0 && waitTime > w.MaxWait { + return w.MaxWait + } + if waitTime < w.MinWait { + return w.MinWait + } + return waitTime +} + +// Reset the failure count to 0. +func (w *Waiter) Reset() { + w.failures = 0 +} + +// Failures returns the count of consecutive failures. +func (w *Waiter) Failures() int { + return int(w.failures) +} + +// Wait increase the number of failures by one, and then blocks until the context +// is cancelled, or until the wait time is reached. +// The wait time increases exponentially as the number of failures increases. +// Wait will return ctx.Err() if the context is cancelled. +func (w *Waiter) Wait(ctx context.Context) error { + w.failures++ + timer := time.NewTimer(w.delay()) + select { + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/lib/retry/retry_test.go b/lib/retry/retry_test.go new file mode 100644 index 0000000000..92983e6d16 --- /dev/null +++ b/lib/retry/retry_test.go @@ -0,0 +1,160 @@ +package retry + +import ( + "context" + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestJitter(t *testing.T) { + repeat(t, "0 percent", func(t *testing.T) { + jitter := NewJitter(0) + for i := 0; i < 10; i++ { + baseTime := time.Duration(i) * time.Second + require.Equal(t, baseTime, jitter(baseTime)) + } + }) + + repeat(t, "10 percent", func(t *testing.T) { + jitter := NewJitter(10) + baseTime := 5000 * time.Millisecond + maxTime := 5500 * time.Millisecond + newTime := jitter(baseTime) + require.True(t, newTime > baseTime) + require.True(t, newTime <= maxTime) + }) + + repeat(t, "100 percent", func(t *testing.T) { + jitter := NewJitter(100) + baseTime := 1234 * time.Millisecond + maxTime := 2468 * time.Millisecond + newTime := jitter(baseTime) + require.True(t, newTime > baseTime) + require.True(t, newTime <= maxTime) + }) + + repeat(t, "overflow", func(t *testing.T) { + jitter := NewJitter(100) + baseTime := time.Duration(math.MaxInt64) - 2*time.Hour + newTime := jitter(baseTime) + require.Equal(t, baseTime, newTime) + }) +} + +func repeat(t *testing.T, name string, fn func(t *testing.T)) { + t.Run(name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + fn(t) + } + }) +} + +func TestWaiter_Delay(t *testing.T) { + t.Run("zero value", func(t *testing.T) { + w := &Waiter{} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) + } + }) + + t.Run("with minimum wait", func(t *testing.T) { + w := &Waiter{MinWait: 5 * time.Second} + for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) + } + }) + + t.Run("with maximum wait", func(t *testing.T) { + w := &Waiter{MaxWait: 20 * time.Second} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) + } + }) + + t.Run("with minimum failures", func(t *testing.T) { + w := &Waiter{MinFailures: 4} + for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} { + w.failures = uint(i) + require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i) + } + }) + + t.Run("with factor", func(t *testing.T) { + w := &Waiter{Factor: time.Millisecond} + for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} { + w.failures = uint(i) + require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i) + } + }) + + t.Run("with all settings", func(t *testing.T) { + w := &Waiter{ + MinFailures: 2, + MinWait: 4 * time.Millisecond, + MaxWait: 20 * time.Millisecond, + Factor: time.Millisecond, + } + for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} { + w.failures = uint(i) + require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i) + } + }) +} + +func TestWaiter_Wait(t *testing.T) { + ctx := context.Background() + + t.Run("first failure", func(t *testing.T) { + w := &Waiter{MinWait: time.Millisecond, Factor: 1} + elapsed, err := runWait(ctx, w) + require.NoError(t, err) + assertApproximateDuration(t, elapsed, time.Millisecond) + require.Equal(t, w.failures, uint(1)) + }) + + t.Run("max failures", func(t *testing.T) { + w := &Waiter{ + MaxWait: 100 * time.Millisecond, + failures: 200, + } + elapsed, err := runWait(ctx, w) + require.NoError(t, err) + assertApproximateDuration(t, elapsed, 100*time.Millisecond) + require.Equal(t, w.failures, uint(201)) + }) + + t.Run("context deadline", func(t *testing.T) { + w := &Waiter{failures: 200, MinWait: time.Second} + ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) + t.Cleanup(cancel) + + elapsed, err := runWait(ctx, w) + require.Equal(t, err, context.DeadlineExceeded) + assertApproximateDuration(t, elapsed, 5*time.Millisecond) + require.Equal(t, w.failures, uint(201)) + }) +} + +func runWait(ctx context.Context, w *Waiter) (time.Duration, error) { + before := time.Now() + err := w.Wait(ctx) + return time.Since(before), err +} + +func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) { + t.Helper() + delta := 20 * time.Millisecond + min, max := expected-delta, expected+delta + if min < 0 { + min = 0 + } + if actual < min || actual > max { + t.Fatalf("expected %v to be between %v and %v", actual, min, max) + } +} diff --git a/lib/retry_test.go b/lib/retry_test.go deleted file mode 100644 index 325b5b9526..0000000000 --- a/lib/retry_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package lib - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestJitterRandomStagger(t *testing.T) { - t.Parallel() - - t.Run("0 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(0) - for i := 0; i < 10; i++ { - baseTime := time.Duration(i) * time.Second - require.Equal(t, baseTime, jitter.AddJitter(baseTime)) - } - }) - - t.Run("10 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(10) - for i := 0; i < 10; i++ { - baseTime := 5000 * time.Millisecond - maxTime := 5500 * time.Millisecond - newTime := jitter.AddJitter(baseTime) - require.True(t, newTime > baseTime) - require.True(t, newTime <= maxTime) - } - }) - - t.Run("100 percent", func(t *testing.T) { - t.Parallel() - jitter := NewJitterRandomStagger(100) - for i := 0; i < 10; i++ { - baseTime := 1234 * time.Millisecond - maxTime := 2468 * time.Millisecond - newTime := jitter.AddJitter(baseTime) - require.True(t, newTime > baseTime) - require.True(t, newTime <= maxTime) - } - }) -} - -func TestRetryWaiter_calculateWait(t *testing.T) { - t.Parallel() - - t.Run("Defaults", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 0, nil) - - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 2*time.Second, rw.calculateWait()) - rw.failures = 31 - require.Equal(t, defaultMaxWait, rw.calculateWait()) - }) - - t.Run("Minimum Wait", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 5*time.Second, 0, nil) - - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 8*time.Second, rw.calculateWait()) - }) - - t.Run("Minimum Failures", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(5, 0, 0, nil) - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 5 - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - }) - - t.Run("Maximum Wait", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 5*time.Second, nil) - require.Equal(t, 0*time.Nanosecond, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 1*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 2*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 4*time.Second, rw.calculateWait()) - rw.failures += 1 - require.Equal(t, 5*time.Second, rw.calculateWait()) - rw.failures = 31 - require.Equal(t, 5*time.Second, rw.calculateWait()) - }) -} - -func TestRetryWaiter_WaitChans(t *testing.T) { - t.Parallel() - - t.Run("Minimum Wait - Success", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.Success(): - require.Fail(t, "minimum wait not respected") - } - }) - - t.Run("Minimum Wait - WaitIf", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.WaitIf(false): - require.Fail(t, "minimum wait not respected") - } - }) - - t.Run("Minimum Wait - WaitIfErr", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil) - - select { - case <-time.After(200 * time.Millisecond): - case <-rw.WaitIfErr(nil): - require.Fail(t, "minimum wait not respected") - } - }) - - t.Run("Maximum Wait - Failed", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.Failed(): - } - }) - - t.Run("Maximum Wait - WaitIf", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.WaitIf(true): - } - }) - - t.Run("Maximum Wait - WaitIfErr", func(t *testing.T) { - t.Parallel() - - rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil) - - select { - case <-time.After(500 * time.Millisecond): - require.Fail(t, "maximum wait not respected") - case <-rw.WaitIfErr(fmt.Errorf("Fake Error")): - } - }) -}