diff --git a/circuitbreaker/circuit_breaker.go b/circuitbreaker/circuit_breaker.go index a1dfda3a2..0b6a97e2c 100644 --- a/circuitbreaker/circuit_breaker.go +++ b/circuitbreaker/circuit_breaker.go @@ -3,6 +3,7 @@ package circuitbreaker import ( "context" "fmt" + "time" "github.com/afex/hystrix-go/hystrix" @@ -12,8 +13,16 @@ import ( type FallbackFunc func() ([]any, error) type CommandResult struct { - res []any - err error + res []any + err error + functorCallStatuses []FunctorCallStatus + cancelled bool +} + +type FunctorCallStatus struct { + Name string + Timestamp time.Time + Err error } func (cr CommandResult) Result() []any { @@ -23,6 +32,21 @@ func (cr CommandResult) Result() []any { func (cr CommandResult) Error() error { return cr.err } +func (cr CommandResult) Cancelled() bool { + return cr.cancelled +} + +func (cr CommandResult) FunctorCallStatuses() []FunctorCallStatus { + return cr.functorCallStatuses +} + +func (cr *CommandResult) addCallStatus(circuitName string, err error) { + cr.functorCallStatuses = append(cr.functorCallStatuses, FunctorCallStatus{ + Name: circuitName, + Timestamp: time.Now(), + Err: err, + }) +} type Command struct { ctx context.Context @@ -106,23 +130,26 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { for i, f := range cmd.functors { if cmd.cancel { + result.cancelled = true break } var err error + circuitName := f.circuitName + if cb.circuitNameHandler != nil { + circuitName = cb.circuitNameHandler(circuitName) + } + // if last command, execute without circuit if i == len(cmd.functors)-1 { res, execErr := f.exec() err = execErr if err == nil { - result = CommandResult{res: res} + result.res = res + result.err = nil } + result.addCallStatus(circuitName, err) } else { - circuitName := f.circuitName - if cb.circuitNameHandler != nil { - circuitName = cb.circuitNameHandler(circuitName) - } - if hystrix.GetCircuitSettings()[circuitName] == nil { hystrix.ConfigureCommand(circuitName, hystrix.CommandConfig{ Timeout: cb.config.Timeout, @@ -137,13 +164,16 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { res, err := f.exec() // Write to result only if success if err == nil { - result = CommandResult{res: res} + result.res = res + result.err = nil } + result.addCallStatus(circuitName, err) // If the command has been cancelled, we don't count // the error towars breaking the circuit, and then we break if cmd.cancel { - result = accumulateCommandError(result, f.circuitName, err) + result = accumulateCommandError(result, circuitName, err) + result.cancelled = true return nil } if err != nil { @@ -156,7 +186,7 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { break } - result = accumulateCommandError(result, f.circuitName, err) + result = accumulateCommandError(result, circuitName, err) // Lets abuse every provider with the same amount of MaxConcurrentRequests, // keep iterating even in case of ErrMaxConcurrency error diff --git a/circuitbreaker/circuit_breaker_test.go b/circuitbreaker/circuit_breaker_test.go index 42d35d0c3..c159f26e6 100644 --- a/circuitbreaker/circuit_breaker_test.go +++ b/circuitbreaker/circuit_breaker_test.go @@ -34,6 +34,7 @@ func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) { result := cb.Execute(cmd) require.NoError(t, result.Error()) require.Equal(t, expectedResult, result.Result()[0].(string)) + require.False(t, result.Cancelled()) } func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) { @@ -219,9 +220,11 @@ func TestCircuitBreaker_CommandCancel(t *testing.T) { result := cb.Execute(cmd) require.True(t, errors.Is(result.Error(), expectedErr)) + require.True(t, result.Cancelled()) assert.Equal(t, 1, prov1Called) assert.Equal(t, 0, prov2Called) + } func TestCircuitBreaker_EmptyOrNilCommand(t *testing.T) { @@ -301,3 +304,149 @@ func TestCircuitBreaker_Fallback(t *testing.T) { assert.Equal(t, 1, prov1Called) } + +func TestCircuitBreaker_SuccessCallStatus(t *testing.T) { + cb := NewCircuitBreaker(Config{}) + + functor := NewFunctor(func() ([]any, error) { + return []any{"success"}, nil + }, "successCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + + result := cb.Execute(cmd) + + require.Nil(t, result.Error()) + require.False(t, result.Cancelled()) + assert.Len(t, result.Result(), 1) + require.Equal(t, "success", result.Result()[0]) + assert.Len(t, result.FunctorCallStatuses(), 1) + + status := result.FunctorCallStatuses()[0] + if status.Name != "successCircuit" { + t.Errorf("Expected functor name to be 'successCircuit', got %s", status.Name) + } + if status.Err != nil { + t.Errorf("Expected no error in functor status, got %v", status.Err) + } +} + +func TestCircuitBreaker_ErrorCallStatus(t *testing.T) { + cb := NewCircuitBreaker(Config{}) + + expectedError := errors.New("functor error") + functor := NewFunctor(func() ([]any, error) { + return nil, expectedError + }, "errorCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + + result := cb.Execute(cmd) + + require.NotNil(t, result.Error()) + require.True(t, errors.Is(result.Error(), expectedError)) + + assert.Len(t, result.Result(), 0) + assert.Len(t, result.FunctorCallStatuses(), 1) + + status := result.FunctorCallStatuses()[0] + if status.Name != "errorCircuit" { + t.Errorf("Expected functor name to be 'errorCircuit', got %s", status.Name) + } + if !errors.Is(status.Err, expectedError) { + t.Errorf("Expected functor error to be '%v', got '%v'", expectedError, status.Err) + } +} + +func TestCircuitBreaker_CancelledResult(t *testing.T) { + cb := NewCircuitBreaker(Config{Timeout: 1000}) + + functor := NewFunctor(func() ([]any, error) { + time.Sleep(500 * time.Millisecond) + return []any{"should not be returned"}, nil + }, "cancelCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + cmd.Cancel() + + result := cb.Execute(cmd) + + assert.True(t, result.Cancelled()) + require.Nil(t, result.Error()) + require.Empty(t, result.Result()) + require.Empty(t, result.FunctorCallStatuses()) +} + +func TestCircuitBreaker_MultipleFunctorsResult(t *testing.T) { + cb := NewCircuitBreaker(Config{ + Timeout: 1000, + MaxConcurrentRequests: 100, + RequestVolumeThreshold: 20, + SleepWindow: 5000, + ErrorPercentThreshold: 50, + }) + + functor1 := NewFunctor(func() ([]any, error) { + return nil, errors.New("functor1 error") + }, "circuit1") + + functor2 := NewFunctor(func() ([]any, error) { + return []any{"success from functor2"}, nil + }, "circuit2") + + cmd := NewCommand(context.Background(), []*Functor{functor1, functor2}) + + result := cb.Execute(cmd) + + require.Nil(t, result.Error()) + + require.Len(t, result.Result(), 1) + require.Equal(t, result.Result()[0], "success from functor2") + statuses := result.FunctorCallStatuses() + require.Len(t, statuses, 2) + + require.Equal(t, statuses[0].Name, "circuit1") + require.NotNil(t, statuses[0].Err) + + require.Equal(t, statuses[1].Name, "circuit2") + require.Nil(t, statuses[1].Err) +} + +func TestCircuitBreaker_LastFunctorDirectExecution(t *testing.T) { + cb := NewCircuitBreaker(Config{ + Timeout: 10, // short timeout to open circuit + MaxConcurrentRequests: 1, + RequestVolumeThreshold: 1, + SleepWindow: 1000, + ErrorPercentThreshold: 1, + }) + + failingFunctor := NewFunctor(func() ([]any, error) { + time.Sleep(20 * time.Millisecond) + return nil, errors.New("should time out") + }, "circuitName") + + successFunctor := NewFunctor(func() ([]any, error) { + return []any{"success without circuit"}, nil + }, "circuitName") + + cmd := NewCommand(context.Background(), []*Functor{failingFunctor, successFunctor}) + + require.False(t, IsCircuitOpen("circuitName")) + result := cb.Execute(cmd) + + require.True(t, CircuitExists("circuitName")) + require.Nil(t, result.Error()) + + require.Len(t, result.Result(), 1) + require.Equal(t, result.Result()[0], "success without circuit") + + statuses := result.FunctorCallStatuses() + require.Len(t, statuses, 2) + + require.Equal(t, statuses[0].Name, "circuitName") + require.NotNil(t, statuses[0].Err) + + require.Equal(t, statuses[1].Name, "circuitName") + require.Nil(t, statuses[1].Err) +} diff --git a/healthmanager/aggregator/aggregator.go b/healthmanager/aggregator/aggregator.go new file mode 100644 index 000000000..ef872a184 --- /dev/null +++ b/healthmanager/aggregator/aggregator.go @@ -0,0 +1,134 @@ +package aggregator + +import ( + "sync" + "time" + + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +// Aggregator manages and aggregates the statuses of multiple providers. +type Aggregator struct { + mu sync.RWMutex + name string + providerStatuses map[string]*rpcstatus.ProviderStatus +} + +// NewAggregator creates a new instance of Aggregator with the given name. +func NewAggregator(name string) *Aggregator { + return &Aggregator{ + name: name, + providerStatuses: make(map[string]*rpcstatus.ProviderStatus), + } +} + +// RegisterProvider adds a new provider to the aggregator. +// If the provider already exists, it does nothing. +func (a *Aggregator) RegisterProvider(providerName string) { + a.mu.Lock() + defer a.mu.Unlock() + if _, exists := a.providerStatuses[providerName]; !exists { + a.providerStatuses[providerName] = &rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUnknown, + } + } +} + +// Update modifies the status of a specific provider. +// If the provider is not already registered, it adds the provider. +func (a *Aggregator) Update(providerStatus rpcstatus.ProviderStatus) { + a.mu.Lock() + defer a.mu.Unlock() + + // Update existing provider status or add a new provider. + if ps, exists := a.providerStatuses[providerStatus.Name]; exists { + ps.Status = providerStatus.Status + if providerStatus.Status == rpcstatus.StatusUp { + ps.LastSuccessAt = providerStatus.LastSuccessAt + } else if providerStatus.Status == rpcstatus.StatusDown { + ps.LastErrorAt = providerStatus.LastErrorAt + ps.LastError = providerStatus.LastError + } + } else { + a.providerStatuses[providerStatus.Name] = &rpcstatus.ProviderStatus{ + Name: providerStatus.Name, + LastSuccessAt: providerStatus.LastSuccessAt, + LastErrorAt: providerStatus.LastErrorAt, + LastError: providerStatus.LastError, + Status: providerStatus.Status, + } + } +} + +// UpdateBatch processes a batch of provider statuses. +func (a *Aggregator) UpdateBatch(statuses []rpcstatus.ProviderStatus) { + for _, status := range statuses { + a.Update(status) + } +} + +// ComputeAggregatedStatus calculates the overall aggregated status based on individual provider statuses. +// The logic is as follows: +// - If any provider is up, the aggregated status is up. +// - If no providers are up but at least one is unknown, the aggregated status is unknown. +// - If all providers are down, the aggregated status is down. +func (a *Aggregator) ComputeAggregatedStatus() rpcstatus.ProviderStatus { + a.mu.RLock() + defer a.mu.RUnlock() + + var lastSuccessAt, lastErrorAt time.Time + var lastError error + anyUp := false + anyUnknown := false + + for _, ps := range a.providerStatuses { + switch ps.Status { + case rpcstatus.StatusUp: + anyUp = true + if ps.LastSuccessAt.After(lastSuccessAt) { + lastSuccessAt = ps.LastSuccessAt + } + case rpcstatus.StatusUnknown: + anyUnknown = true + case rpcstatus.StatusDown: + if ps.LastErrorAt.After(lastErrorAt) { + lastErrorAt = ps.LastErrorAt + lastError = ps.LastError + } + } + } + + aggregatedStatus := rpcstatus.ProviderStatus{ + Name: a.name, + LastSuccessAt: lastSuccessAt, + LastErrorAt: lastErrorAt, + LastError: lastError, + } + if len(a.providerStatuses) == 0 { + aggregatedStatus.Status = rpcstatus.StatusDown + } else if anyUp { + aggregatedStatus.Status = rpcstatus.StatusUp + } else if anyUnknown { + aggregatedStatus.Status = rpcstatus.StatusUnknown + } else { + aggregatedStatus.Status = rpcstatus.StatusDown + } + + return aggregatedStatus +} + +func (a *Aggregator) GetAggregatedStatus() rpcstatus.ProviderStatus { + return a.ComputeAggregatedStatus() +} + +func (a *Aggregator) GetStatuses() map[string]rpcstatus.ProviderStatus { + a.mu.RLock() + defer a.mu.RUnlock() + + statusesCopy := make(map[string]rpcstatus.ProviderStatus) + for k, v := range a.providerStatuses { + statusesCopy[k] = *v + } + return statusesCopy +} diff --git a/healthmanager/aggregator/aggregator_test.go b/healthmanager/aggregator/aggregator_test.go new file mode 100644 index 000000000..87e029bf2 --- /dev/null +++ b/healthmanager/aggregator/aggregator_test.go @@ -0,0 +1,312 @@ +package aggregator + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +// StatusAggregatorTestSuite defines the test suite for Aggregator. +type StatusAggregatorTestSuite struct { + suite.Suite + aggregator *Aggregator +} + +// SetupTest runs before each test in the suite. +func (suite *StatusAggregatorTestSuite) SetupTest() { + suite.aggregator = NewAggregator("TestAggregator") +} + +// TestNewAggregator verifies that a new Aggregator is initialized correctly. +func (suite *StatusAggregatorTestSuite) TestNewAggregator() { + assert.Equal(suite.T(), "TestAggregator", suite.aggregator.name, "Aggregator name should be set correctly") + assert.Empty(suite.T(), suite.aggregator.providerStatuses, "Aggregator should have no providers initially") +} + +// TestRegisterProvider verifies that providers are registered correctly. +func (suite *StatusAggregatorTestSuite) TestRegisterProvider() { + providerName := "Provider1" + suite.aggregator.RegisterProvider(providerName) + + assert.Len(suite.T(), suite.aggregator.providerStatuses, 1, "Expected 1 provider after registration") + _, exists := suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should be registered") + + // Attempt to register the same provider again + suite.aggregator.RegisterProvider(providerName) + assert.Len(suite.T(), suite.aggregator.providerStatuses, 1, "Duplicate registration should not increase provider count") +} + +// TestUpdate verifies that updating a provider's status works correctly. +func (suite *StatusAggregatorTestSuite) TestUpdate() { + providerName := "Provider1" + suite.aggregator.RegisterProvider(providerName) + + now := time.Now() + + // Update existing provider to up + statusUp := rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + } + suite.aggregator.Update(statusUp) + + ps, exists := suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should exist after update") + assert.Equal(suite.T(), rpcstatus.StatusUp, ps.Status, "Provider1 status should be 'up'") + assert.Equal(suite.T(), now, ps.LastSuccessAt, "Provider1 LastSuccessAt should be updated") + + // Update existing provider to down + nowDown := now.Add(1 * time.Hour) + statusDown := rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusDown, + LastErrorAt: nowDown, + } + suite.aggregator.Update(statusDown) + + ps, exists = suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should exist after second update") + assert.Equal(suite.T(), rpcstatus.StatusDown, ps.Status, "Provider1 status should be 'down'") + assert.Equal(suite.T(), nowDown, ps.LastErrorAt, "Provider1 LastErrorAt should be updated") + + // Update a non-registered provider via Update (should add it) + provider2 := "Provider2" + statusUp2 := rpcstatus.ProviderStatus{ + Name: provider2, + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + } + suite.aggregator.Update(statusUp2) + + assert.Len(suite.T(), suite.aggregator.providerStatuses, 2, "Expected 2 providers after updating a new provider") + ps2, exists := suite.aggregator.providerStatuses[provider2] + assert.True(suite.T(), exists, "Provider2 should be added via Update") + assert.Equal(suite.T(), rpcstatus.StatusUp, ps2.Status, "Provider2 status should be 'up'") +} + +// TestComputeAggregatedStatus_NoProviders verifies aggregated status when no providers are registered. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_NoProviders() { + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusDown, aggStatus.Status, "Aggregated status should be 'down' when no providers are registered") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when no providers are registered") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when no providers are registered") +} + +// TestComputeAggregatedStatus_AllUnknown verifies aggregated status when all providers are unknown. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllUnknown() { + // Register multiple providers with unknown status + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + suite.aggregator.RegisterProvider("Provider3") + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUnknown, aggStatus.Status, "Aggregated status should be 'unknown' when all providers are unknown") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when all providers are unknown") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when all providers are unknown") +} + +// TestComputeAggregatedStatus_AllUp verifies aggregated status when all providers are up. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllUp() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now1 := time.Now() + now2 := now1.Add(1 * time.Hour) + + // Update all providers to up + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusUp, + LastSuccessAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when all providers are up") + assert.Equal(suite.T(), now2, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when all providers are up") +} + +// TestComputeAggregatedStatus_AllDown verifies aggregated status when all providers are down. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllDown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now1 := time.Now() + now2 := now1.Add(1 * time.Hour) + + // Update all providers to down + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusDown, + LastErrorAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusDown, aggStatus.Status, "Aggregated status should be 'down' when all providers are down") + assert.Equal(suite.T(), now2, aggStatus.LastErrorAt, "LastErrorAt should reflect the latest error time") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when all providers are down") +} + +// TestComputeAggregatedStatus_MixedUpAndUnknown verifies aggregated status with mixed up and unknown providers. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_MixedUpAndUnknown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") // up + suite.aggregator.RegisterProvider("Provider2") // unknown + suite.aggregator.RegisterProvider("Provider3") // up + + now1 := time.Now() + now2 := now1.Add(30 * time.Minute) + + // Update some providers to up + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider3", + Status: rpcstatus.StatusUp, + LastSuccessAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now2, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when no providers are down") +} + +// TestComputeAggregatedStatus_MixedUpAndDown verifies aggregated status with mixed up and down providers. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_MixedUpAndDown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") // up + suite.aggregator.RegisterProvider("Provider2") // down + suite.aggregator.RegisterProvider("Provider3") // up + + now1 := time.Now() + now2 := now1.Add(15 * time.Minute) + + // Update providers + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now2, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider3", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now1, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.Equal(suite.T(), now2, aggStatus.LastErrorAt, "LastErrorAt should reflect the latest error time") +} + +// TestGetAggregatedStatus verifies that GetAggregatedStatus returns the correct aggregated status. +func (suite *StatusAggregatorTestSuite) TestGetAggregatedStatus() { + // Register and update providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now := time.Now() + + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now.Add(1 * time.Hour), + }) + + aggStatus := suite.aggregator.GetAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the provider's success time") + assert.Equal(suite.T(), now.Add(1*time.Hour), aggStatus.LastErrorAt, "LastErrorAt should reflect the provider's error time") +} + +// TestConcurrentAccess verifies that the Aggregator is safe for concurrent use. +func (suite *StatusAggregatorTestSuite) TestConcurrentAccess() { + // Register multiple providers + providers := []string{"Provider1", "Provider2", "Provider3", "Provider4", "Provider5"} + for _, p := range providers { + suite.aggregator.RegisterProvider(p) + } + + var wg sync.WaitGroup + + // Concurrently update providers + for _, p := range providers { + wg.Add(1) + go func(providerName string) { + defer wg.Done() + for i := 0; i < 1000; i++ { + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUp, + LastSuccessAt: time.Now(), + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusDown, + LastErrorAt: time.Now(), + }) + } + }(p) + } + + // Wait for all goroutines to finish + wg.Wait() + + // Set all providers to down to ensure deterministic aggregated status + now := time.Now() + for _, p := range providers { + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: p, + Status: rpcstatus.StatusDown, + LastErrorAt: now, + }) + } + + aggStatus := suite.aggregator.GetAggregatedStatus() + assert.Equal(suite.T(), rpcstatus.StatusDown, aggStatus.Status, "Aggregated status should be 'down' after setting all providers to down") +} + +// TestStatusAggregatorTestSuite runs the test suite. +func TestStatusAggregatorTestSuite(t *testing.T) { + suite.Run(t, new(StatusAggregatorTestSuite)) +} diff --git a/healthmanager/blockchain_health_manager.go b/healthmanager/blockchain_health_manager.go new file mode 100644 index 000000000..7fcdf5d6d --- /dev/null +++ b/healthmanager/blockchain_health_manager.go @@ -0,0 +1,247 @@ +package healthmanager + +import ( + "context" + "sync" + + "github.com/status-im/status-go/healthmanager/aggregator" + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +// BlockchainFullStatus contains the full status of the blockchain, including provider statuses. +type BlockchainFullStatus struct { + Status rpcstatus.ProviderStatus `json:"status"` + StatusPerChain map[uint64]rpcstatus.ProviderStatus `json:"statusPerChain"` + StatusPerChainPerProvider map[uint64]map[string]rpcstatus.ProviderStatus `json:"statusPerChainPerProvider"` +} + +// BlockchainStatus contains the status of the blockchain +type BlockchainStatus struct { + Status rpcstatus.ProviderStatus `json:"status"` + StatusPerChain map[uint64]rpcstatus.ProviderStatus `json:"statusPerChain"` +} + +// BlockchainHealthManager manages the state of all providers and aggregates their statuses. +type BlockchainHealthManager struct { + mu sync.RWMutex + aggregator *aggregator.Aggregator + subscribers sync.Map // thread-safe + + providers map[uint64]*ProvidersHealthManager + cancelFuncs map[uint64]context.CancelFunc // Map chainID to cancel functions + lastStatus *BlockchainStatus + wg sync.WaitGroup +} + +// NewBlockchainHealthManager creates a new instance of BlockchainHealthManager. +func NewBlockchainHealthManager() *BlockchainHealthManager { + agg := aggregator.NewAggregator("blockchain") + return &BlockchainHealthManager{ + aggregator: agg, + providers: make(map[uint64]*ProvidersHealthManager), + cancelFuncs: make(map[uint64]context.CancelFunc), + } +} + +// RegisterProvidersHealthManager registers the provider health manager. +// It removes any existing provider for the same chain before registering the new one. +func (b *BlockchainHealthManager) RegisterProvidersHealthManager(ctx context.Context, phm *ProvidersHealthManager) error { + b.mu.Lock() + defer b.mu.Unlock() + + chainID := phm.ChainID() + + // Check if a provider for the given chainID is already registered and remove it + if _, exists := b.providers[chainID]; exists { + // Cancel the existing context + if cancel, cancelExists := b.cancelFuncs[chainID]; cancelExists { + cancel() + } + // Remove the old registration + delete(b.providers, chainID) + delete(b.cancelFuncs, chainID) + } + + // Proceed with the registration + b.providers[chainID] = phm + + // Create a new context for the provider + providerCtx, cancel := context.WithCancel(ctx) + b.cancelFuncs[chainID] = cancel + + statusCh := phm.Subscribe() + b.wg.Add(1) + go func(phm *ProvidersHealthManager, statusCh chan struct{}, providerCtx context.Context) { + defer func() { + phm.Unsubscribe(statusCh) + b.wg.Done() + }() + for { + select { + case <-statusCh: + // When the provider updates its status, check the statuses of all providers + b.aggregateAndUpdateStatus(providerCtx) + case <-providerCtx.Done(): + // Stop processing when the context is cancelled + return + } + } + }(phm, statusCh, providerCtx) + + return nil +} + +// Stop stops the event processing and unsubscribes. +func (b *BlockchainHealthManager) Stop() { + b.mu.Lock() + + for _, cancel := range b.cancelFuncs { + cancel() + } + clear(b.cancelFuncs) + clear(b.providers) + + b.mu.Unlock() + b.wg.Wait() +} + +// Subscribe allows clients to receive notifications about changes. +func (b *BlockchainHealthManager) Subscribe() chan struct{} { + ch := make(chan struct{}, 1) + b.subscribers.Store(ch, struct{}{}) + return ch +} + +// Unsubscribe removes a subscriber from receiving notifications. +func (b *BlockchainHealthManager) Unsubscribe(ch chan struct{}) { + b.subscribers.Delete(ch) // Удаляем подписчика из sync.Map + close(ch) +} + +// aggregateAndUpdateStatus collects statuses from all providers and updates the overall and short status. +func (b *BlockchainHealthManager) aggregateAndUpdateStatus(ctx context.Context) { + newShortStatus := b.aggregateStatus() + + // If status has changed, update the last status and emit notifications + if b.shouldUpdateStatus(newShortStatus) { + b.updateStatus(newShortStatus) + b.emitBlockchainHealthStatus(ctx) + } +} + +// aggregateStatus aggregates provider statuses and returns the new short status. +func (b *BlockchainHealthManager) aggregateStatus() BlockchainStatus { + b.mu.Lock() + defer b.mu.Unlock() + + // Collect statuses from all providers + providerStatuses := make([]rpcstatus.ProviderStatus, 0) + for _, provider := range b.providers { + providerStatuses = append(providerStatuses, provider.Status()) + } + + // Update the aggregator with the new list of provider statuses + b.aggregator.UpdateBatch(providerStatuses) + + // Get the new aggregated full and short status + return b.getStatusPerChain() +} + +// shouldUpdateStatus checks if the status has changed and needs to be updated. +func (b *BlockchainHealthManager) shouldUpdateStatus(newShortStatus BlockchainStatus) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.lastStatus == nil || !compareShortStatus(newShortStatus, *b.lastStatus) +} + +// updateStatus updates the last known status with the new status. +func (b *BlockchainHealthManager) updateStatus(newShortStatus BlockchainStatus) { + b.mu.Lock() + defer b.mu.Unlock() + b.lastStatus = &newShortStatus +} + +// compareShortStatus compares two BlockchainStatus structs and returns true if they are identical. +func compareShortStatus(newStatus, previousStatus BlockchainStatus) bool { + if newStatus.Status.Status != previousStatus.Status.Status { + return false + } + + if len(newStatus.StatusPerChain) != len(previousStatus.StatusPerChain) { + return false + } + + for chainID, newChainStatus := range newStatus.StatusPerChain { + if prevChainStatus, ok := previousStatus.StatusPerChain[chainID]; !ok || newChainStatus.Status != prevChainStatus.Status { + return false + } + } + + return true +} + +// emitBlockchainHealthStatus sends a notification to all subscribers about the new blockchain status. +func (b *BlockchainHealthManager) emitBlockchainHealthStatus(ctx context.Context) { + b.subscribers.Range(func(key, value interface{}) bool { + subscriber := key.(chan struct{}) + select { + case <-ctx.Done(): + // Stop sending notifications when the context is cancelled + return false + case subscriber <- struct{}{}: + default: + // Skip notification if the subscriber's channel is full (non-blocking) + } + return true + }) +} + +func (b *BlockchainHealthManager) GetFullStatus() BlockchainFullStatus { + b.mu.RLock() + defer b.mu.RUnlock() + + statusPerChainPerProvider := make(map[uint64]map[string]rpcstatus.ProviderStatus) + + for chainID, phm := range b.providers { + providerStatuses := phm.GetStatuses() + statusPerChainPerProvider[chainID] = providerStatuses + } + + statusPerChain := b.getStatusPerChain() + + return BlockchainFullStatus{ + Status: statusPerChain.Status, + StatusPerChain: statusPerChain.StatusPerChain, + StatusPerChainPerProvider: statusPerChainPerProvider, + } +} + +func (b *BlockchainHealthManager) getStatusPerChain() BlockchainStatus { + statusPerChain := make(map[uint64]rpcstatus.ProviderStatus) + + for chainID, phm := range b.providers { + chainStatus := phm.Status() + statusPerChain[chainID] = chainStatus + } + + blockchainStatus := b.aggregator.GetAggregatedStatus() + + return BlockchainStatus{ + Status: blockchainStatus, + StatusPerChain: statusPerChain, + } +} + +func (b *BlockchainHealthManager) GetStatusPerChain() BlockchainStatus { + b.mu.RLock() + defer b.mu.RUnlock() + return b.getStatusPerChain() +} + +// Status returns the current aggregated status. +func (b *BlockchainHealthManager) Status() rpcstatus.ProviderStatus { + b.mu.RLock() + defer b.mu.RUnlock() + return b.aggregator.GetAggregatedStatus() +} diff --git a/healthmanager/blockchain_health_manager_test.go b/healthmanager/blockchain_health_manager_test.go new file mode 100644 index 000000000..88fdc447c --- /dev/null +++ b/healthmanager/blockchain_health_manager_test.go @@ -0,0 +1,237 @@ +package healthmanager + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +type BlockchainHealthManagerSuite struct { + suite.Suite + manager *BlockchainHealthManager + ctx context.Context + cancel context.CancelFunc +} + +func (s *BlockchainHealthManagerSuite) SetupTest() { + s.manager = NewBlockchainHealthManager() + s.ctx, s.cancel = context.WithCancel(context.Background()) +} + +func (s *BlockchainHealthManagerSuite) TearDownTest() { + s.manager.Stop() + s.cancel() +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *BlockchainHealthManagerSuite) waitForUpdate(ch <-chan struct{}, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + select { + case <-ch: + // Received notification + case <-time.After(timeout): + s.Fail("Timeout waiting for chain status update") + } + + s.assertBlockChainStatus(expectedChainStatus) +} + +// Helper method to assert the current chain status +func (s *BlockchainHealthManagerSuite) assertBlockChainStatus(expected rpcstatus.StatusType) { + actual := s.manager.Status().Status + s.Equal(expected, actual, fmt.Sprintf("Expected blockchain status to be %s", expected)) +} + +// Test registering a provider health manager +func (s *BlockchainHealthManagerSuite) TestRegisterProvidersHealthManager() { + phm := NewProvidersHealthManager(1) // Create a real ProvidersHealthManager + err := s.manager.RegisterProvidersHealthManager(context.Background(), phm) + s.Require().NoError(err) + + // Verify that the provider is registered + s.Require().NotNil(s.manager.providers[1]) +} + +// Test status updates and notifications +func (s *BlockchainHealthManagerSuite) TestStatusUpdateNotification() { + phm := NewProvidersHealthManager(1) + err := s.manager.RegisterProvidersHealthManager(context.Background(), phm) + s.Require().NoError(err) + ch := s.manager.Subscribe() + + // Update the provider status + phm.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName", Timestamp: time.Now(), Err: nil}, + }) + + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) +} + +// Test getting the full status +func (s *BlockchainHealthManagerSuite) TestGetFullStatus() { + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + ctx := context.Background() + err := s.manager.RegisterProvidersHealthManager(ctx, phm1) + s.Require().NoError(err) + err = s.manager.RegisterProvidersHealthManager(ctx, phm2) + s.Require().NoError(err) + ch := s.manager.Subscribe() + + // Update the provider status + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName1", Timestamp: time.Now(), Err: nil}, + }) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName2", Timestamp: time.Now(), Err: errors.New("connection error")}, + }) + + s.waitForUpdate(ch, rpcstatus.StatusUp, 10*time.Millisecond) + fullStatus := s.manager.GetFullStatus() + s.Len(fullStatus.StatusPerChainPerProvider, 2, "Expected statuses for 2 chains") +} + +func (s *BlockchainHealthManagerSuite) TestConcurrentSubscriptionUnsubscription() { + var wg sync.WaitGroup + subscribersCount := 100 + + // Concurrently add and remove subscribers + for i := 0; i < subscribersCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + subCh := s.manager.Subscribe() + time.Sleep(10 * time.Millisecond) + s.manager.Unsubscribe(subCh) + }() + } + + wg.Wait() + + activeSubscribersCount := 0 + s.manager.subscribers.Range(func(key, value interface{}) bool { + activeSubscribersCount++ + return true + }) + + // After all subscribers are removed, there should be no active subscribers + s.Equal(0, activeSubscribersCount, "Expected no subscribers after unsubscription") +} + +func (s *BlockchainHealthManagerSuite) TestConcurrency() { + var wg sync.WaitGroup + chainsCount := 10 + providersCount := 100 + ctx, cancel := context.WithCancel(s.ctx) + defer cancel() + for i := 1; i <= chainsCount; i++ { + phm := NewProvidersHealthManager(uint64(i)) + err := s.manager.RegisterProvidersHealthManager(ctx, phm) + s.Require().NoError(err) + } + + ch := s.manager.Subscribe() + + for i := 1; i <= chainsCount; i++ { + wg.Add(1) + go func(chainID uint64) { + defer wg.Done() + phm := s.manager.providers[chainID] + for j := 0; j < providersCount; j++ { + err := errors.New("connection error") + if j == providersCount-1 { + err = nil + } + name := fmt.Sprintf("provider-%d", j) + go phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: name, Timestamp: time.Now(), Err: err}, + }) + } + }(uint64(i)) + } + + wg.Wait() + + s.waitForUpdate(ch, rpcstatus.StatusUp, 2*time.Second) +} + +func (s *BlockchainHealthManagerSuite) TestMultipleStartAndStop() { + s.manager.Stop() + + s.manager.Stop() + + // Ensure that the manager is in a clean state after multiple starts and stops + s.Equal(0, len(s.manager.cancelFuncs), "Expected no cancel functions after stop") +} + +func (s *BlockchainHealthManagerSuite) TestUnsubscribeOneOfMultipleSubscribers() { + // Create an instance of BlockchainHealthManager and register a provider manager + phm := NewProvidersHealthManager(1) + ctx, cancel := context.WithCancel(s.ctx) + err := s.manager.RegisterProvidersHealthManager(ctx, phm) + s.Require().NoError(err) + + defer cancel() + + // Subscribe two subscribers + subscriber1 := s.manager.Subscribe() + subscriber2 := s.manager.Subscribe() + + // Unsubscribe the first subscriber + s.manager.Unsubscribe(subscriber1) + + phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "provider-1", Timestamp: time.Now(), Err: nil}, + }) + + // Ensure the first subscriber did not receive a notification + select { + case _, ok := <-subscriber1: + s.False(ok, "First subscriber channel should be closed") + default: + s.Fail("First subscriber channel was not closed") + } + + // Ensure the second subscriber received a notification + select { + case <-subscriber2: + // Notification received by the second subscriber + case <-time.After(100 * time.Millisecond): + s.Fail("Second subscriber should have received a notification") + } +} + +func (s *BlockchainHealthManagerSuite) TestMixedProviderStatusInSingleChain() { + // Register a provider for chain 1 + phm := NewProvidersHealthManager(1) + err := s.manager.RegisterProvidersHealthManager(s.ctx, phm) + s.Require().NoError(err) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Simulate mixed statuses within the same chain (one provider up, one provider down) + phm.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "provider1_chain1", Timestamp: time.Now(), Err: nil}, // Provider 1 is up + {Name: "provider2_chain1", Timestamp: time.Now(), Err: errors.New("error")}, // Provider 2 is down + }) + + // Wait for the status to propagate + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Verify that the short status reflects the chain as down, since one provider is down + shortStatus := s.manager.GetStatusPerChain() + s.Equal(rpcstatus.StatusUp, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusUp, shortStatus.StatusPerChain[1].Status) // Chain 1 should be marked as down +} + +func TestBlockchainHealthManagerSuite(t *testing.T) { + suite.Run(t, new(BlockchainHealthManagerSuite)) +} diff --git a/healthmanager/provider_errors/provider_errors.go b/healthmanager/provider_errors/provider_errors.go new file mode 100644 index 000000000..bec09d0a1 --- /dev/null +++ b/healthmanager/provider_errors/provider_errors.go @@ -0,0 +1,247 @@ +package provider_errors + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "strings" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/rpc/chain/rpclimiter" +) + +// ProviderErrorType defines the type of non-RPC error for JSON serialization. +type ProviderErrorType string + +const ( + // Non-RPC Errors + ProviderErrorTypeNone ProviderErrorType = "none" + ProviderErrorTypeContextCanceled ProviderErrorType = "context_canceled" + ProviderErrorTypeContextDeadlineExceeded ProviderErrorType = "context_deadline" + ProviderErrorTypeConnection ProviderErrorType = "connection" + ProviderErrorTypeNotAuthorized ProviderErrorType = "not_authorized" + ProviderErrorTypeForbidden ProviderErrorType = "forbidden" + ProviderErrorTypeBadRequest ProviderErrorType = "bad_request" + ProviderErrorTypeContentTooLarge ProviderErrorType = "content_too_large" + ProviderErrorTypeInternalError ProviderErrorType = "internal" + ProviderErrorTypeServiceUnavailable ProviderErrorType = "service_unavailable" + ProviderErrorTypeRateLimit ProviderErrorType = "rate_limit" + ProviderErrorTypeOther ProviderErrorType = "other" +) + +// IsConnectionError checks if the error is related to network issues. +func IsConnectionError(err error) bool { + if err == nil { + return false + } + + // Check for net.Error (timeout or other network errors) + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return true + } + } + + // Check for DNS errors + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return true + } + + // Check for network operation errors (e.g., connection refused) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // Check for TLS errors + var tlsRecordErr *tls.RecordHeaderError + if errors.As(err, &tlsRecordErr) { + return true + } + + // FIXME: Check for TLS ECH Rejection Error (tls.ECHRejectionError is added in go 1.23) + + // Check for TLS Certificate Verification Error + var certVerifyErr *tls.CertificateVerificationError + if errors.As(err, &certVerifyErr) { + return true + } + + // Check for TLS Alert Error + var alertErr tls.AlertError + if errors.As(err, &alertErr) { + return true + } + + // Check for specific HTTP server closed error + if errors.Is(err, http.ErrServerClosed) { + return true + } + + // Common connection refused or timeout error messages + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "i/o timeout") || + strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "network is unreachable") || + strings.Contains(errMsg, "no such host") || + strings.Contains(errMsg, "tls handshake timeout") { + return true + } + + return false +} + +func IsRateLimitError(err error) bool { + if err == nil { + return false + } + + if ok, statusCode := IsHTTPError(err); ok && statusCode == 429 { + return true + } + + if errors.Is(err, rpclimiter.ErrRequestsOverLimit) { + return true + } + + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "backoff_seconds") || + strings.Contains(errMsg, "has exceeded its throughput limit") || + strings.Contains(errMsg, "request rate exceeded") { + return true + } + return false +} + +// Don't mark connection as failed if we get one of these errors +var propagateErrors = []error{ + vm.ErrOutOfGas, + vm.ErrCodeStoreOutOfGas, + vm.ErrDepth, + vm.ErrInsufficientBalance, + vm.ErrContractAddressCollision, + vm.ErrExecutionReverted, + vm.ErrMaxCodeSizeExceeded, + vm.ErrInvalidJump, + vm.ErrWriteProtection, + vm.ErrReturnDataOutOfBounds, + vm.ErrGasUintOverflow, + vm.ErrInvalidCode, + vm.ErrNonceUintOverflow, + + // Used by balance history to check state + bind.ErrNoCode, +} + +func IsHTTPError(err error) (bool, int) { + var httpErrPtr *rpc.HTTPError + if errors.As(err, &httpErrPtr) { + return true, httpErrPtr.StatusCode + } + + var httpErr rpc.HTTPError + if errors.As(err, &httpErr) { + return true, httpErr.StatusCode + } + + return false, 0 +} + +func IsNotAuthorizedError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 401 + } + return false +} + +func IsForbiddenError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 403 + } + return false +} + +func IsBadRequestError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 400 + } + return false +} + +func IsContentTooLargeError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 413 + } + return false +} + +func IsInternalServerError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 500 + } + return false +} + +func IsServiceUnavailableError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 503 + } + return false +} + +// determineProviderErrorType determines the ProviderErrorType based on the error. +func determineProviderErrorType(err error) ProviderErrorType { + if err == nil { + return ProviderErrorTypeNone + } + if errors.Is(err, context.Canceled) { + return ProviderErrorTypeContextCanceled + } + if errors.Is(err, context.DeadlineExceeded) { + return ProviderErrorTypeContextDeadlineExceeded + } + if IsConnectionError(err) { + return ProviderErrorTypeConnection + } + if IsNotAuthorizedError(err) { + return ProviderErrorTypeNotAuthorized + } + if IsForbiddenError(err) { + return ProviderErrorTypeForbidden + } + if IsBadRequestError(err) { + return ProviderErrorTypeBadRequest + } + if IsContentTooLargeError(err) { + return ProviderErrorTypeContentTooLarge + } + if IsInternalServerError(err) { + return ProviderErrorTypeInternalError + } + if IsServiceUnavailableError(err) { + return ProviderErrorTypeServiceUnavailable + } + if IsRateLimitError(err) { + return ProviderErrorTypeRateLimit + } + // Add additional non-RPC checks as necessary + return ProviderErrorTypeOther +} + +// IsNonCriticalProviderError determines if the non-RPC error is not critical. +func IsNonCriticalProviderError(err error) bool { + errorType := determineProviderErrorType(err) + + switch errorType { + case ProviderErrorTypeNone, ProviderErrorTypeContextCanceled, ProviderErrorTypeContentTooLarge, ProviderErrorTypeRateLimit: + return true + default: + return false + } +} diff --git a/healthmanager/provider_errors/provider_errors_test.go b/healthmanager/provider_errors/provider_errors_test.go new file mode 100644 index 000000000..34af6b81b --- /dev/null +++ b/healthmanager/provider_errors/provider_errors_test.go @@ -0,0 +1,115 @@ +package provider_errors + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "testing" +) + +// TestIsConnectionError tests the IsConnectionError function. +func TestIsConnectionError(t *testing.T) { + tests := []struct { + name string + err error + wantResult bool + }{ + { + name: "nil error", + err: nil, + wantResult: false, + }, + { + name: "net.DNSError with timeout", + err: &net.DNSError{IsTimeout: true}, + wantResult: true, + }, + { + name: "DNS error without timeout", + err: &net.DNSError{}, + wantResult: true, + }, + { + name: "net.OpError", + err: &net.OpError{}, + wantResult: true, + }, + { + name: "tls.RecordHeaderError", + err: &tls.RecordHeaderError{}, + wantResult: true, + }, + { + name: "tls.CertificateVerificationError", + err: &tls.CertificateVerificationError{}, + wantResult: true, + }, + { + name: "tls.AlertError", + err: tls.AlertError(0), + wantResult: true, + }, + { + name: "context.DeadlineExceeded", + err: context.DeadlineExceeded, + wantResult: true, + }, + { + name: "http.ErrServerClosed", + err: http.ErrServerClosed, + wantResult: true, + }, + { + name: "i/o timeout error message", + err: errors.New("i/o timeout"), + wantResult: true, + }, + { + name: "connection refused error message", + err: errors.New("connection refused"), + wantResult: true, + }, + { + name: "network is unreachable error message", + err: errors.New("network is unreachable"), + wantResult: true, + }, + { + name: "no such host error message", + err: errors.New("no such host"), + wantResult: true, + }, + { + name: "tls handshake timeout error message", + err: errors.New("tls handshake timeout"), + wantResult: true, + }, + { + name: "rps limit error 1", + err: errors.New("backoff_seconds"), + wantResult: false, + }, + { + name: "rps limit error 2", + err: errors.New("has exceeded its throughput limit"), + wantResult: false, + }, + { + name: "rps limit error 3", + err: errors.New("request rate exceeded"), + wantResult: false, + }, + } + + for _, tt := range tests { + tt := tt // capture the variable + t.Run(tt.name, func(t *testing.T) { + got := IsConnectionError(tt.err) + if got != tt.wantResult { + t.Errorf("IsConnectionError(%v) = %v; want %v", tt.err, got, tt.wantResult) + } + }) + } +} diff --git a/healthmanager/provider_errors/rpc_provider_errors.go b/healthmanager/provider_errors/rpc_provider_errors.go new file mode 100644 index 000000000..7f1c8c90e --- /dev/null +++ b/healthmanager/provider_errors/rpc_provider_errors.go @@ -0,0 +1,87 @@ +package provider_errors + +import ( + "errors" + "strings" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/rpc" +) + +type RpcProviderErrorType string + +const ( + // RPC Errors + RpcErrorTypeNone RpcProviderErrorType = "none" + RpcErrorTypeMethodNotFound RpcProviderErrorType = "rpc_method_not_found" + RpcErrorTypeRPSLimit RpcProviderErrorType = "rpc_rps_limit" + RpcErrorTypeVMError RpcProviderErrorType = "rpc_vm_error" + RpcErrorTypeRPCOther RpcProviderErrorType = "rpc_other" +) + +// Not found should not be cancelling the requests, as that's returned +// when we are hitting a non archival node for example, it should continue the +// chain as the next provider might have archival support. +func IsNotFoundError(err error) bool { + return strings.Contains(err.Error(), ethereum.NotFound.Error()) +} + +func IsRPCError(err error) (rpc.Error, bool) { + var rpcErr rpc.Error + if errors.As(err, &rpcErr) { + return rpcErr, true + } + return nil, false +} + +func IsMethodNotFoundError(err error) bool { + if rpcErr, ok := IsRPCError(err); ok { + return rpcErr.ErrorCode() == -32601 + } + return false +} + +func IsVMError(err error) bool { + if rpcErr, ok := IsRPCError(err); ok { + return rpcErr.ErrorCode() == -32015 // Код ошибки VM execution error + } + if strings.Contains(err.Error(), core.ErrInsufficientFunds.Error()) { + return true + } + for _, vmError := range propagateErrors { + if strings.Contains(err.Error(), vmError.Error()) { + return true + } + } + return false +} + +// determineRpcErrorType determines the RpcProviderErrorType based on the error. +func determineRpcErrorType(err error) RpcProviderErrorType { + if err == nil { + return RpcErrorTypeNone + } + //if IsRpsLimitError(err) { + // return RpcErrorTypeRPSLimit + //} + if IsMethodNotFoundError(err) || IsNotFoundError(err) { + return RpcErrorTypeMethodNotFound + } + if IsVMError(err) { + return RpcErrorTypeVMError + } + return RpcErrorTypeRPCOther +} + +// IsCriticalRpcError determines if the RPC error is critical. +func IsNonCriticalRpcError(err error) bool { + errorType := determineRpcErrorType(err) + + switch errorType { + case RpcErrorTypeNone, RpcErrorTypeMethodNotFound, RpcErrorTypeRPSLimit, RpcErrorTypeVMError: + return true + default: + return false + } +} diff --git a/healthmanager/provider_errors/rpc_provider_errors_test.go b/healthmanager/provider_errors/rpc_provider_errors_test.go new file mode 100644 index 000000000..3d16fae7f --- /dev/null +++ b/healthmanager/provider_errors/rpc_provider_errors_test.go @@ -0,0 +1,51 @@ +package provider_errors + +import ( + "errors" + "testing" +) + +// TestIsRpsLimitError tests the IsRpsLimitError function. +func TestIsRpsLimitError(t *testing.T) { + tests := []struct { + name string + err error + wantResult bool + }{ + { + name: "Error contains 'backoff_seconds'", + err: errors.New("Error: backoff_seconds: 30"), + wantResult: true, + }, + { + name: "Error contains 'has exceeded its throughput limit'", + err: errors.New("Your application has exceeded its throughput limit."), + wantResult: true, + }, + { + name: "Error contains 'request rate exceeded'", + err: errors.New("Request rate exceeded. Please try again later."), + wantResult: true, + }, + { + name: "Error does not contain any matching phrases", + err: errors.New("Some other error occurred."), + wantResult: false, + }, + { + name: "Error is nil", + err: nil, + wantResult: false, + }, + } + + for _, tt := range tests { + tt := tt // capture the variable + t.Run(tt.name, func(t *testing.T) { + got := IsRateLimitError(tt.err) + if got != tt.wantResult { + t.Errorf("IsRpsLimitError(%v) = %v; want %v", tt.err, got, tt.wantResult) + } + }) + } +} diff --git a/healthmanager/providers_health_manager.go b/healthmanager/providers_health_manager.go new file mode 100644 index 000000000..4d1c340e9 --- /dev/null +++ b/healthmanager/providers_health_manager.go @@ -0,0 +1,117 @@ +package healthmanager + +import ( + "context" + "fmt" + "sync" + + "github.com/status-im/status-go/healthmanager/aggregator" + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +type ProvidersHealthManager struct { + mu sync.RWMutex + chainID uint64 + aggregator *aggregator.Aggregator + subscribers sync.Map // Use sync.Map for concurrent access to subscribers + lastStatus *rpcstatus.ProviderStatus +} + +// NewProvidersHealthManager creates a new instance of ProvidersHealthManager with the given chain ID. +func NewProvidersHealthManager(chainID uint64) *ProvidersHealthManager { + agg := aggregator.NewAggregator(fmt.Sprintf("%d", chainID)) + + return &ProvidersHealthManager{ + chainID: chainID, + aggregator: agg, + } +} + +// Update processes a batch of provider call statuses, updates the aggregated status, and emits chain status changes if necessary. +func (p *ProvidersHealthManager) Update(ctx context.Context, callStatuses []rpcstatus.RpcProviderCallStatus) { + p.mu.Lock() + + // Update the aggregator with the new provider statuses + for _, rpcCallStatus := range callStatuses { + providerStatus := rpcstatus.NewRpcProviderStatus(rpcCallStatus) + p.aggregator.Update(providerStatus) + } + + newStatus := p.aggregator.GetAggregatedStatus() + + shouldEmit := p.lastStatus == nil || p.lastStatus.Status != newStatus.Status + p.mu.Unlock() + + if !shouldEmit { + return + } + + p.emitChainStatus(ctx) + p.mu.Lock() + defer p.mu.Unlock() + p.lastStatus = &newStatus +} + +// GetStatuses returns a copy of the current provider statuses. +func (p *ProvidersHealthManager) GetStatuses() map[string]rpcstatus.ProviderStatus { + p.mu.RLock() + defer p.mu.RUnlock() + return p.aggregator.GetStatuses() +} + +// Subscribe allows providers to receive notifications about changes. +func (p *ProvidersHealthManager) Subscribe() chan struct{} { + ch := make(chan struct{}, 1) + p.subscribers.Store(ch, struct{}{}) + return ch +} + +// Unsubscribe removes a subscriber from receiving notifications. +func (p *ProvidersHealthManager) Unsubscribe(ch chan struct{}) { + p.subscribers.Delete(ch) + close(ch) +} + +// UnsubscribeAll removes all subscriber channels. +func (p *ProvidersHealthManager) UnsubscribeAll() { + p.subscribers.Range(func(key, value interface{}) bool { + ch := key.(chan struct{}) + close(ch) + p.subscribers.Delete(key) + return true + }) +} + +// Reset clears all provider statuses and resets the chain status to unknown. +func (p *ProvidersHealthManager) Reset() { + p.mu.Lock() + defer p.mu.Unlock() + p.aggregator = aggregator.NewAggregator(fmt.Sprintf("%d", p.chainID)) +} + +// Status Returns the current aggregated status. +func (p *ProvidersHealthManager) Status() rpcstatus.ProviderStatus { + p.mu.RLock() + defer p.mu.RUnlock() + return p.aggregator.GetAggregatedStatus() +} + +// ChainID returns the ID of the chain. +func (p *ProvidersHealthManager) ChainID() uint64 { + return p.chainID +} + +// emitChainStatus sends a notification to all subscribers. +func (p *ProvidersHealthManager) emitChainStatus(ctx context.Context) { + p.subscribers.Range(func(key, value interface{}) bool { + subscriber := key.(chan struct{}) + select { + case subscriber <- struct{}{}: + case <-ctx.Done(): + return false // Stop sending if context is done + default: + // Non-blocking send; skip if the channel is full + } + return true + }) +} diff --git a/healthmanager/providers_health_manager_test.go b/healthmanager/providers_health_manager_test.go new file mode 100644 index 000000000..c40e3228f --- /dev/null +++ b/healthmanager/providers_health_manager_test.go @@ -0,0 +1,213 @@ +package healthmanager + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +type ProvidersHealthManagerSuite struct { + suite.Suite + phm *ProvidersHealthManager +} + +// SetupTest initializes the ProvidersHealthManager before each test +func (s *ProvidersHealthManagerSuite) SetupTest() { + s.phm = NewProvidersHealthManager(1) +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *ProvidersHealthManagerSuite) updateAndWait(ch <-chan struct{}, statuses []rpcstatus.RpcProviderCallStatus, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + s.phm.Update(context.Background(), statuses) + + select { + case <-ch: + // Received notification + case <-time.After(timeout): + s.Fail("Timeout waiting for chain status update") + } + + s.assertChainStatus(expectedChainStatus) +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *ProvidersHealthManagerSuite) updateAndExpectNoNotification(ch <-chan struct{}, statuses []rpcstatus.RpcProviderCallStatus, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + s.phm.Update(context.Background(), statuses) + + select { + case <-ch: + s.Fail("Unexpected status update") + case <-time.After(timeout): + // No notification as expected + } + + s.assertChainStatus(expectedChainStatus) +} + +// Helper method to assert the current chain status +func (s *ProvidersHealthManagerSuite) assertChainStatus(expected rpcstatus.StatusType) { + actual := s.phm.Status().Status + s.Equal(expected, actual, fmt.Sprintf("Expected chain status to be %s", expected)) +} + +func (s *ProvidersHealthManagerSuite) TestInitialStatus() { + s.assertChainStatus(rpcstatus.StatusDown) +} + +func (s *ProvidersHealthManagerSuite) TestUpdateProviderStatuses() { + s.updateAndWait(s.phm.Subscribe(), []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: nil}, + {Name: "Provider2", Timestamp: time.Now(), Err: errors.New("connection error")}, + }, rpcstatus.StatusUp, time.Second) + + statusMap := s.phm.GetStatuses() + s.Len(statusMap, 2, "Expected 2 provider statuses") + s.Equal(rpcstatus.StatusUp, statusMap["Provider1"].Status, "Expected Provider1 status to be Up") + s.Equal(rpcstatus.StatusDown, statusMap["Provider2"].Status, "Expected Provider2 status to be Down") +} + +func (s *ProvidersHealthManagerSuite) TestChainStatusUpdatesOnce() { + ch := s.phm.Subscribe() + s.assertChainStatus(rpcstatus.StatusDown) + + // Update providers to Down + statuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: errors.New("error")}, + {Name: "Provider2", Timestamp: time.Now(), Err: nil}, + } + s.updateAndWait(ch, statuses, rpcstatus.StatusUp, time.Second) + s.updateAndExpectNoNotification(ch, statuses, rpcstatus.StatusUp, 10*time.Millisecond) +} + +func (s *ProvidersHealthManagerSuite) TestSubscribeReceivesOnlyOnChange() { + ch := s.phm.Subscribe() + + // Update provider to Up and wait for notification + upStatuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: nil}, + } + s.updateAndWait(ch, upStatuses, rpcstatus.StatusUp, time.Second) + + // Update provider to Down and wait for notification + downStatuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: errors.New("some critical error")}, + } + s.updateAndWait(ch, downStatuses, rpcstatus.StatusDown, time.Second) + + s.updateAndExpectNoNotification(ch, downStatuses, rpcstatus.StatusDown, 10*time.Millisecond) +} + +func (s *ProvidersHealthManagerSuite) TestConcurrency() { + var wg sync.WaitGroup + providerCount := 1000 + + s.phm.Update(context.Background(), []rpcstatus.RpcProviderCallStatus{ + {Name: "ProviderUp", Timestamp: time.Now(), Err: nil}, + }) + + ctx := context.Background() + for i := 0; i < providerCount-1; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + providerName := fmt.Sprintf("Provider%d", i) + var err error + if i%2 == 0 { + err = errors.New("error") + } + s.phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: providerName, Timestamp: time.Now(), Err: err}, + }) + }(i) + } + wg.Wait() + + statuses := s.phm.GetStatuses() + s.Len(statuses, providerCount, "Expected 1000 provider statuses") + + chainStatus := s.phm.Status().Status + s.Equal(chainStatus, rpcstatus.StatusUp, "Expected chain status to be either Up or Down") +} + +func (s *BlockchainHealthManagerSuite) TestInterleavedChainStatusChanges() { + // Register providers for chains 1, 2, and 3 + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + phm3 := NewProvidersHealthManager(3) + err := s.manager.RegisterProvidersHealthManager(s.ctx, phm1) + s.Require().NoError(err) + err = s.manager.RegisterProvidersHealthManager(s.ctx, phm2) + s.Require().NoError(err) + err = s.manager.RegisterProvidersHealthManager(s.ctx, phm3) + s.Require().NoError(err) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Initially, all chains are up + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain1", Timestamp: time.Now(), Err: nil}}) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain2", Timestamp: time.Now(), Err: nil}}) + phm3.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain3", Timestamp: time.Now(), Err: nil}}) + + // Wait for the status to propagate + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Now chain 1 goes down, and chain 3 goes down at the same time + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain1", Timestamp: time.Now(), Err: errors.New("connection error")}}) + phm3.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain3", Timestamp: time.Now(), Err: errors.New("connection error")}}) + + // Wait for the status to reflect the changes + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Check that short status correctly reflects the mixed state + shortStatus := s.manager.GetStatusPerChain() + s.Equal(rpcstatus.StatusUp, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[1].Status) // Chain 1 is down + s.Equal(rpcstatus.StatusUp, shortStatus.StatusPerChain[2].Status) // Chain 2 is still up + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[3].Status) // Chain 3 is down +} + +func (s *BlockchainHealthManagerSuite) TestDelayedChainUpdate() { + // Register providers for chains 1 and 2 + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + err := s.manager.RegisterProvidersHealthManager(s.ctx, phm1) + s.Require().NoError(err) + err = s.manager.RegisterProvidersHealthManager(s.ctx, phm2) + s.Require().NoError(err) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Initially, both chains are up + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain1", Timestamp: time.Now(), Err: nil}}) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain2", Timestamp: time.Now(), Err: nil}}) + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Chain 2 goes down + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain2", Timestamp: time.Now(), Err: errors.New("connection error")}}) + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Chain 1 goes down after a delay + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain1", Timestamp: time.Now(), Err: errors.New("connection error")}}) + s.waitForUpdate(ch, rpcstatus.StatusDown, 100*time.Millisecond) + + // Check that short status reflects the final state where both chains are down + shortStatus := s.manager.GetStatusPerChain() + s.Equal(rpcstatus.StatusDown, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[1].Status) // Chain 1 is down + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[2].Status) // Chain 2 is down +} + +func TestProvidersHealthManagerSuite(t *testing.T) { + suite.Run(t, new(ProvidersHealthManagerSuite)) +} diff --git a/healthmanager/rpcstatus/provider_status.go b/healthmanager/rpcstatus/provider_status.go new file mode 100644 index 000000000..26aabfc44 --- /dev/null +++ b/healthmanager/rpcstatus/provider_status.go @@ -0,0 +1,77 @@ +package rpcstatus + +import ( + "time" + + "github.com/status-im/status-go/healthmanager/provider_errors" +) + +// StatusType represents the possible status values for a provider. +type StatusType string + +const ( + StatusUnknown StatusType = "unknown" + StatusUp StatusType = "up" + StatusDown StatusType = "down" +) + +// ProviderStatus holds the status information for a single provider. +type ProviderStatus struct { + Name string `json:"name"` + LastSuccessAt time.Time `json:"last_success_at"` + LastErrorAt time.Time `json:"last_error_at"` + LastError error `json:"last_error"` + Status StatusType `json:"status"` +} + +// ProviderCallStatus represents the result of an arbitrary provider call. +type ProviderCallStatus struct { + Name string + Timestamp time.Time + Err error +} + +// RpcProviderCallStatus represents the result of an RPC provider call. +type RpcProviderCallStatus struct { + Name string + Timestamp time.Time + Err error +} + +// NewRpcProviderStatus processes RpcProviderCallStatus and returns a new ProviderStatus. +func NewRpcProviderStatus(res RpcProviderCallStatus) ProviderStatus { + status := ProviderStatus{ + Name: res.Name, + } + + // Determine if the error is critical + if res.Err == nil || provider_errors.IsNonCriticalRpcError(res.Err) || provider_errors.IsNonCriticalProviderError(res.Err) { + status.LastSuccessAt = res.Timestamp + status.Status = StatusUp + } else { + status.LastErrorAt = res.Timestamp + status.LastError = res.Err + status.Status = StatusDown + } + + return status +} + +// NewProviderStatus processes ProviderCallStatus and returns a new ProviderStatus. +func NewProviderStatus(res ProviderCallStatus) ProviderStatus { + status := ProviderStatus{ + Name: res.Name, + } + + // Determine if the error is critical + if res.Err == nil || provider_errors.IsNonCriticalProviderError(res.Err) { + status.LastSuccessAt = res.Timestamp + status.Status = StatusUp + } else { + status.LastErrorAt = res.Timestamp + status.LastError = res.Err + status.Status = StatusDown + } + + return status +} diff --git a/healthmanager/rpcstatus/provider_status_test.go b/healthmanager/rpcstatus/provider_status_test.go new file mode 100644 index 000000000..3a4f73b80 --- /dev/null +++ b/healthmanager/rpcstatus/provider_status_test.go @@ -0,0 +1,175 @@ +package rpcstatus + +import ( + "errors" + "testing" + "time" + + "github.com/status-im/status-go/rpc/chain/rpclimiter" +) + +func TestNewRpcProviderStatus(t *testing.T) { + tests := []struct { + name string + res RpcProviderCallStatus + expected ProviderStatus + }{ + { + name: "No error, should be up", + res: RpcProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: nil, + }, + expected: ProviderStatus{ + Name: "Provider1", + Status: StatusUp, + }, + }, + { + name: "Critical RPC error, should be down", + res: RpcProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: errors.New("Some critical RPC error"), + }, + expected: ProviderStatus{ + Name: "Provider1", + LastError: errors.New("Some critical RPC error"), + Status: StatusDown, + }, + }, + { + name: "Non-critical RPC error, should be up", + res: RpcProviderCallStatus{ + Name: "Provider2", + Timestamp: time.Now(), + Err: rpclimiter.ErrRequestsOverLimit, // Assuming this is non-critical + }, + expected: ProviderStatus{ + Name: "Provider2", + Status: StatusUp, + }, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + got := NewRpcProviderStatus(tt.res) + + // Compare expected and got + if got.Name != tt.expected.Name { + t.Errorf("expected name %v, got %v", tt.expected.Name, got.Name) + } + + // Check LastSuccessAt for StatusUp + if tt.expected.Status == StatusUp { + if got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be set, but got zero value") + } + if !got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be zero, but got %v", got.LastErrorAt) + } + } else if tt.expected.Status == StatusDown { + if got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be set, but got zero value") + } + if !got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be zero, but got %v", got.LastSuccessAt) + } + } + + if got.Status != tt.expected.Status { + t.Errorf("expected status %v, got %v", tt.expected.Status, got.Status) + } + + if got.LastError != nil && tt.expected.LastError != nil && got.LastError.Error() != tt.expected.LastError.Error() { + t.Errorf("expected last error %v, got %v", tt.expected.LastError, got.LastError) + } + }) + } +} + +func TestNewProviderStatus(t *testing.T) { + tests := []struct { + name string + res ProviderCallStatus + expected ProviderStatus + }{ + { + name: "No error, should be up", + res: ProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: nil, + }, + expected: ProviderStatus{ + Name: "Provider1", + Status: StatusUp, + }, + }, + { + name: "Critical provider error, should be down", + res: ProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: errors.New("Some critical provider error"), + }, + expected: ProviderStatus{ + Name: "Provider1", + LastError: errors.New("Some critical provider error"), + Status: StatusDown, + }, + }, + { + name: "Non-critical provider error, should be up", + res: ProviderCallStatus{ + Name: "Provider2", + Timestamp: time.Now(), + Err: errors.New("backoff_seconds"), // Assuming this is non-critical + }, + expected: ProviderStatus{ + Name: "Provider2", + Status: StatusUp, + }, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + got := NewProviderStatus(tt.res) + + // Compare expected and got + if got.Name != tt.expected.Name { + t.Errorf("expected name %v, got %v", tt.expected.Name, got.Name) + } + + // Check LastSuccessAt for StatusUp + if tt.expected.Status == StatusUp { + if got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be set, but got zero value") + } + if !got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be zero, but got %v", got.LastErrorAt) + } + } else if tt.expected.Status == StatusDown { + if got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be set, but got zero value") + } + if !got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be zero, but got %v", got.LastSuccessAt) + } + } + + if got.Status != tt.expected.Status { + t.Errorf("expected status %v, got %v", tt.expected.Status, got.Status) + } + + if got.LastError != nil && tt.expected.LastError != nil && got.LastError.Error() != tt.expected.LastError.Error() { + t.Errorf("expected last error %v, got %v", tt.expected.LastError, got.LastError) + } + }) + } +} diff --git a/node/get_status_node.go b/node/get_status_node.go index ef1ac50fa..0e2984684 100644 --- a/node/get_status_node.go +++ b/node/get_status_node.go @@ -1,6 +1,7 @@ package node import ( + "context" "database/sql" "errors" "fmt" @@ -331,11 +332,19 @@ func (n *StatusNode) setupRPCClient() (err error) { }, } - n.rpcClient, err = rpc.NewClient(gethNodeClient, n.config.NetworkID, n.config.Networks, n.appDB, providerConfigs) + config := rpc.ClientConfig{ + Client: gethNodeClient, + UpstreamChainID: n.config.NetworkID, + Networks: n.config.Networks, + DB: n.appDB, + WalletFeed: &n.walletFeed, + ProviderConfigs: providerConfigs, + } + n.rpcClient, err = rpc.NewClient(config) + n.rpcClient.Start(context.Background()) if err != nil { return } - return } @@ -451,6 +460,7 @@ func (n *StatusNode) stop() error { return err } + n.rpcClient.Stop() n.rpcClient = nil // We need to clear `gethNode` because config is passed to `Start()` // and may be completely different. Similarly with `config`. diff --git a/rpc/chain/blockchain_health_test.go b/rpc/chain/blockchain_health_test.go new file mode 100644 index 000000000..7665ee7f9 --- /dev/null +++ b/rpc/chain/blockchain_health_test.go @@ -0,0 +1,299 @@ +package chain + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" + mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + "go.uber.org/mock/gomock" + + "github.com/status-im/status-go/rpc/chain/ethclient" +) + +type BlockchainHealthManagerSuite struct { + suite.Suite + blockchainHealthManager *healthmanager.BlockchainHealthManager + mockProviders map[uint64]*healthmanager.ProvidersHealthManager + mockEthClients map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface + clients map[uint64]*ClientWithFallback + mockCtrl *gomock.Controller +} + +func (s *BlockchainHealthManagerSuite) SetupTest() { + s.blockchainHealthManager = healthmanager.NewBlockchainHealthManager() + s.mockProviders = make(map[uint64]*healthmanager.ProvidersHealthManager) + s.mockEthClients = make(map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface) + s.clients = make(map[uint64]*ClientWithFallback) + s.mockCtrl = gomock.NewController(s.T()) +} + +func (s *BlockchainHealthManagerSuite) TearDownTest() { + s.blockchainHealthManager.Stop() + s.mockCtrl.Finish() +} + +func (s *BlockchainHealthManagerSuite) setupClients(chainIDs []uint64) { + ctx := context.Background() + + for _, chainID := range chainIDs { + mockEthClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl) + mockEthClient.EXPECT().GetName().AnyTimes().Return(fmt.Sprintf("test_client_chain_%d", chainID)) + mockEthClient.EXPECT().GetLimiter().AnyTimes().Return(nil) + + phm := healthmanager.NewProvidersHealthManager(chainID) + client := NewClient([]ethclient.RPSLimitedEthClientInterface{mockEthClient}, chainID, phm) + + err := s.blockchainHealthManager.RegisterProvidersHealthManager(ctx, phm) + require.NoError(s.T(), err) + + s.mockProviders[chainID] = phm + s.mockEthClients[chainID] = mockEthClient + s.clients[chainID] = client + } +} + +func (s *BlockchainHealthManagerSuite) simulateChainStatus(chainID uint64, up bool) { + client, exists := s.clients[chainID] + require.True(s.T(), exists, "Client for chainID %d not found", chainID) + + mockEthClient := s.mockEthClients[chainID] + ctx := context.Background() + hash := common.HexToHash("0x1234") + + if up { + block := &types.Block{} + mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1) + _, err := client.BlockByHash(ctx, hash) + require.NoError(s.T(), err) + } else { + mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + _, err := client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + } +} + +func (s *BlockchainHealthManagerSuite) waitForStatus(statusCh chan struct{}, expectedStatus rpcstatus.StatusType) { + timeout := time.After(2 * time.Second) + for { + select { + case <-statusCh: + status := s.blockchainHealthManager.Status() + if status.Status == expectedStatus { + return + } + case <-timeout: + s.T().Errorf("Did not receive expected blockchain status update in time") + return + } + } +} + +func (s *BlockchainHealthManagerSuite) TestAllChainsUp() { + s.setupClients([]uint64{1, 2, 3}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, true) + s.simulateChainStatus(2, true) + s.simulateChainStatus(3, true) + + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestSomeChainsDown() { + s.setupClients([]uint64{1, 2, 3}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, true) + s.simulateChainStatus(2, false) + s.simulateChainStatus(3, true) + + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestAllChainsDown() { + s.setupClients([]uint64{1, 2}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, false) + s.simulateChainStatus(2, false) + + s.waitForStatus(statusCh, rpcstatus.StatusDown) +} + +func (s *BlockchainHealthManagerSuite) TestChainStatusChanges() { + s.setupClients([]uint64{1, 2}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, false) + s.simulateChainStatus(2, false) + s.waitForStatus(statusCh, rpcstatus.StatusDown) + + s.simulateChainStatus(1, true) + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestGetFullStatus() { + // Setup clients for chain IDs 1 and 2 + s.setupClients([]uint64{1, 2}) + + // Subscribe to blockchain status updates + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + // Simulate provider statuses for chain 1 + providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain1", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain1", + Timestamp: time.Now(), + Err: errors.New("connection error"), // Down + }, + } + ctx := context.Background() + s.mockProviders[1].Update(ctx, providerCallStatusesChain1) + + // Simulate provider statuses for chain 2 + providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + } + s.mockProviders[2].Update(ctx, providerCallStatusesChain2) + + // Wait for status event to be triggered before getting full status + s.waitForStatus(statusCh, rpcstatus.StatusUp) + + // Get the full status from the BlockchainHealthManager + fullStatus := s.blockchainHealthManager.GetFullStatus() + + // Assert overall blockchain status + require.Equal(s.T(), rpcstatus.StatusUp, fullStatus.Status.Status) + + // Assert provider statuses per chain + require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(1)) + require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(2)) + + // Provider statuses for chain 1 + providerStatusesChain1 := fullStatus.StatusPerChainPerProvider[1] + require.Contains(s.T(), providerStatusesChain1, "provider1_chain1") + require.Contains(s.T(), providerStatusesChain1, "provider2_chain1") + + provider1Chain1Status := providerStatusesChain1["provider1_chain1"] + require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain1Status.Status) + + provider2Chain1Status := providerStatusesChain1["provider2_chain1"] + require.Equal(s.T(), rpcstatus.StatusDown, provider2Chain1Status.Status) + + // Provider statuses for chain 2 + providerStatusesChain2 := fullStatus.StatusPerChainPerProvider[2] + require.Contains(s.T(), providerStatusesChain2, "provider1_chain2") + require.Contains(s.T(), providerStatusesChain2, "provider2_chain2") + + provider1Chain2Status := providerStatusesChain2["provider1_chain2"] + require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain2Status.Status) + + provider2Chain2Status := providerStatusesChain2["provider2_chain2"] + require.Equal(s.T(), rpcstatus.StatusUp, provider2Chain2Status.Status) + + // Serialization to JSON works without errors + jsonData, err := json.MarshalIndent(fullStatus, "", " ") + require.NoError(s.T(), err) + require.NotEmpty(s.T(), jsonData) +} + +func (s *BlockchainHealthManagerSuite) TestGetShortStatus() { + // Setup clients for chain IDs 1 and 2 + s.setupClients([]uint64{1, 2}) + + // Subscribe to blockchain status updates + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + // Simulate provider statuses for chain 1 + providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain1", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain1", + Timestamp: time.Now(), + Err: errors.New("connection error"), // Down + }, + } + ctx := context.Background() + s.mockProviders[1].Update(ctx, providerCallStatusesChain1) + + // Simulate provider statuses for chain 2 + providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + } + s.mockProviders[2].Update(ctx, providerCallStatusesChain2) + + // Wait for status event to be triggered before getting short status + s.waitForStatus(statusCh, rpcstatus.StatusUp) + + // Get the short status from the BlockchainHealthManager + shortStatus := s.blockchainHealthManager.GetStatusPerChain() + + // Assert overall blockchain status + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.Status.Status) + + // Assert chain statuses + require.Contains(s.T(), shortStatus.StatusPerChain, uint64(1)) + require.Contains(s.T(), shortStatus.StatusPerChain, uint64(2)) + + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[1].Status) + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[2].Status) + + // Serialization to JSON works without errors + jsonData, err := json.MarshalIndent(shortStatus, "", " ") + require.NoError(s.T(), err) + require.NotEmpty(s.T(), jsonData) +} + +func TestBlockchainHealthManagerSuite(t *testing.T) { + suite.Run(t, new(BlockchainHealthManagerSuite)) +} diff --git a/rpc/chain/client.go b/rpc/chain/client.go index 24741d161..226c5aa73 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -20,6 +20,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" "github.com/status-im/status-go/circuitbreaker" + "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" "github.com/status-im/status-go/rpc/chain/ethclient" "github.com/status-im/status-go/rpc/chain/rpclimiter" "github.com/status-im/status-go/rpc/chain/tagger" @@ -63,10 +65,11 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte } type ClientWithFallback struct { - ChainID uint64 - ethClients []ethclient.RPSLimitedEthClientInterface - commonLimiter rpclimiter.RequestLimiter - circuitbreaker *circuitbreaker.CircuitBreaker + ChainID uint64 + ethClients []ethclient.RPSLimitedEthClientInterface + commonLimiter rpclimiter.RequestLimiter + circuitbreaker *circuitbreaker.CircuitBreaker + providersHealthManager *healthmanager.ProvidersHealthManager WalletNotifier func(chainId uint64, message string) @@ -111,7 +114,7 @@ var propagateErrors = []error{ bind.ErrNoCode, } -func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64) *ClientWithFallback { +func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64, providersHealthManager *healthmanager.ProvidersHealthManager) *ClientWithFallback { cbConfig := circuitbreaker.Config{ Timeout: 20000, MaxConcurrentRequests: 100, @@ -123,11 +126,12 @@ func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint isConnected.Store(true) return &ClientWithFallback{ - ChainID: chainID, - ethClients: ethClients, - isConnected: isConnected, - LastCheckedAt: time.Now().Unix(), - circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig), + ChainID: chainID, + ethClients: ethClients, + isConnected: isConnected, + LastCheckedAt: time.Now().Unix(), + circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig), + providersHealthManager: providersHealthManager, } } @@ -238,6 +242,10 @@ func (c *ClientWithFallback) makeCall(ctx context.Context, ethClients []ethclien } result := c.circuitbreaker.Execute(cmd) + if c.providersHealthManager != nil { + rpcCallStatuses := convertFunctorCallStatuses(result.FunctorCallStatuses()) + c.providersHealthManager.Update(ctx, rpcCallStatuses) + } if result.Error() != nil { return nil, result.Error() } @@ -842,3 +850,10 @@ func (c *ClientWithFallback) GetCircuitBreaker() *circuitbreaker.CircuitBreaker func (c *ClientWithFallback) SetCircuitBreaker(cb *circuitbreaker.CircuitBreaker) { c.circuitbreaker = cb } + +func convertFunctorCallStatuses(statuses []circuitbreaker.FunctorCallStatus) (result []rpcstatus.RpcProviderCallStatus) { + for _, f := range statuses { + result = append(result, rpcstatus.RpcProviderCallStatus{Name: f.Name, Timestamp: f.Timestamp, Err: f.Err}) + } + return +} diff --git a/rpc/chain/client_health_test.go b/rpc/chain/client_health_test.go new file mode 100644 index 000000000..ca5ae161b --- /dev/null +++ b/rpc/chain/client_health_test.go @@ -0,0 +1,242 @@ +package chain + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/core/vm" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + healthManager "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" + "github.com/status-im/status-go/rpc/chain/ethclient" + "github.com/status-im/status-go/rpc/chain/rpclimiter" + + mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient" +) + +type ClientWithFallbackSuite struct { + suite.Suite + client *ClientWithFallback + mockEthClients []*mockEthclient.MockRPSLimitedEthClientInterface + providersHealthManager *healthManager.ProvidersHealthManager + mockCtrl *gomock.Controller +} + +func (s *ClientWithFallbackSuite) SetupTest() { + s.mockCtrl = gomock.NewController(s.T()) +} + +func (s *ClientWithFallbackSuite) TearDownTest() { + s.mockCtrl.Finish() +} + +func (s *ClientWithFallbackSuite) setupClients(numClients int) { + s.mockEthClients = make([]*mockEthclient.MockRPSLimitedEthClientInterface, 0) + ethClients := make([]ethclient.RPSLimitedEthClientInterface, 0) + + for i := 0; i < numClients; i++ { + ethClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl) + ethClient.EXPECT().GetName().AnyTimes().Return("test" + strconv.Itoa(i)) + ethClient.EXPECT().GetLimiter().AnyTimes().Return(nil) + + s.mockEthClients = append(s.mockEthClients, ethClient) + ethClients = append(ethClients, ethClient) + } + var chainID uint64 = 0 + s.providersHealthManager = healthManager.NewProvidersHealthManager(chainID) + s.client = NewClient(ethClients, chainID, s.providersHealthManager) +} + +func (s *ClientWithFallbackSuite) TestSingleClientSuccess() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + block := &types.Block{} + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1) + + // WHEN + result, err := s.client.BlockByHash(ctx, hash) + require.NoError(s.T(), err) + require.Equal(s.T(), block, result) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestSingleClientConnectionError() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("connection error")).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown) +} + +func (s *ClientWithFallbackSuite) TestRPSLimitErrorDoesNotMarkChainDown() { + s.setupClients(1) + + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // WHEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1) + + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) + + status := providerStatuses["test0"] + require.Equal(s.T(), status.Status, rpcstatus.StatusUp, "provider shouldn't be DOWN on RPS limit") +} + +func (s *ClientWithFallbackSuite) TestContextCanceledDoesNotMarkChainDown() { + s.setupClients(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + hash := common.HexToHash("0x1234") + + // WHEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, context.Canceled).Times(1) + + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + require.True(s.T(), errors.Is(err, context.Canceled)) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestVMErrorDoesNotMarkChainDown() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + vmError := vm.ErrOutOfGas + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, vmError).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + require.True(s.T(), errors.Is(err, vm.ErrOutOfGas)) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestNoClientsChainDown() { + s.setupClients(0) + + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) +} + +func (s *ClientWithFallbackSuite) TestAllClientsDifferentErrors() { + s.setupClients(3) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1) + s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, vm.ErrOutOfGas).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 3) + + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown, "provider test0 should be DOWN due to a connection error") + require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusUp, "provider test1 should not be marked DOWN due to RPS limit error") + require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusUp, "provider test2 should not be labelled DOWN due to a VM error") +} + +func (s *ClientWithFallbackSuite) TestAllClientsNetworkErrors() { + s.setupClients(3) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) + + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 3) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown) + require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusDown) + require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusDown) +} + +func (s *ClientWithFallbackSuite) TestChainStatusDownWhenInitial() { + s.setupClients(2) + + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) +} + +func TestClientWithFallbackSuite(t *testing.T) { + suite.Run(t, new(ClientWithFallbackSuite)) +} diff --git a/rpc/chain/client_test.go b/rpc/chain/client_test.go index b73d5b54e..13a7577a5 100644 --- a/rpc/chain/client_test.go +++ b/rpc/chain/client_test.go @@ -32,7 +32,7 @@ func setupClientTest(t *testing.T) (*ClientWithFallback, []*mock_ethclient.MockR ethClients = append(ethClients, ethCl) } - client := NewClient(ethClients, 0) + client := NewClient(ethClients, 0, nil) cleanup := func() { mockCtrl.Finish() diff --git a/rpc/client.go b/rpc/client.go index a06326282..a4b84770b 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -19,7 +19,9 @@ import ( "github.com/ethereum/go-ethereum/log" gethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/ethereum/go-ethereum/event" appCommon "github.com/status-im/status-go/common" + "github.com/status-im/status-go/healthmanager" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/chain/ethclient" @@ -27,6 +29,7 @@ import ( "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/rpcstats" "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" ) const ( @@ -48,6 +51,8 @@ const ( // rpcUserAgentUpstreamFormat a separate user agent format for upstream, because we should not be using upstream // if we see this user agent in the logs that means parts of the application are using a malconfigured http client rpcUserAgentUpstreamFormat = "procuratee-%s-upstream/%s" + + EventBlockchainHealthChanged walletevent.EventType = "wallet-blockchain-health-changed" // Full status of the blockchain (including provider statuses) ) // List of RPC client errors. @@ -101,6 +106,10 @@ type Client struct { router *router NetworkManager *network.Manager + healthMgr *healthmanager.BlockchainHealthManager + stopMonitoringFunc context.CancelFunc + walletFeed *event.Feed + handlersMx sync.RWMutex // mx guards handlers handlers map[string]Handler // locally registered handlers log log.Logger @@ -112,35 +121,47 @@ type Client struct { // Is initialized in a build-tag-dependent module var verifProxyInitFn func(c *Client) +// ClientConfig holds the configuration for initializing a new Client. +type ClientConfig struct { + Client *gethrpc.Client + UpstreamChainID uint64 + Networks []params.Network + DB *sql.DB + WalletFeed *event.Feed + ProviderConfigs []params.ProviderConfig +} + // NewClient initializes Client // // Client is safe for concurrent use and will automatically // reconnect to the server if connection is lost. -func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params.Network, db *sql.DB, providerConfigs []params.ProviderConfig) (*Client, error) { +func NewClient(config ClientConfig) (*Client, error) { var err error log := log.New("package", "status-go/rpc.Client") - networkManager := network.NewManager(db) + networkManager := network.NewManager(config.DB) if networkManager == nil { return nil, errors.New("failed to create network manager") } - err = networkManager.Init(networks) + err = networkManager.Init(config.Networks) if err != nil { log.Error("Network manager failed to initialize", "error", err) } c := Client{ - local: client, + local: config.Client, NetworkManager: networkManager, handlers: make(map[string]Handler), rpcClients: make(map[uint64]chain.ClientInterface), limiterPerProvider: make(map[string]*rpclimiter.RPCRpsLimiter), log: log, - providerConfigs: providerConfigs, + providerConfigs: config.ProviderConfigs, + healthMgr: healthmanager.NewBlockchainHealthManager(), + walletFeed: config.WalletFeed, } - c.UpstreamChainID = upstreamChainID + c.UpstreamChainID = config.UpstreamChainID c.router = newRouter(true) if verifProxyInitFn != nil { @@ -150,6 +171,57 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params return &c, nil } +func (c *Client) Start(ctx context.Context) { + if c.stopMonitoringFunc != nil { + c.log.Warn("Blockchain health manager already started") + return + } + + cancelableCtx, cancel := context.WithCancel(ctx) + c.stopMonitoringFunc = cancel + statusCh := c.healthMgr.Subscribe() + go c.monitorHealth(cancelableCtx, statusCh) +} + +func (c *Client) Stop() { + c.healthMgr.Stop() + if c.stopMonitoringFunc == nil { + return + } + c.stopMonitoringFunc() + c.stopMonitoringFunc = nil +} + +func (c *Client) monitorHealth(ctx context.Context, statusCh chan struct{}) { + sendFullStatusEventFunc := func() { + blockchainStatus := c.healthMgr.GetFullStatus() + encodedMessage, err := json.Marshal(blockchainStatus) + if err != nil { + c.log.Warn("could not marshal full blockchain status", "error", err) + return + } + if c.walletFeed == nil { + return + } + // FIXME: remove these excessive logs in future release (2.31+) + c.log.Debug("Sending blockchain health status event", "status", string(encodedMessage)) + c.walletFeed.Send(walletevent.Event{ + Type: EventBlockchainHealthChanged, + Message: string(encodedMessage), + At: time.Now().Unix(), + }) + } + + for { + select { + case <-ctx.Done(): + return + case <-statusCh: + sendFullStatusEventFunc() + } + } +} + func (c *Client) GetNetworkManager() *network.Manager { return c.NetworkManager } @@ -207,7 +279,13 @@ func (c *Client) getClientUsingCache(chainID uint64) (chain.ClientInterface, err return nil, fmt.Errorf("could not find any RPC URL for chain: %d", chainID) } - client := chain.NewClient(ethClients, chainID) + phm := healthmanager.NewProvidersHealthManager(chainID) + err := c.healthMgr.RegisterProvidersHealthManager(context.Background(), phm) + if err != nil { + return nil, fmt.Errorf("register providers health manager: %s", err) + } + + client := chain.NewClient(ethClients, chainID, phm) client.SetWalletNotifier(c.walletNotifier) c.rpcClients[chainID] = client return client, nil diff --git a/rpc/client_test.go b/rpc/client_test.go index 877a1af04..36a3ca1af 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -44,7 +44,15 @@ func TestBlockedRoutesCall(t *testing.T) { gethRPCClient, err := gethrpc.Dial(ts.URL) require.NoError(t, err) - c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) + config := ClientConfig{ + Client: gethRPCClient, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + c, err := NewClient(config) require.NoError(t, err) for _, m := range blockedMethods { @@ -83,7 +91,15 @@ func TestBlockedRoutesRawCall(t *testing.T) { gethRPCClient, err := gethrpc.Dial(ts.URL) require.NoError(t, err) - c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) + config := ClientConfig{ + Client: gethRPCClient, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + c, err := NewClient(config) require.NoError(t, err) for _, m := range blockedMethods { @@ -142,7 +158,17 @@ func TestGetClientsUsingCache(t *testing.T) { DefaultFallbackURL2: server.URL + path3, }, } - c, err := NewClient(nil, 1, networks, db, providerConfigs) + + config := ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: networks, + DB: db, + WalletFeed: nil, + ProviderConfigs: providerConfigs, + } + + c, err := NewClient(config) require.NoError(t, err) // Networks from DB must pick up DefaultRPCURL, DefaultFallbackURL, DefaultFallbackURL2 diff --git a/rpc/verif_proxy_test.go b/rpc/verif_proxy_test.go index 34f774a11..907f334ca 100644 --- a/rpc/verif_proxy_test.go +++ b/rpc/verif_proxy_test.go @@ -48,7 +48,16 @@ func (s *ProxySuite) startRpcClient(infuraURL string) *Client { db, close := setupTestNetworkDB(s.T()) defer close() - c, err := NewClient(gethRPCClient, 1, []params.Network{}, db) + + config := ClientConfig{ + Client: gethRPCClient, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + c, err := NewClient(config) require.NoError(s.T(), err) return c diff --git a/services/ens/api_test.go b/services/ens/api_test.go index 4deba6432..ac3af2466 100644 --- a/services/ens/api_test.go +++ b/services/ens/api_test.go @@ -33,7 +33,15 @@ func setupTestAPI(t *testing.T) (*API, func()) { _ = client - rpcClient, err := statusRPC.NewClient(nil, 1, nil, db, nil) + config := statusRPC.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: nil, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + rpcClient, err := statusRPC.NewClient(config) require.NoError(t, err) // import account keys diff --git a/services/wallet/api_test.go b/services/wallet/api_test.go index b4337bbdd..9b14daa79 100644 --- a/services/wallet/api_test.go +++ b/services/wallet/api_test.go @@ -144,7 +144,15 @@ func TestAPI_GetAddressDetails(t *testing.T) { DefaultFallbackURL: serverWith1SecDelay.URL, }, } - c, err := rpc.NewClient(nil, chainID, networks, appDB, providerConfigs) + config := rpc.ClientConfig{ + Client: nil, + UpstreamChainID: chainID, + Networks: networks, + DB: appDB, + WalletFeed: nil, + ProviderConfigs: providerConfigs, + } + c, err := rpc.NewClient(config) require.NoError(t, err) chainClient, err := c.EthClient(chainID) diff --git a/services/wallet/common/mock/feed_subscription.go b/services/wallet/common/mock/feed_subscription.go new file mode 100644 index 000000000..b508ded5f --- /dev/null +++ b/services/wallet/common/mock/feed_subscription.go @@ -0,0 +1,46 @@ +package mock_common + +import ( + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/status-im/status-go/services/wallet/walletevent" +) + +type FeedSubscription struct { + events chan walletevent.Event + feed *event.Feed + done chan struct{} +} + +func NewFeedSubscription(feed *event.Feed) *FeedSubscription { + events := make(chan walletevent.Event, 100) + done := make(chan struct{}) + + subscription := feed.Subscribe(events) + + go func() { + <-done + subscription.Unsubscribe() + close(events) + }() + + return &FeedSubscription{events: events, feed: feed, done: done} +} + +func (f *FeedSubscription) WaitForEvent(timeout time.Duration) (walletevent.Event, bool) { + select { + case evt := <-f.events: + return evt, true + case <-time.After(timeout): + return walletevent.Event{}, false + } +} + +func (f *FeedSubscription) GetFeed() *event.Feed { + return f.feed +} + +func (f *FeedSubscription) Close() { + close(f.done) +} diff --git a/services/wallet/history/service_test.go b/services/wallet/history/service_test.go index 97c5f4c9a..4bcf9e020 100644 --- a/services/wallet/history/service_test.go +++ b/services/wallet/history/service_test.go @@ -404,7 +404,16 @@ func Test_removeBalanceHistoryOnEventAccountRemoved(t *testing.T) { txServiceMockCtrl := gomock.NewController(t) server, _ := fake.NewTestServer(txServiceMockCtrl) client := gethrpc.DialInProc(server) - rpcClient, _ := rpc.NewClient(client, chainID, nil, appDB, nil) + + config := rpc.ClientConfig{ + Client: client, + UpstreamChainID: chainID, + Networks: nil, + DB: appDB, + WalletFeed: nil, + ProviderConfigs: nil, + } + rpcClient, _ := rpc.NewClient(config) rpcClient.UpstreamChainID = chainID service := NewService(walletDB, accountsDB, &accountFeed, &walletFeed, rpcClient, nil, nil, nil) diff --git a/services/wallet/market/market_feed_test.go b/services/wallet/market/market_feed_test.go new file mode 100644 index 000000000..15697040d --- /dev/null +++ b/services/wallet/market/market_feed_test.go @@ -0,0 +1,74 @@ +package market + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/ethereum/go-ethereum/event" + mock_common "github.com/status-im/status-go/services/wallet/common/mock" + mock_market "github.com/status-im/status-go/services/wallet/market/mock" + "github.com/status-im/status-go/services/wallet/thirdparty" +) + +type MarketTestSuite struct { + suite.Suite + feedSub *mock_common.FeedSubscription + symbols []string + currencies []string +} + +func (s *MarketTestSuite) SetupTest() { + feed := new(event.Feed) + s.feedSub = mock_common.NewFeedSubscription(feed) + + s.symbols = []string{"BTC", "ETH"} + s.currencies = []string{"USD", "EUR"} +} + +func (s *MarketTestSuite) TearDownTest() { + s.feedSub.Close() +} + +func (s *MarketTestSuite) TestEventOnRpsError() { + ctrl := gomock.NewController(s.T()) + defer ctrl.Finish() + // GIVEN + customErr := errors.New("request rate exceeded") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed()) + + // WHEN + _, err := manager.FetchPrices(s.symbols, s.currencies) + s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError") + event, ok := s.feedSub.WaitForEvent(5 * time.Second) + s.Require().True(ok, "expected an event, but none was received") + + // THEN + s.Require().Equal(event.Type, EventMarketStatusChanged) +} + +func (s *MarketTestSuite) TestEventOnNetworkError() { + ctrl := gomock.NewController(s.T()) + defer ctrl.Finish() + + // GIVEN + customErr := errors.New("dial tcp: lookup optimism-goerli.infura.io: no such host") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed()) + + _, err := manager.FetchPrices(s.symbols, s.currencies) + s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError") + event, ok := s.feedSub.WaitForEvent(500 * time.Millisecond) + s.Require().True(ok, "expected an event, but none was received") + + // THEN + s.Require().Equal(event.Type, EventMarketStatusChanged) +} + +func TestMarketTestSuite(t *testing.T) { + suite.Run(t, new(MarketTestSuite)) +} diff --git a/services/wallet/market/market_test.go b/services/wallet/market/market_test.go index a7d4f0660..e646fc6ac 100644 --- a/services/wallet/market/market_test.go +++ b/services/wallet/market/market_test.go @@ -10,48 +10,11 @@ import ( "github.com/stretchr/testify/require" + mock_market "github.com/status-im/status-go/services/wallet/market/mock" "github.com/status-im/status-go/services/wallet/thirdparty" mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock" ) -type MockPriceProvider struct { - mock_thirdparty.MockMarketDataProvider - mockPrices map[string]map[string]float64 -} - -func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider { - return &MockPriceProvider{ - MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl), - } -} - -func (mpp *MockPriceProvider) setMockPrices(prices map[string]map[string]float64) { - mpp.mockPrices = prices -} - -func (mpp *MockPriceProvider) ID() string { - return "MockPriceProvider" -} - -func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { - res := make(map[string]map[string]float64) - for _, symbol := range symbols { - res[symbol] = make(map[string]float64) - for _, currency := range currencies { - res[symbol][currency] = mpp.mockPrices[symbol][currency] - } - } - return res, nil -} - -type MockPriceProviderWithError struct { - MockPriceProvider -} - -func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { - return nil, errors.New("error") -} - func setupMarketManager(t *testing.T, providers []thirdparty.MarketDataProvider) *Manager { return NewManager(providers, &event.Feed{}) } @@ -80,8 +43,8 @@ var mockPrices = map[string]map[string]float64{ func TestPrice(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - priceProvider := NewMockPriceProvider(ctrl) - priceProvider.setMockPrices(mockPrices) + priceProvider := mock_market.NewMockPriceProvider(ctrl) + priceProvider.SetMockPrices(mockPrices) manager := setupMarketManager(t, []thirdparty.MarketDataProvider{priceProvider, priceProvider}) @@ -125,9 +88,12 @@ func TestPrice(t *testing.T) { func TestFetchPriceErrorFirstProvider(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - priceProvider := NewMockPriceProvider(ctrl) - priceProvider.setMockPrices(mockPrices) - priceProviderWithError := &MockPriceProviderWithError{} + priceProvider := mock_market.NewMockPriceProvider(ctrl) + priceProvider.SetMockPrices(mockPrices) + + customErr := errors.New("error") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + symbols := []string{"BTC", "ETH"} currencies := []string{"USD", "EUR"} diff --git a/services/wallet/market/mock/mock_price_provider.go b/services/wallet/market/mock/mock_price_provider.go new file mode 100644 index 000000000..5c0170d82 --- /dev/null +++ b/services/wallet/market/mock/mock_price_provider.go @@ -0,0 +1,54 @@ +package mock_market + +import ( + "go.uber.org/mock/gomock" + + mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock" +) + +type MockPriceProvider struct { + mock_thirdparty.MockMarketDataProvider + mockPrices map[string]map[string]float64 +} + +func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider { + return &MockPriceProvider{ + MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl), + } +} + +func (mpp *MockPriceProvider) SetMockPrices(prices map[string]map[string]float64) { + mpp.mockPrices = prices +} + +func (mpp *MockPriceProvider) ID() string { + return "MockPriceProvider" +} + +func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { + res := make(map[string]map[string]float64) + for _, symbol := range symbols { + res[symbol] = make(map[string]float64) + for _, currency := range currencies { + res[symbol][currency] = mpp.mockPrices[symbol][currency] + } + } + return res, nil +} + +type MockPriceProviderWithError struct { + MockPriceProvider + err error +} + +// NewMockPriceProviderWithError creates a new MockPriceProviderWithError with the specified error +func NewMockPriceProviderWithError(ctrl *gomock.Controller, err error) *MockPriceProviderWithError { + return &MockPriceProviderWithError{ + MockPriceProvider: *NewMockPriceProvider(ctrl), + err: err, + } +} + +func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { + return nil, mpp.err +} diff --git a/services/wallet/router/router_test.go b/services/wallet/router/router_test.go index 8b06ab174..ac36ede5e 100644 --- a/services/wallet/router/router_test.go +++ b/services/wallet/router/router_test.go @@ -91,7 +91,15 @@ func setupTestNetworkDB(t *testing.T) (*sql.DB, func()) { func setupRouter(t *testing.T) (*Router, func()) { db, cleanTmpDb := setupTestNetworkDB(t) - client, _ := rpc.NewClient(nil, 1, defaultNetworks, db, nil) + config := rpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: defaultNetworks, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := rpc.NewClient(config) router := NewRouter(client, nil, nil, nil, nil, nil, nil, nil) diff --git a/services/wallet/token/token_test.go b/services/wallet/token/token_test.go index f8b01c073..b0ea6920f 100644 --- a/services/wallet/token/token_test.go +++ b/services/wallet/token/token_test.go @@ -331,7 +331,17 @@ func Test_removeTokenBalanceOnEventAccountRemoved(t *testing.T) { txServiceMockCtrl := gomock.NewController(t) server, _ := fake.NewTestServer(txServiceMockCtrl) client := gethrpc.DialInProc(server) - rpcClient, _ := rpc.NewClient(client, chainID, nil, appDB, nil) + + config := rpc.ClientConfig{ + Client: client, + UpstreamChainID: chainID, + Networks: nil, + DB: appDB, + WalletFeed: nil, + ProviderConfigs: nil, + } + rpcClient, _ := rpc.NewClient(config) + rpcClient.UpstreamChainID = chainID nm := network.NewManager(appDB) mediaServer, err := mediaserver.NewMediaServer(appDB, nil, nil, walletDB) diff --git a/services/wallet/transfer/commands_sequential_test.go b/services/wallet/transfer/commands_sequential_test.go index 8da1a4ea5..426d14668 100644 --- a/services/wallet/transfer/commands_sequential_test.go +++ b/services/wallet/transfer/commands_sequential_test.go @@ -10,6 +10,10 @@ import ( "testing" "time" + "github.com/status-im/status-go/contracts" + "github.com/status-im/status-go/services/wallet/blockchainstate" + "github.com/status-im/status-go/t/utils" + "github.com/pkg/errors" "github.com/stretchr/testify/mock" "go.uber.org/mock/gomock" @@ -24,30 +28,26 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/contracts" "github.com/status-im/status-go/contracts/balancechecker" "github.com/status-im/status-go/contracts/ethscan" "github.com/status-im/status-go/contracts/ierc20" ethtypes "github.com/status-im/status-go/eth-node/types" - ethclient "github.com/status-im/status-go/rpc/chain/ethclient" - mock_client "github.com/status-im/status-go/rpc/chain/mock/client" - "github.com/status-im/status-go/rpc/chain/rpclimiter" - mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" - "github.com/status-im/status-go/server" - "github.com/status-im/status-go/services/wallet/async" - "github.com/status-im/status-go/services/wallet/balance" - "github.com/status-im/status-go/services/wallet/blockchainstate" - "github.com/status-im/status-go/services/wallet/community" - "github.com/status-im/status-go/t/helpers" - "github.com/status-im/status-go/t/utils" - "github.com/status-im/status-go/multiaccounts/accounts" multicommon "github.com/status-im/status-go/multiaccounts/common" "github.com/status-im/status-go/params" statusRpc "github.com/status-im/status-go/rpc" + ethclient "github.com/status-im/status-go/rpc/chain/ethclient" + mock_client "github.com/status-im/status-go/rpc/chain/mock/client" + "github.com/status-im/status-go/rpc/chain/rpclimiter" + mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" "github.com/status-im/status-go/rpc/network" + "github.com/status-im/status-go/server" + "github.com/status-im/status-go/services/wallet/async" + "github.com/status-im/status-go/services/wallet/balance" walletcommon "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/community" "github.com/status-im/status-go/services/wallet/token" + "github.com/status-im/status-go/t/helpers" "github.com/status-im/status-go/transactions" "github.com/status-im/status-go/walletdatabase" ) @@ -1079,7 +1079,17 @@ func setupFindBlocksCommand(t *testing.T, accountAddress common.Address, fromBlo return nil } - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + client.SetClient(tc.NetworkID(), tc) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil, token.NewPersistence(db)) tokenManager.SetTokens([]*token.Token{ @@ -1342,7 +1352,16 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) { currentBlock: 100, } - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + client.SetClient(tc.NetworkID(), tc) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil, token.NewPersistence(db)) @@ -1466,7 +1485,16 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) { currentBlock: 100, } - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + client.SetClient(tc.NetworkID(), tc) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil, token.NewPersistence(db)) @@ -1546,7 +1574,16 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) { mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) require.NoError(t, err) - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + client.SetClient(tc.NetworkID(), tc) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil, token.NewPersistence(db)) @@ -1660,7 +1697,16 @@ func TestFetchNewBlocksCommand(t *testing.T) { } //tc.printPreparedData = true - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + client.SetClient(tc.NetworkID(), tc) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil, token.NewPersistence(db)) @@ -1799,7 +1845,16 @@ func TestLoadBlocksAndTransfersCommand_FiniteFinishedInfiniteRunning(t *testing. db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - client, _ := statusRpc.NewClient(nil, 1, []params.Network{}, db, nil) + config := statusRpc.ClientConfig{ + Client: nil, + UpstreamChainID: 1, + Networks: []params.Network{}, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + client, _ := statusRpc.NewClient(config) + maker, _ := contracts.NewContractMaker(client) wdb := NewDB(db) diff --git a/services/web3provider/api_test.go b/services/web3provider/api_test.go index 22b5b55a3..fe815a212 100644 --- a/services/web3provider/api_test.go +++ b/services/web3provider/api_test.go @@ -39,7 +39,15 @@ func setupTestAPI(t *testing.T) (*API, func()) { server, _ := fake.NewTestServer(txServiceMockCtrl) client := gethrpc.DialInProc(server) - rpcClient, err := statusRPC.NewClient(client, 1, nil, db, nil) + config := statusRPC.ClientConfig{ + Client: client, + UpstreamChainID: 1, + Networks: nil, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + rpcClient, err := statusRPC.NewClient(config) require.NoError(t, err) // import account keys diff --git a/transactions/transactor_test.go b/transactions/transactor_test.go index 7d6252939..b6348023d 100644 --- a/transactions/transactor_test.go +++ b/transactions/transactor_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" + statusRpc "github.com/status-im/status-go/rpc" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" gethtypes "github.com/ethereum/go-ethereum/core/types" @@ -26,7 +28,6 @@ import ( "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" - "github.com/status-im/status-go/rpc" wallet_common "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/sqlite" "github.com/status-im/status-go/t/utils" @@ -60,13 +61,23 @@ func (s *TransactorSuite) SetupTest() { chainID := gethparams.AllEthashProtocolChanges.ChainID.Uint64() db, err := sqlite.OpenUnecryptedDB(sqlite.InMemoryPath) // dummy to make rpc.Client happy s.Require().NoError(err) - rpcClient, _ := rpc.NewClient(s.client, chainID, nil, db, nil) + + config := statusRpc.ClientConfig{ + Client: s.client, + UpstreamChainID: chainID, + Networks: nil, + DB: db, + WalletFeed: nil, + ProviderConfigs: nil, + } + rpcClient, _ := statusRpc.NewClient(config) + rpcClient.UpstreamChainID = chainID ethClients := []ethclient.RPSLimitedEthClientInterface{ ethclient.NewRPSLimitedEthClient(s.client, rpclimiter.NewRPCRpsLimiter(), "local-1-chain-id-1"), } - localClient := chain.NewClient(ethClients, chainID) + localClient := chain.NewClient(ethClients, chainID, nil) rpcClient.SetClient(chainID, localClient) nodeConfig, err := utils.MakeTestNodeConfigWithDataDir("", "/tmp", chainID)