diff --git a/circuitbreaker/circuit_breaker.go b/circuitbreaker/circuit_breaker.go index c8afb6b29..c465bda20 100644 --- a/circuitbreaker/circuit_breaker.go +++ b/circuitbreaker/circuit_breaker.go @@ -1,8 +1,8 @@ package circuitbreaker import ( + "context" "fmt" - "strings" "github.com/afex/hystrix-go/hystrix" ) @@ -23,11 +23,14 @@ func (cr CommandResult) Error() error { } type Command struct { + ctx context.Context functors []*Functor + cancel bool } -func NewCommand(functors []*Functor) *Command { +func NewCommand(ctx context.Context, functors []*Functor) *Command { return &Command{ + ctx: ctx, functors: functors, } } @@ -40,8 +43,11 @@ func (cmd *Command) IsEmpty() bool { return len(cmd.functors) == 0 } +func (cmd *Command) Cancel() { + cmd.cancel = true +} + type Config struct { - CommandName string Timeout int MaxConcurrentRequests int RequestVolumeThreshold int @@ -60,76 +66,72 @@ func NewCircuitBreaker(config Config) *CircuitBreaker { } type Functor struct { - Exec FallbackFunc + exec FallbackFunc + circuitName string } -func NewFunctor(exec FallbackFunc) *Functor { +func NewFunctor(exec FallbackFunc, circuitName string) *Functor { return &Functor{ - Exec: exec, + exec: exec, + circuitName: circuitName, } } -// This a blocking function -func (eh *CircuitBreaker) Execute(cmd Command) CommandResult { - resultChan := make(chan CommandResult, 1) - var result CommandResult +// Executes the command in its circuit if set. +// If the command's circuit is not configured, the circuit of the CircuitBreaker is used. +// This is a blocking function. +func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { + if cmd == nil || cmd.IsEmpty() { + return CommandResult{err: fmt.Errorf("command is nil or empty")} + } - for i := 0; i < len(cmd.functors); i += 2 { - f1 := cmd.functors[i] - var f2 *Functor - if i+1 < len(cmd.functors) { - f2 = cmd.functors[i+1] + var result CommandResult + ctx := cmd.ctx + if ctx == nil { + ctx = context.Background() + } + + for _, f := range cmd.functors { + if cmd.cancel { + break } - circuitName := fmt.Sprintf("%s_%d", eh.config.CommandName, i) - if hystrix.GetCircuitSettings()[circuitName] == nil { - hystrix.ConfigureCommand(circuitName, hystrix.CommandConfig{ - Timeout: eh.config.Timeout, - MaxConcurrentRequests: eh.config.MaxConcurrentRequests, - RequestVolumeThreshold: eh.config.RequestVolumeThreshold, - SleepWindow: eh.config.SleepWindow, - ErrorPercentThreshold: eh.config.ErrorPercentThreshold, + if hystrix.GetCircuitSettings()[f.circuitName] == nil { + hystrix.ConfigureCommand(f.circuitName, hystrix.CommandConfig{ + Timeout: cb.config.Timeout, + MaxConcurrentRequests: cb.config.MaxConcurrentRequests, + RequestVolumeThreshold: cb.config.RequestVolumeThreshold, + SleepWindow: cb.config.SleepWindow, + ErrorPercentThreshold: cb.config.ErrorPercentThreshold, }) } - // If circuit is the same for all functions, in case of len(cmd.functors) > 2, - // main and fallback providers are different next run if first two fail, - // which causes health issues for both main and fallback and ErrorPercentThreshold - // is reached faster than it should be. - errChan := hystrix.Go(circuitName, func() error { - res, err := f1.Exec() - // Write to resultChan only if success + err := hystrix.DoC(ctx, f.circuitName, func(ctx context.Context) error { + res, err := f.exec() + // Write to result only if success if err == nil { - resultChan <- CommandResult{res: res, err: err} + result = CommandResult{res: res} } return err - }, func(err error) error { - // In case of concurrency, we should not execute the fallback - if f2 == nil || err == hystrix.ErrMaxConcurrency { - return err - } - res, err := f2.Exec() - if err == nil { - resultChan <- CommandResult{res: res, err: err} - } - return err - }) + }, nil) - select { - case result = <-resultChan: - if result.err == nil { - return result - } - case err := <-errChan: - result = CommandResult{err: err} - - // In case of max concurrency, we should delay the execution and stop iterating over fallbacks - // No error unwrapping here, so use strings.Contains - if strings.Contains(err.Error(), hystrix.ErrMaxConcurrency.Error()) { - return result - } + if err == nil { + break } + + // Accumulate errors + if result.err != nil { + result.err = fmt.Errorf("%w, %s.error: %w", result.err, f.circuitName, err) + } else { + result.err = fmt.Errorf("%s.error: %w", f.circuitName, err) + } + // Lets abuse every provider with the same amount of MaxConcurrentRequests, + // keep iterating even in case of ErrMaxConcurrency error } return result } + +func (cb *CircuitBreaker) Config() Config { + return cb.config +} diff --git a/circuitbreaker/circuit_breaker_test.go b/circuitbreaker/circuit_breaker_test.go index 183f7dc94..d5a4b1d95 100644 --- a/circuitbreaker/circuit_breaker_test.go +++ b/circuitbreaker/circuit_breaker_test.go @@ -1,10 +1,14 @@ package circuitbreaker import ( + "context" "errors" + "fmt" "testing" "time" + "github.com/afex/hystrix-go/hystrix" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -12,7 +16,6 @@ const success = "Success" func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) { cb := NewCircuitBreaker(Config{ - CommandName: "SuccessSingle", // unique name to avoid conflicts with go tests `-count` option Timeout: 1000, MaxConcurrentRequests: 100, RequestVolumeThreshold: 10, @@ -21,13 +24,12 @@ func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) { }) expectedResult := success - cmd := Command{ - functors: []*Functor{ - NewFunctor(func() ([]interface{}, error) { - return []any{expectedResult}, nil - }), - }, - } + circuitName := "SuccessSingle" + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + return []any{expectedResult}, nil + }, circuitName)}, + ) result := cb.Execute(cmd) require.NoError(t, result.Error()) @@ -36,7 +38,6 @@ func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) { func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) { cb := NewCircuitBreaker(Config{ - CommandName: "MultipleFail", // unique name to avoid conflicts with go tests `-count` option Timeout: 10, MaxConcurrentRequests: 100, RequestVolumeThreshold: 10, @@ -44,28 +45,31 @@ func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) { ErrorPercentThreshold: 10, }) - cmd := Command{ - functors: []*Functor{ - NewFunctor(func() ([]interface{}, error) { - time.Sleep(100 * time.Millisecond) // will cause hystrix: timeout - return []any{success}, nil - }), - NewFunctor(func() ([]interface{}, error) { - return nil, errors.New("provider 2 failed") - }), - NewFunctor(func() ([]interface{}, error) { - return nil, errors.New("provider 3 failed") - }), - }, - } + circuitName := "" + errSecProvFailed := errors.New("provider 2 failed") + errThirdProvFailed := errors.New("provider 3 failed") + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + time.Sleep(100 * time.Millisecond) // will cause hystrix: timeout + return []any{success}, nil + }, circuitName), + NewFunctor(func() ([]interface{}, error) { + return nil, errSecProvFailed + }, circuitName), + NewFunctor(func() ([]interface{}, error) { + return nil, errThirdProvFailed + }, circuitName), + }) result := cb.Execute(cmd) require.Error(t, result.Error()) + assert.True(t, errors.Is(result.Error(), hystrix.ErrTimeout)) + assert.True(t, errors.Is(result.Error(), errSecProvFailed)) + assert.True(t, errors.Is(result.Error(), errThirdProvFailed)) } func TestCircuitBreaker_ExecuteMultipleFallbacksFailButLastSuccessStress(t *testing.T) { cb := NewCircuitBreaker(Config{ - CommandName: "LastSuccessStress", // unique name to avoid conflicts with go tests `-count` option Timeout: 10, MaxConcurrentRequests: 100, RequestVolumeThreshold: 10, @@ -74,26 +78,146 @@ func TestCircuitBreaker_ExecuteMultipleFallbacksFailButLastSuccessStress(t *test }) expectedResult := success + circuitName := fmt.Sprintf("LastSuccessStress_%d", time.Now().Nanosecond()) // unique name to avoid conflicts with go tests `-count` option // These are executed sequentially, but I had an issue with the test failing // because of the open circuit for i := 0; i < 1000; i++ { - cmd := Command{ - functors: []*Functor{ - NewFunctor(func() ([]interface{}, error) { - return nil, errors.New("provider 1 failed") - }), - NewFunctor(func() ([]interface{}, error) { - return nil, errors.New("provider 2 failed") - }), - NewFunctor(func() ([]interface{}, error) { - return []any{expectedResult}, nil - }), - }, - } + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + return nil, errors.New("provider 1 failed") + }, circuitName+"1"), + NewFunctor(func() ([]interface{}, error) { + return nil, errors.New("provider 2 failed") + }, circuitName+"2"), + NewFunctor(func() ([]interface{}, error) { + return []any{expectedResult}, nil + }, circuitName+"3"), + }, + ) result := cb.Execute(cmd) require.NoError(t, result.Error()) require.Equal(t, expectedResult, result.Result()[0].(string)) } } + +func TestCircuitBreaker_ExecuteSwitchToWorkingProviderOnVolumeThresholdReached(t *testing.T) { + cb := NewCircuitBreaker(Config{ + RequestVolumeThreshold: 10, + }) + + expectedResult := success + circuitName := fmt.Sprintf("SwitchToWorkingProviderOnVolumeThresholdReached_%d", time.Now().Nanosecond()) // unique name to avoid conflicts with go tests `-count` option + + prov1Called := 0 + prov2Called := 0 + prov3Called := 0 + // These are executed sequentially + for i := 0; i < 20; i++ { + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + prov1Called++ + return nil, errors.New("provider 1 failed") + }, circuitName+"1"), + NewFunctor(func() ([]interface{}, error) { + prov2Called++ + return nil, errors.New("provider 2 failed") + }, circuitName+"2"), + NewFunctor(func() ([]interface{}, error) { + prov3Called++ + return []any{expectedResult}, nil + }, circuitName+"3"), + }) + + result := cb.Execute(cmd) + require.NoError(t, result.Error()) + require.Equal(t, expectedResult, result.Result()[0].(string)) + } + + assert.Equal(t, 10, prov1Called) + assert.Equal(t, 10, prov2Called) + assert.Equal(t, 20, prov3Called) +} + +func TestCircuitBreaker_ExecuteHealthCheckOnWindowTimeout(t *testing.T) { + sleepWindow := 10 + cb := NewCircuitBreaker(Config{ + RequestVolumeThreshold: 1, // 1 failed request is enough to trip the circuit + SleepWindow: sleepWindow, + ErrorPercentThreshold: 1, // Trip on first error + }) + + expectedResult := success + circuitName := fmt.Sprintf("SwitchToWorkingProviderOnWindowTimeout_%d", time.Now().Nanosecond()) // unique name to avoid conflicts with go tests `-count` option + + prov1Called := 0 + prov2Called := 0 + // These are executed sequentially + for i := 0; i < 10; i++ { + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + prov1Called++ + return nil, errors.New("provider 1 failed") + }, circuitName+"1"), + NewFunctor(func() ([]interface{}, error) { + prov2Called++ + return []any{expectedResult}, nil + }, circuitName+"2"), + }) + + result := cb.Execute(cmd) + require.NoError(t, result.Error()) + require.Equal(t, expectedResult, result.Result()[0].(string)) + } + + assert.Equal(t, 1, prov1Called) + assert.Equal(t, 10, prov2Called) + + // Wait for the sleep window to expire + time.Sleep(time.Duration(sleepWindow+1) * time.Millisecond) + cmd := NewCommand(context.TODO(), []*Functor{ + NewFunctor(func() ([]interface{}, error) { + prov1Called++ + return []any{expectedResult}, nil // Now it is working + }, circuitName+"1"), + NewFunctor(func() ([]interface{}, error) { + prov2Called++ + return []any{expectedResult}, nil + }, circuitName+"2"), + }) + result := cb.Execute(cmd) + require.NoError(t, result.Error()) + + assert.Equal(t, 2, prov1Called) + assert.Equal(t, 10, prov2Called) +} + +func TestCircuitBreaker_CommandCancel(t *testing.T) { + cb := NewCircuitBreaker(Config{}) + + circuitName := fmt.Sprintf("CommandCancel_%d", time.Now().Nanosecond()) // unique name to avoid conflicts with go tests `-count` option + + prov1Called := 0 + prov2Called := 0 + + expectedErr := errors.New("provider 1 failed") + // These are executed sequentially + cmd := NewCommand(context.TODO(), nil) + cmd.Add(NewFunctor(func() ([]interface{}, error) { + prov1Called++ + cmd.Cancel() + return nil, expectedErr + }, circuitName+"1")) + cmd.Add(NewFunctor(func() ([]interface{}, error) { + prov2Called++ + return nil, errors.New("provider 2 failed") + }, circuitName+"2")) + + result := cb.Execute(cmd) + t.Log(result.Error()) + require.True(t, errors.Is(result.Error(), expectedErr)) + + assert.Equal(t, 1, prov1Called) + assert.Equal(t, 0, prov2Called) +}