lib/retry: Refactor to reduce the interface surface

Reduce Jitter to one function

Rename NewRetryWaiter

Fix a bug in calculateWait where maxWait was applied before jitter, which would make it
possible to wait longer than maxWait.
This commit is contained in:
Daniel Nephin 2020-10-01 01:14:21 -04:00
parent 7b4aca2088
commit e54567223b
8 changed files with 243 additions and 331 deletions

View File

@ -85,7 +85,11 @@ func New(config Config) (*AutoConfig, error) {
}
if config.Waiter == nil {
config.Waiter = retry.NewRetryWaiter(1, 0, 10*time.Minute, retry.NewJitterRandomStagger(25))
config.Waiter = &retry.Waiter{
MinFailures: 1,
MaxWait: 10 * time.Minute,
Jitter: retry.NewJitter(25),
}
}
return &AutoConfig{
@ -306,23 +310,21 @@ func (ac *AutoConfig) getInitialConfiguration(ctx context.Context) (*pbautoconf.
return nil, err
}
// this resets the failures so that we will perform immediate request
wait := ac.acConfig.Waiter.Success()
ac.acConfig.Waiter.Reset()
for {
select {
case <-wait:
if resp, err := ac.getInitialConfigurationOnce(ctx, csr, key); err == nil && resp != nil {
return resp, nil
} else if err != nil {
ac.logger.Error(err.Error())
} else {
ac.logger.Error("No error returned when fetching configuration from the servers but no response was either")
}
resp, err := ac.getInitialConfigurationOnce(ctx, csr, key)
switch {
case err == nil && resp != nil:
return resp, nil
case err != nil:
ac.logger.Error(err.Error())
default:
ac.logger.Error("No error returned when fetching configuration from the servers but no response was either")
}
wait = ac.acConfig.Waiter.Failed()
case <-ctx.Done():
ac.logger.Info("interrupted during initial auto configuration", "err", ctx.Err())
return nil, ctx.Err()
if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
ac.logger.Info("interrupted during initial auto configuration", "err", err)
return nil, err
}
}
}

View File

@ -413,7 +413,7 @@ func TestInitialConfiguration_retries(t *testing.T) {
mcfg.Config.Loader = loader.Load
// reduce the retry wait times to make this test run faster
mcfg.Config.Waiter = retry.NewWaiter(2, 0, 1*time.Millisecond, nil)
mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", "secret")

View File

@ -16,23 +16,21 @@ func (ac *AutoConfig) autoEncryptInitialCerts(ctx context.Context) (*structs.Sig
return nil, err
}
// this resets the failures so that we will perform immediate request
wait := ac.acConfig.Waiter.Success()
ac.acConfig.Waiter.Reset()
for {
select {
case <-wait:
if resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key); err == nil && resp != nil {
return resp, nil
} else if err != nil {
ac.logger.Error(err.Error())
} else {
ac.logger.Error("No error returned when fetching certificates from the servers but no response was either")
}
resp, err := ac.autoEncryptInitialCertsOnce(ctx, csr, key)
switch {
case err == nil && resp != nil:
return resp, nil
case err != nil:
ac.logger.Error(err.Error())
default:
ac.logger.Error("No error returned when fetching certificates from the servers but no response was either")
}
wait = ac.acConfig.Waiter.Failed()
case <-ctx.Done():
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", ctx.Err())
return nil, ctx.Err()
if err := ac.acConfig.Waiter.Wait(ctx); err != nil {
ac.logger.Info("interrupted during retrieval of auto-encrypt certificates", "err", err)
return nil, err
}
}
}

View File

@ -248,7 +248,7 @@ func TestAutoEncrypt_InitialCerts(t *testing.T) {
resp.VerifyServerHostname = true
})
mcfg.Config.Waiter = retry.NewRetryWaiter(2, 0, 1*time.Millisecond, nil)
mcfg.Config.Waiter = &retry.Waiter{MinFailures: 2, MaxWait: time.Millisecond}
ac := AutoConfig{
config: &config.RuntimeConfig{

View File

@ -68,12 +68,12 @@ type Config struct {
// known servers during fallback operations.
ServerProvider ServerProvider
// Waiter is a RetryWaiter to be used during retrieval of the
// initial configuration. When a round of requests fails we will
// Waiter is used during retrieval of the initial configuration.
// When around of requests fails we will
// wait and eventually make another round of requests (1 round
// is trying the RPC once against each configured server addr). The
// waiting implements some backoff to prevent from retrying these RPCs
// to often. This field is not required and if left unset a waiter will
// too often. This field is not required and if left unset a waiter will
// be used that has a max wait duration of 10 minutes and a randomized
// jitter of 25% of the wait time. Setting this is mainly useful for
// testing purposes to allow testing out the retrying functionality without

View File

@ -18,8 +18,6 @@ const (
// replicationMaxRetryWait is the maximum number of seconds to wait between
// failed blocking queries when backing off.
replicationDefaultMaxRetryWait = 120 * time.Second
replicationDefaultRate = 1
)
type ReplicatorDelegate interface {
@ -36,7 +34,7 @@ type ReplicatorConfig struct {
// The number of replication rounds that can be done in a burst
Burst int
// Minimum number of RPC failures to ignore before backing off
MinFailures int
MinFailures uint
// Maximum wait time between failing RPCs
MaxRetryWait time.Duration
// Where to send our logs
@ -71,12 +69,11 @@ func NewReplicator(config *ReplicatorConfig) (*Replicator, error) {
if maxWait == 0 {
maxWait = replicationDefaultMaxRetryWait
}
minFailures := config.MinFailures
if minFailures < 0 {
minFailures = 0
waiter := &retry.Waiter{
MinFailures: config.MinFailures,
MaxWait: maxWait,
Jitter: retry.NewJitter(10),
}
waiter := retry.NewRetryWaiter(minFailures, 0*time.Second, maxWait, retry.NewJitterRandomStagger(10))
return &Replicator{
limiter: limiter,
waiter: waiter,
@ -100,10 +97,8 @@ func (r *Replicator) Run(ctx context.Context) error {
// Perform a single round of replication
index, exit, err := r.delegate.Replicate(ctx, atomic.LoadUint64(&r.lastRemoteIndex), r.logger)
if exit {
// the replication function told us to exit
return nil
}
if err != nil {
// reset the lastRemoteIndex when there is an RPC failure. This should cause a full sync to be done during
// the next round of replication
@ -112,18 +107,16 @@ func (r *Replicator) Run(ctx context.Context) error {
if r.suppressErrorLog != nil && !r.suppressErrorLog(err) {
r.logger.Warn("replication error (will retry if still leader)", "error", err)
}
} else {
atomic.StoreUint64(&r.lastRemoteIndex, index)
r.logger.Debug("replication completed through remote index", "index", index)
if err := r.waiter.Wait(ctx); err != nil {
return nil
}
continue
}
select {
case <-ctx.Done():
return nil
// wait some amount of time to prevent churning through many replication rounds while replication is failing
case <-r.waiter.WaitIfErr(err):
// do nothing
}
atomic.StoreUint64(&r.lastRemoteIndex, index)
r.logger.Debug("replication completed through remote index", "index", index)
r.waiter.Reset()
}
}

View File

@ -1,9 +1,9 @@
package retry
import (
"context"
"math/rand"
"time"
"github.com/hashicorp/consul/lib"
)
const (
@ -11,153 +11,96 @@ const (
defaultMaxWait = 2 * time.Minute
)
// Interface used for offloading jitter calculations from the RetryWaiter
type Jitter interface {
AddJitter(baseTime time.Duration) time.Duration
}
// Jitter should return a new wait duration optionally with some time added or
// removed to create some randomness in wait time.
type Jitter func(baseTime time.Duration) time.Duration
// Calculates a random jitter between 0 and up to a specific percentage of the baseTime
type JitterRandomStagger struct {
// int64 because we are going to be doing math against an int64 to represent nanoseconds
percent int64
}
// Creates a new JitterRandomStagger
func NewJitterRandomStagger(percent int) *JitterRandomStagger {
// NewJitter returns a new random Jitter that is up to percent longer than the
// original wait time.
func NewJitter(percent int64) Jitter {
if percent < 0 {
percent = 0
}
return &JitterRandomStagger{
percent: int64(percent),
return func(baseTime time.Duration) time.Duration {
if percent == 0 {
return baseTime
}
max := (int64(baseTime) * percent) / 100
if max < 0 { // overflow
return baseTime
}
return baseTime + time.Duration(rand.Int63n(max))
}
}
// Implments the Jitter interface
func (j *JitterRandomStagger) AddJitter(baseTime time.Duration) time.Duration {
if j.percent == 0 {
return baseTime
}
// time.Duration is actually a type alias for int64 which is why casting
// to the duration type and then dividing works
return baseTime + lib.RandomStagger((baseTime*time.Duration(j.percent))/100)
}
// RetryWaiter will record failed and successful operations and provide
// a channel to wait on before a failed operation can be retried.
// Waiter records the number of failures and performs exponential backoff when
// when there are consecutive failures.
type Waiter struct {
// MinFailures before exponential backoff starts. Any failures before
// MinFailures is reached will wait MinWait time.
MinFailures uint
MinWait time.Duration
MaxWait time.Duration
Jitter Jitter
failures uint
// MinWait time. Returned after the first failure.
MinWait time.Duration
// MaxWait time.
MaxWait time.Duration
// Jitter to add to each wait time.
Jitter Jitter
// Factor is the multiplier to use when calculating the delay. Defaults to
// 1 second.
Factor time.Duration
failures uint
}
// Creates a new RetryWaiter
func NewRetryWaiter(minFailures int, minWait, maxWait time.Duration, jitter Jitter) *Waiter {
if minFailures < 0 {
minFailures = defaultMinFailures
// delay calculates the time to wait based on the number of failures
func (w *Waiter) delay() time.Duration {
if w.failures <= w.MinFailures {
return w.MinWait
}
factor := w.Factor
if factor == 0 {
factor = time.Second
}
if maxWait <= 0 {
maxWait = defaultMaxWait
shift := w.failures - w.MinFailures - 1
waitTime := w.MaxWait
if shift < 31 {
waitTime = (1 << shift) * factor
}
if minWait <= 0 {
minWait = 0 * time.Nanosecond
if w.Jitter != nil {
waitTime = w.Jitter(waitTime)
}
return &Waiter{
MinFailures: uint(minFailures),
MinWait: minWait,
MaxWait: maxWait,
failures: 0,
Jitter: jitter,
if w.MaxWait != 0 && waitTime > w.MaxWait {
return w.MaxWait
}
}
// calculates the necessary wait time before the
// next operation should be allowed.
func (rw *Waiter) calculateWait() time.Duration {
waitTime := rw.MinWait
if rw.failures > rw.MinFailures {
shift := rw.failures - rw.MinFailures - 1
waitTime = rw.MaxWait
if shift < 31 {
waitTime = (1 << shift) * time.Second
}
if waitTime > rw.MaxWait {
waitTime = rw.MaxWait
}
if rw.Jitter != nil {
waitTime = rw.Jitter.AddJitter(waitTime)
}
if waitTime < w.MinWait {
return w.MinWait
}
if waitTime < rw.MinWait {
waitTime = rw.MinWait
}
return waitTime
}
// calculates the waitTime and returns a chan
// that will become selectable once that amount
// of time has elapsed.
func (rw *Waiter) wait() <-chan struct{} {
waitTime := rw.calculateWait()
ch := make(chan struct{})
if waitTime > 0 {
time.AfterFunc(waitTime, func() { close(ch) })
} else {
// if there should be 0 wait time then we ensure
// that the chan will be immediately selectable
close(ch)
// Reset the failure count to 0.
func (w *Waiter) Reset() {
w.failures = 0
}
// Failures returns the count of consecutive failures.
func (w *Waiter) Failures() int {
return int(w.failures)
}
// Wait increase the number of failures by one, and then blocks until the context
// is cancelled, or until the wait time is reached.
// The wait time increases exponentially as the number of failures increases.
// Wait will return ctx.Err() if the context is cancelled.
func (w *Waiter) Wait(ctx context.Context) error {
w.failures++
timer := time.NewTimer(w.delay())
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-timer.C:
return nil
}
return ch
}
// Marks that an operation is successful which resets the failure count.
// The chan that is returned will be immediately selectable
func (rw *Waiter) Success() <-chan struct{} {
rw.Reset()
return rw.wait()
}
// Marks that an operation failed. The chan returned will be selectable
// once the calculated retry wait amount of time has elapsed
func (rw *Waiter) Failed() <-chan struct{} {
rw.failures += 1
ch := rw.wait()
return ch
}
// Resets the internal failure counter.
func (rw *Waiter) Reset() {
rw.failures = 0
}
// Failures returns the current number of consecutive failures recorded.
func (rw *Waiter) Failures() int {
return int(rw.failures)
}
// WaitIf is a convenice method to record whether the last
// operation was a success or failure and return a chan that
// will be selectablw when the next operation can be done.
func (rw *Waiter) WaitIf(failure bool) <-chan struct{} {
if failure {
return rw.Failed()
}
return rw.Success()
}
// WaitIfErr is a convenience method to record whether the last
// operation was a success or failure based on whether the err
// is nil and then return a chan that will be selectable when
// the next operation can be done.
func (rw *Waiter) WaitIfErr(err error) <-chan struct{} {
return rw.WaitIf(err != nil)
}

View File

@ -1,184 +1,160 @@
package retry
import (
"fmt"
"context"
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestJitterRandomStagger(t *testing.T) {
t.Parallel()
t.Run("0 percent", func(t *testing.T) {
t.Parallel()
jitter := NewJitterRandomStagger(0)
func TestJitter(t *testing.T) {
repeat(t, "0 percent", func(t *testing.T) {
jitter := NewJitter(0)
for i := 0; i < 10; i++ {
baseTime := time.Duration(i) * time.Second
require.Equal(t, baseTime, jitter.AddJitter(baseTime))
require.Equal(t, baseTime, jitter(baseTime))
}
})
t.Run("10 percent", func(t *testing.T) {
t.Parallel()
jitter := NewJitterRandomStagger(10)
for i := 0; i < 10; i++ {
baseTime := 5000 * time.Millisecond
maxTime := 5500 * time.Millisecond
newTime := jitter.AddJitter(baseTime)
require.True(t, newTime > baseTime)
require.True(t, newTime <= maxTime)
}
repeat(t, "10 percent", func(t *testing.T) {
jitter := NewJitter(10)
baseTime := 5000 * time.Millisecond
maxTime := 5500 * time.Millisecond
newTime := jitter(baseTime)
require.True(t, newTime > baseTime)
require.True(t, newTime <= maxTime)
})
t.Run("100 percent", func(t *testing.T) {
t.Parallel()
jitter := NewJitterRandomStagger(100)
for i := 0; i < 10; i++ {
baseTime := 1234 * time.Millisecond
maxTime := 2468 * time.Millisecond
newTime := jitter.AddJitter(baseTime)
require.True(t, newTime > baseTime)
require.True(t, newTime <= maxTime)
repeat(t, "100 percent", func(t *testing.T) {
jitter := NewJitter(100)
baseTime := 1234 * time.Millisecond
maxTime := 2468 * time.Millisecond
newTime := jitter(baseTime)
require.True(t, newTime > baseTime)
require.True(t, newTime <= maxTime)
})
repeat(t, "overflow", func(t *testing.T) {
jitter := NewJitter(100)
baseTime := time.Duration(math.MaxInt64) - 2*time.Hour
newTime := jitter(baseTime)
require.Equal(t, baseTime, newTime)
})
}
func repeat(t *testing.T, name string, fn func(t *testing.T)) {
t.Run(name, func(t *testing.T) {
for i := 0; i < 1000; i++ {
fn(t)
}
})
}
func TestRetryWaiter_calculateWait(t *testing.T) {
t.Parallel()
t.Run("Defaults", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 0, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 2*time.Second, rw.calculateWait())
rw.failures = 31
require.Equal(t, defaultMaxWait, rw.calculateWait())
})
t.Run("Minimum Wait", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 5*time.Second, 0, nil)
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 8*time.Second, rw.calculateWait())
})
t.Run("Minimum Failures", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(5, 0, 0, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 5
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
})
t.Run("Maximum Wait", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 5*time.Second, nil)
require.Equal(t, 0*time.Nanosecond, rw.calculateWait())
rw.failures += 1
require.Equal(t, 1*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 2*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 4*time.Second, rw.calculateWait())
rw.failures += 1
require.Equal(t, 5*time.Second, rw.calculateWait())
rw.failures = 31
require.Equal(t, 5*time.Second, rw.calculateWait())
})
}
func TestRetryWaiter_WaitChans(t *testing.T) {
t.Parallel()
t.Run("Minimum Wait - Success", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.Success():
require.Fail(t, "minimum wait not respected")
func TestWaiter_Delay(t *testing.T) {
t.Run("zero value", func(t *testing.T) {
w := &Waiter{}
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
}
})
t.Run("Minimum Wait - WaitIf", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.WaitIf(false):
require.Fail(t, "minimum wait not respected")
t.Run("with minimum wait", func(t *testing.T) {
w := &Waiter{MinWait: 5 * time.Second}
for i, expected := range []time.Duration{5, 5, 5, 5, 8, 16, 32, 64, 128} {
w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
}
})
t.Run("Minimum Wait - WaitIfErr", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 250*time.Millisecond, 0, nil)
select {
case <-time.After(200 * time.Millisecond):
case <-rw.WaitIfErr(nil):
require.Fail(t, "minimum wait not respected")
t.Run("with maximum wait", func(t *testing.T) {
w := &Waiter{MaxWait: 20 * time.Second}
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 20, 20, 20} {
w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
}
})
t.Run("Maximum Wait - Failed", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
select {
case <-time.After(500 * time.Millisecond):
require.Fail(t, "maximum wait not respected")
case <-rw.Failed():
t.Run("with minimum failures", func(t *testing.T) {
w := &Waiter{MinFailures: 4}
for i, expected := range []time.Duration{0, 0, 0, 0, 0, 1, 2, 4, 8, 16} {
w.failures = uint(i)
require.Equal(t, expected*time.Second, w.delay(), "failure count: %d", i)
}
})
t.Run("Maximum Wait - WaitIf", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
select {
case <-time.After(500 * time.Millisecond):
require.Fail(t, "maximum wait not respected")
case <-rw.WaitIf(true):
t.Run("with factor", func(t *testing.T) {
w := &Waiter{Factor: time.Millisecond}
for i, expected := range []time.Duration{0, 1, 2, 4, 8, 16, 32, 64, 128} {
w.failures = uint(i)
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
}
})
t.Run("Maximum Wait - WaitIfErr", func(t *testing.T) {
t.Parallel()
rw := NewRetryWaiter(0, 0, 250*time.Millisecond, nil)
select {
case <-time.After(500 * time.Millisecond):
require.Fail(t, "maximum wait not respected")
case <-rw.WaitIfErr(fmt.Errorf("Fake Error")):
t.Run("with all settings", func(t *testing.T) {
w := &Waiter{
MinFailures: 2,
MinWait: 4 * time.Millisecond,
MaxWait: 20 * time.Millisecond,
Factor: time.Millisecond,
}
for i, expected := range []time.Duration{4, 4, 4, 4, 4, 4, 8, 16, 20, 20, 20} {
w.failures = uint(i)
require.Equal(t, expected*time.Millisecond, w.delay(), "failure count: %d", i)
}
})
}
func TestWaiter_Wait(t *testing.T) {
ctx := context.Background()
t.Run("first failure", func(t *testing.T) {
w := &Waiter{MinWait: time.Millisecond, Factor: 1}
elapsed, err := runWait(ctx, w)
require.NoError(t, err)
assertApproximateDuration(t, elapsed, time.Millisecond)
require.Equal(t, w.failures, uint(1))
})
t.Run("max failures", func(t *testing.T) {
w := &Waiter{
MaxWait: 100 * time.Millisecond,
failures: 200,
}
elapsed, err := runWait(ctx, w)
require.NoError(t, err)
assertApproximateDuration(t, elapsed, 100*time.Millisecond)
require.Equal(t, w.failures, uint(201))
})
t.Run("context deadline", func(t *testing.T) {
w := &Waiter{failures: 200, MinWait: time.Second}
ctx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
t.Cleanup(cancel)
elapsed, err := runWait(ctx, w)
require.Equal(t, err, context.DeadlineExceeded)
assertApproximateDuration(t, elapsed, 5*time.Millisecond)
require.Equal(t, w.failures, uint(201))
})
}
func runWait(ctx context.Context, w *Waiter) (time.Duration, error) {
before := time.Now()
err := w.Wait(ctx)
return time.Since(before), err
}
func assertApproximateDuration(t *testing.T, actual time.Duration, expected time.Duration) {
t.Helper()
delta := 20 * time.Millisecond
min, max := expected-delta, expected+delta
if min < 0 {
min = 0
}
if actual < min || actual > max {
t.Fatalf("expected %v to be between %v and %v", actual, min, max)
}
}