diff --git a/healthmanager/aggregator/aggregator.go b/healthmanager/aggregator/aggregator.go new file mode 100644 index 000000000..b81edc01a --- /dev/null +++ b/healthmanager/aggregator/aggregator.go @@ -0,0 +1,134 @@ +package aggregator + +import ( + "sync" + "time" + + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +// Aggregator manages and aggregates the statuses of multiple providers. +type Aggregator struct { + mu sync.RWMutex + Name string + providerStatuses map[string]*rpcstatus.ProviderStatus +} + +// NewAggregator creates a new instance of Aggregator with the given name. +func NewAggregator(name string) *Aggregator { + return &Aggregator{ + Name: name, + providerStatuses: make(map[string]*rpcstatus.ProviderStatus), + } +} + +// RegisterProvider adds a new provider to the aggregator. +// If the provider already exists, it does nothing. +func (a *Aggregator) RegisterProvider(providerName string) { + a.mu.Lock() + defer a.mu.Unlock() + if _, exists := a.providerStatuses[providerName]; !exists { + a.providerStatuses[providerName] = &rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUnknown, + } + } +} + +// Update modifies the status of a specific provider. +// If the provider is not already registered, it adds the provider. +func (a *Aggregator) Update(providerStatus rpcstatus.ProviderStatus) { + a.mu.Lock() + defer a.mu.Unlock() + + // Update existing provider status or add a new provider. + if ps, exists := a.providerStatuses[providerStatus.Name]; exists { + ps.Status = providerStatus.Status + if providerStatus.Status == rpcstatus.StatusUp { + ps.LastSuccessAt = providerStatus.LastSuccessAt + } else if providerStatus.Status == rpcstatus.StatusDown { + ps.LastErrorAt = providerStatus.LastErrorAt + ps.LastError = providerStatus.LastError + } + } else { + a.providerStatuses[providerStatus.Name] = &rpcstatus.ProviderStatus{ + Name: providerStatus.Name, + LastSuccessAt: providerStatus.LastSuccessAt, + LastErrorAt: providerStatus.LastErrorAt, + LastError: providerStatus.LastError, + Status: providerStatus.Status, + } + } +} + +// UpdateBatch processes a batch of provider statuses. +func (a *Aggregator) UpdateBatch(statuses []rpcstatus.ProviderStatus) { + for _, status := range statuses { + a.Update(status) + } +} + +// ComputeAggregatedStatus calculates the overall aggregated status based on individual provider statuses. +// The logic is as follows: +// - If any provider is up, the aggregated status is up. +// - If no providers are up but at least one is unknown, the aggregated status is unknown. +// - If all providers are down, the aggregated status is down. +func (a *Aggregator) ComputeAggregatedStatus() rpcstatus.ProviderStatus { + a.mu.RLock() + defer a.mu.RUnlock() + + var lastSuccessAt, lastErrorAt time.Time + var lastError error + anyUp := false + anyUnknown := false + + for _, ps := range a.providerStatuses { + switch ps.Status { + case rpcstatus.StatusUp: + anyUp = true + if ps.LastSuccessAt.After(lastSuccessAt) { + lastSuccessAt = ps.LastSuccessAt + } + case rpcstatus.StatusUnknown: + anyUnknown = true + case rpcstatus.StatusDown: + if ps.LastErrorAt.After(lastErrorAt) { + lastErrorAt = ps.LastErrorAt + lastError = ps.LastError + } + } + } + + aggregatedStatus := rpcstatus.ProviderStatus{ + Name: a.Name, + LastSuccessAt: lastSuccessAt, + LastErrorAt: lastErrorAt, + LastError: lastError, + } + if len(a.providerStatuses) == 0 { + aggregatedStatus.Status = rpcstatus.StatusDown + } else if anyUp { + aggregatedStatus.Status = rpcstatus.StatusUp + } else if anyUnknown { + aggregatedStatus.Status = rpcstatus.StatusUnknown + } else { + aggregatedStatus.Status = rpcstatus.StatusDown + } + + return aggregatedStatus +} + +func (a *Aggregator) GetAggregatedStatus() rpcstatus.ProviderStatus { + return a.ComputeAggregatedStatus() +} + +func (a *Aggregator) GetStatuses() map[string]rpcstatus.ProviderStatus { + a.mu.RLock() + defer a.mu.RUnlock() + + statusesCopy := make(map[string]rpcstatus.ProviderStatus) + for k, v := range a.providerStatuses { + statusesCopy[k] = *v + } + return statusesCopy +} diff --git a/healthmanager/aggregator/aggregator_test.go b/healthmanager/aggregator/aggregator_test.go new file mode 100644 index 000000000..22246e485 --- /dev/null +++ b/healthmanager/aggregator/aggregator_test.go @@ -0,0 +1,311 @@ +package aggregator + +import ( + "sync" + "testing" + "time" + + "github.com/status-im/status-go/healthmanager/rpcstatus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +// StatusAggregatorTestSuite defines the test suite for Aggregator. +type StatusAggregatorTestSuite struct { + suite.Suite + aggregator *Aggregator +} + +// SetupTest runs before each test in the suite. +func (suite *StatusAggregatorTestSuite) SetupTest() { + suite.aggregator = NewAggregator("TestAggregator") +} + +// TestNewAggregator verifies that a new Aggregator is initialized correctly. +func (suite *StatusAggregatorTestSuite) TestNewAggregator() { + assert.Equal(suite.T(), "TestAggregator", suite.aggregator.Name, "Aggregator name should be set correctly") + assert.Empty(suite.T(), suite.aggregator.providerStatuses, "Aggregator should have no providers initially") +} + +// TestRegisterProvider verifies that providers are registered correctly. +func (suite *StatusAggregatorTestSuite) TestRegisterProvider() { + providerName := "Provider1" + suite.aggregator.RegisterProvider(providerName) + + assert.Len(suite.T(), suite.aggregator.providerStatuses, 1, "Expected 1 provider after registration") + _, exists := suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should be registered") + + // Attempt to register the same provider again + suite.aggregator.RegisterProvider(providerName) + assert.Len(suite.T(), suite.aggregator.providerStatuses, 1, "Duplicate registration should not increase provider count") +} + +// TestUpdate verifies that updating a provider's status works correctly. +func (suite *StatusAggregatorTestSuite) TestUpdate() { + providerName := "Provider1" + suite.aggregator.RegisterProvider(providerName) + + now := time.Now() + + // Update existing provider to up + statusUp := rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + } + suite.aggregator.Update(statusUp) + + ps, exists := suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should exist after update") + assert.Equal(suite.T(), rpcstatus.StatusUp, ps.Status, "Provider1 status should be 'up'") + assert.Equal(suite.T(), now, ps.LastSuccessAt, "Provider1 LastSuccessAt should be updated") + + // Update existing provider to down + nowDown := now.Add(1 * time.Hour) + statusDown := rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusDown, + LastErrorAt: nowDown, + } + suite.aggregator.Update(statusDown) + + ps, exists = suite.aggregator.providerStatuses[providerName] + assert.True(suite.T(), exists, "Provider1 should exist after second update") + assert.Equal(suite.T(), rpcstatus.StatusDown, ps.Status, "Provider1 status should be 'down'") + assert.Equal(suite.T(), nowDown, ps.LastErrorAt, "Provider1 LastErrorAt should be updated") + + // Update a non-registered provider via Update (should add it) + provider2 := "Provider2" + statusUp2 := rpcstatus.ProviderStatus{ + Name: provider2, + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + } + suite.aggregator.Update(statusUp2) + + assert.Len(suite.T(), suite.aggregator.providerStatuses, 2, "Expected 2 providers after updating a new provider") + ps2, exists := suite.aggregator.providerStatuses[provider2] + assert.True(suite.T(), exists, "Provider2 should be added via Update") + assert.Equal(suite.T(), rpcstatus.StatusUp, ps2.Status, "Provider2 status should be 'up'") +} + +// TestComputeAggregatedStatus_NoProviders verifies aggregated status when no providers are registered. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_NoProviders() { + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUnknown, aggStatus.Status, "Aggregated status should be 'unknown' when no providers are registered") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when no providers are registered") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when no providers are registered") +} + +// TestComputeAggregatedStatus_AllUnknown verifies aggregated status when all providers are unknown. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllUnknown() { + // Register multiple providers with unknown status + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + suite.aggregator.RegisterProvider("Provider3") + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUnknown, aggStatus.Status, "Aggregated status should be 'unknown' when all providers are unknown") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when all providers are unknown") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when all providers are unknown") +} + +// TestComputeAggregatedStatus_AllUp verifies aggregated status when all providers are up. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllUp() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now1 := time.Now() + now2 := now1.Add(1 * time.Hour) + + // Update all providers to up + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusUp, + LastSuccessAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when all providers are up") + assert.Equal(suite.T(), now2, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when all providers are up") +} + +// TestComputeAggregatedStatus_AllDown verifies aggregated status when all providers are down. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_AllDown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now1 := time.Now() + now2 := now1.Add(1 * time.Hour) + + // Update all providers to down + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusDown, + LastErrorAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusDown, aggStatus.Status, "Aggregated status should be 'down' when all providers are down") + assert.Equal(suite.T(), now2, aggStatus.LastErrorAt, "LastErrorAt should reflect the latest error time") + assert.True(suite.T(), aggStatus.LastSuccessAt.IsZero(), "LastSuccessAt should be zero when all providers are down") +} + +// TestComputeAggregatedStatus_MixedUpAndUnknown verifies aggregated status with mixed up and unknown providers. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_MixedUpAndUnknown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") // up + suite.aggregator.RegisterProvider("Provider2") // unknown + suite.aggregator.RegisterProvider("Provider3") // up + + now1 := time.Now() + now2 := now1.Add(30 * time.Minute) + + // Update some providers to up + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider3", + Status: rpcstatus.StatusUp, + LastSuccessAt: now2, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now2, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.True(suite.T(), aggStatus.LastErrorAt.IsZero(), "LastErrorAt should be zero when no providers are down") +} + +// TestComputeAggregatedStatus_MixedUpAndDown verifies aggregated status with mixed up and down providers. +func (suite *StatusAggregatorTestSuite) TestComputeAggregatedStatus_MixedUpAndDown() { + // Register providers + suite.aggregator.RegisterProvider("Provider1") // up + suite.aggregator.RegisterProvider("Provider2") // down + suite.aggregator.RegisterProvider("Provider3") // up + + now1 := time.Now() + now2 := now1.Add(15 * time.Minute) + + // Update providers + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now2, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider3", + Status: rpcstatus.StatusUp, + LastSuccessAt: now1, + }) + + aggStatus := suite.aggregator.ComputeAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now1, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the latest success time") + assert.Equal(suite.T(), now2, aggStatus.LastErrorAt, "LastErrorAt should reflect the latest error time") +} + +// TestGetAggregatedStatus verifies that GetAggregatedStatus returns the correct aggregated status. +func (suite *StatusAggregatorTestSuite) TestGetAggregatedStatus() { + // Register and update providers + suite.aggregator.RegisterProvider("Provider1") + suite.aggregator.RegisterProvider("Provider2") + + now := time.Now() + + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider1", + Status: rpcstatus.StatusUp, + LastSuccessAt: now, + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: "Provider2", + Status: rpcstatus.StatusDown, + LastErrorAt: now.Add(1 * time.Hour), + }) + + aggStatus := suite.aggregator.GetAggregatedStatus() + + assert.Equal(suite.T(), rpcstatus.StatusUp, aggStatus.Status, "Aggregated status should be 'up' when at least one provider is up") + assert.Equal(suite.T(), now, aggStatus.LastSuccessAt, "LastSuccessAt should reflect the provider's success time") + assert.Equal(suite.T(), now.Add(1*time.Hour), aggStatus.LastErrorAt, "LastErrorAt should reflect the provider's error time") +} + +// TestConcurrentAccess verifies that the Aggregator is safe for concurrent use. +func (suite *StatusAggregatorTestSuite) TestConcurrentAccess() { + // Register multiple providers + providers := []string{"Provider1", "Provider2", "Provider3", "Provider4", "Provider5"} + for _, p := range providers { + suite.aggregator.RegisterProvider(p) + } + + var wg sync.WaitGroup + + // Concurrently update providers + for _, p := range providers { + wg.Add(1) + go func(providerName string) { + defer wg.Done() + for i := 0; i < 1000; i++ { + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusUp, + LastSuccessAt: time.Now(), + }) + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: providerName, + Status: rpcstatus.StatusDown, + LastErrorAt: time.Now(), + }) + } + }(p) + } + + // Wait for all goroutines to finish + wg.Wait() + + // Set all providers to down to ensure deterministic aggregated status + now := time.Now() + for _, p := range providers { + suite.aggregator.Update(rpcstatus.ProviderStatus{ + Name: p, + Status: rpcstatus.StatusDown, + LastErrorAt: now, + }) + } + + aggStatus := suite.aggregator.GetAggregatedStatus() + assert.Equal(suite.T(), rpcstatus.StatusDown, aggStatus.Status, "Aggregated status should be 'down' after setting all providers to down") +} + +// TestStatusAggregatorTestSuite runs the test suite. +func TestStatusAggregatorTestSuite(t *testing.T) { + suite.Run(t, new(StatusAggregatorTestSuite)) +} diff --git a/healthmanager/blockchain_health_manager.go b/healthmanager/blockchain_health_manager.go new file mode 100644 index 000000000..d6084ab48 --- /dev/null +++ b/healthmanager/blockchain_health_manager.go @@ -0,0 +1,224 @@ +package healthmanager + +import ( + "context" + "fmt" + "github.com/status-im/status-go/healthmanager/aggregator" + "github.com/status-im/status-go/healthmanager/rpcstatus" + "sync" +) + +// BlockchainFullStatus contains the full status of the blockchain, including provider statuses. +type BlockchainFullStatus struct { + Status rpcstatus.ProviderStatus `json:"status"` + StatusPerChainPerProvider map[uint64]map[string]rpcstatus.ProviderStatus `json:"statusPerChainPerProvider"` +} + +// BlockchainHealthManager manages the state of all providers and aggregates their statuses. +type BlockchainHealthManager struct { + mu sync.RWMutex + aggregator *aggregator.Aggregator + subscribers []chan struct{} + + providers map[uint64]*ProvidersHealthManager + cancelFuncs map[uint64]context.CancelFunc // Map chainID to cancel functions + lastStatus BlockchainStatus + wg sync.WaitGroup +} + +// NewBlockchainHealthManager creates a new instance of BlockchainHealthManager. +func NewBlockchainHealthManager() *BlockchainHealthManager { + agg := aggregator.NewAggregator("blockchain") + return &BlockchainHealthManager{ + aggregator: agg, + providers: make(map[uint64]*ProvidersHealthManager), + cancelFuncs: make(map[uint64]context.CancelFunc), + } +} + +// RegisterProvidersHealthManager registers the provider health manager. +// It prevents registering the same provider twice for the same chain. +func (b *BlockchainHealthManager) RegisterProvidersHealthManager(ctx context.Context, phm *ProvidersHealthManager) error { + b.mu.Lock() + defer b.mu.Unlock() + + // Check if the provider for the given chainID is already registered + if _, exists := b.providers[phm.ChainID()]; exists { + // Log a warning or return an error to indicate that the provider is already registered + return fmt.Errorf("provider for chainID %d is already registered", phm.ChainID()) + } + + // Proceed with the registration + b.providers[phm.ChainID()] = phm + + // Create a new context for the provider + providerCtx, cancel := context.WithCancel(ctx) + b.cancelFuncs[phm.ChainID()] = cancel + + statusCh := phm.Subscribe() + b.wg.Add(1) + go func(phm *ProvidersHealthManager, statusCh chan struct{}, providerCtx context.Context) { + defer func() { + b.wg.Done() + phm.Unsubscribe(statusCh) + }() + for { + select { + case <-statusCh: + // When the provider updates its status, check the statuses of all providers + b.aggregateAndUpdateStatus(providerCtx) + case <-providerCtx.Done(): + // Stop processing when the context is cancelled + return + } + } + }(phm, statusCh, providerCtx) + + return nil +} + +// Stop stops the event processing and unsubscribes. +func (b *BlockchainHealthManager) Stop() { + b.mu.Lock() + + for _, cancel := range b.cancelFuncs { + cancel() + } + clear(b.cancelFuncs) + + b.mu.Unlock() + b.wg.Wait() +} + +// Subscribe allows clients to receive notifications about changes. +func (b *BlockchainHealthManager) Subscribe() chan struct{} { + ch := make(chan struct{}, 1) + b.mu.Lock() + defer b.mu.Unlock() + b.subscribers = append(b.subscribers, ch) + return ch +} + +// Unsubscribe removes a subscriber from receiving notifications. +func (b *BlockchainHealthManager) Unsubscribe(ch chan struct{}) { + b.mu.Lock() + defer b.mu.Unlock() + + // Remove the subscriber channel from the list + for i, subscriber := range b.subscribers { + if subscriber == ch { + b.subscribers = append(b.subscribers[:i], b.subscribers[i+1:]...) + close(ch) + break + } + } +} + +// aggregateAndUpdateStatus collects statuses from all providers and updates the overall and short status. +func (b *BlockchainHealthManager) aggregateAndUpdateStatus(ctx context.Context) { + b.mu.Lock() + + // Collect statuses from all providers + providerStatuses := make([]rpcstatus.ProviderStatus, 0) + for _, provider := range b.providers { + providerStatuses = append(providerStatuses, provider.Status()) + } + + // Update the aggregator with the new list of provider statuses + b.aggregator.UpdateBatch(providerStatuses) + + // Get the new aggregated full and short status + newShortStatus := b.getShortStatus() + b.mu.Unlock() + + // Compare full and short statuses and emit if changed + if !compareShortStatus(newShortStatus, b.lastStatus) { + b.emitBlockchainHealthStatus(ctx) + b.mu.Lock() + defer b.mu.Unlock() + b.lastStatus = newShortStatus + } +} + +// compareShortStatus compares two BlockchainStatus structs and returns true if they are identical. +func compareShortStatus(newStatus, previousStatus BlockchainStatus) bool { + if newStatus.Status.Status != previousStatus.Status.Status { + return false + } + + if len(newStatus.StatusPerChain) != len(previousStatus.StatusPerChain) { + return false + } + + for chainID, newChainStatus := range newStatus.StatusPerChain { + if prevChainStatus, ok := previousStatus.StatusPerChain[chainID]; !ok || newChainStatus.Status != prevChainStatus.Status { + return false + } + } + + return true +} + +// emitBlockchainHealthStatus sends a notification to all subscribers about the new blockchain status. +func (b *BlockchainHealthManager) emitBlockchainHealthStatus(ctx context.Context) { + b.mu.RLock() + defer b.mu.RUnlock() + for _, subscriber := range b.subscribers { + select { + case <-ctx.Done(): + // Stop sending notifications when the context is cancelled + return + case subscriber <- struct{}{}: + default: + // Skip notification if the subscriber's channel is full + } + } +} + +func (b *BlockchainHealthManager) GetFullStatus() BlockchainFullStatus { + b.mu.RLock() + defer b.mu.RUnlock() + + statusPerChainPerProvider := make(map[uint64]map[string]rpcstatus.ProviderStatus) + + for chainID, phm := range b.providers { + providerStatuses := phm.GetStatuses() + statusPerChainPerProvider[chainID] = providerStatuses + } + + blockchainStatus := b.aggregator.GetAggregatedStatus() + + return BlockchainFullStatus{ + Status: blockchainStatus, + StatusPerChainPerProvider: statusPerChainPerProvider, + } +} + +func (b *BlockchainHealthManager) getShortStatus() BlockchainStatus { + statusPerChain := make(map[uint64]rpcstatus.ProviderStatus) + + for chainID, phm := range b.providers { + chainStatus := phm.Status() + statusPerChain[chainID] = chainStatus + } + + blockchainStatus := b.aggregator.GetAggregatedStatus() + + return BlockchainStatus{ + Status: blockchainStatus, + StatusPerChain: statusPerChain, + } +} + +func (b *BlockchainHealthManager) GetShortStatus() BlockchainStatus { + b.mu.RLock() + defer b.mu.RUnlock() + return b.getShortStatus() +} + +// Status returns the current aggregated status. +func (b *BlockchainHealthManager) Status() rpcstatus.ProviderStatus { + b.mu.RLock() + defer b.mu.RUnlock() + return b.aggregator.GetAggregatedStatus() +} diff --git a/healthmanager/blockchain_health_manager_test.go b/healthmanager/blockchain_health_manager_test.go new file mode 100644 index 000000000..8b65fcae6 --- /dev/null +++ b/healthmanager/blockchain_health_manager_test.go @@ -0,0 +1,222 @@ +package healthmanager + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/status-im/status-go/healthmanager/rpcstatus" + "github.com/stretchr/testify/suite" +) + +type BlockchainHealthManagerSuite struct { + suite.Suite + manager *BlockchainHealthManager + ctx context.Context + cancel context.CancelFunc +} + +func (s *BlockchainHealthManagerSuite) SetupTest() { + s.manager = NewBlockchainHealthManager() + s.ctx, s.cancel = context.WithCancel(context.Background()) +} + +func (s *BlockchainHealthManagerSuite) TearDownTest() { + s.manager.Stop() + s.cancel() +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *BlockchainHealthManagerSuite) waitForUpdate(ch <-chan struct{}, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + select { + case <-ch: + // Received notification + case <-time.After(timeout): + s.Fail("Timeout waiting for chain status update") + } + + s.assertBlockChainStatus(expectedChainStatus) +} + +// Helper method to assert the current chain status +func (s *BlockchainHealthManagerSuite) assertBlockChainStatus(expected rpcstatus.StatusType) { + actual := s.manager.Status().Status + s.Equal(expected, actual, fmt.Sprintf("Expected blockchain status to be %s", expected)) +} + +// Test registering a provider health manager +func (s *BlockchainHealthManagerSuite) TestRegisterProvidersHealthManager() { + phm := NewProvidersHealthManager(1) // Create a real ProvidersHealthManager + s.manager.RegisterProvidersHealthManager(context.Background(), phm) + + // Verify that the provider is registered + s.Require().NotNil(s.manager.providers[1]) +} + +// Test status updates and notifications +func (s *BlockchainHealthManagerSuite) TestStatusUpdateNotification() { + phm := NewProvidersHealthManager(1) + s.manager.RegisterProvidersHealthManager(context.Background(), phm) + ch := s.manager.Subscribe() + + // Update the provider status + phm.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName", Timestamp: time.Now(), Err: nil}, + }) + + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) +} + +// Test getting the full status +func (s *BlockchainHealthManagerSuite) TestGetFullStatus() { + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + ctx := context.Background() + s.manager.RegisterProvidersHealthManager(ctx, phm1) + s.manager.RegisterProvidersHealthManager(ctx, phm2) + ch := s.manager.Subscribe() + + // Update the provider status + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName1", Timestamp: time.Now(), Err: nil}, + }) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "providerName2", Timestamp: time.Now(), Err: errors.New("connection error")}, + }) + + s.waitForUpdate(ch, rpcstatus.StatusUp, 10*time.Millisecond) + fullStatus := s.manager.GetFullStatus() + s.Len(fullStatus.StatusPerChainPerProvider, 2, "Expected statuses for 2 chains") +} + +func (s *BlockchainHealthManagerSuite) TestConcurrentSubscriptionUnsubscription() { + var wg sync.WaitGroup + subscribersCount := 100 + + // Concurrently add and remove subscribers + for i := 0; i < subscribersCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + subCh := s.manager.Subscribe() + time.Sleep(10 * time.Millisecond) + s.manager.Unsubscribe(subCh) + }() + } + + wg.Wait() + // After all subscribers are removed, there should be no active subscribers + s.Equal(0, len(s.manager.subscribers), "Expected no subscribers after unsubscription") +} + +func (s *BlockchainHealthManagerSuite) TestConcurrency() { + var wg sync.WaitGroup + chainsCount := 10 + providersCount := 100 + ctx, cancel := context.WithCancel(s.ctx) + defer cancel() + for i := 1; i <= chainsCount; i++ { + phm := NewProvidersHealthManager(uint64(i)) + s.manager.RegisterProvidersHealthManager(ctx, phm) + } + + ch := s.manager.Subscribe() + + for i := 1; i <= chainsCount; i++ { + wg.Add(1) + go func(chainID uint64) { + defer wg.Done() + phm := s.manager.providers[chainID] + for j := 0; j < providersCount; j++ { + err := errors.New("connection error") + if j == providersCount-1 { + err = nil + } + name := fmt.Sprintf("provider-%d", j) + go phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: name, Timestamp: time.Now(), Err: err}, + }) + } + }(uint64(i)) + } + + wg.Wait() + + s.waitForUpdate(ch, rpcstatus.StatusUp, 2*time.Second) +} + +func (s *BlockchainHealthManagerSuite) TestMultipleStartAndStop() { + s.manager.Stop() + + s.manager.Stop() + + // Ensure that the manager is in a clean state after multiple starts and stops + s.Equal(0, len(s.manager.cancelFuncs), "Expected no cancel functions after stop") +} + +func (s *BlockchainHealthManagerSuite) TestUnsubscribeOneOfMultipleSubscribers() { + // Create an instance of BlockchainHealthManager and register a provider manager + phm := NewProvidersHealthManager(1) + ctx, cancel := context.WithCancel(s.ctx) + s.manager.RegisterProvidersHealthManager(ctx, phm) + + defer cancel() + + // Subscribe two subscribers + subscriber1 := s.manager.Subscribe() + subscriber2 := s.manager.Subscribe() + + // Unsubscribe the first subscriber + s.manager.Unsubscribe(subscriber1) + + phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "provider-1", Timestamp: time.Now(), Err: nil}, + }) + + // Ensure the first subscriber did not receive a notification + select { + case _, ok := <-subscriber1: + s.False(ok, "First subscriber channel should be closed") + default: + s.Fail("First subscriber channel was not closed") + } + + // Ensure the second subscriber received a notification + select { + case <-subscriber2: + // Notification received by the second subscriber + case <-time.After(100 * time.Millisecond): + s.Fail("Second subscriber should have received a notification") + } +} + +func (s *BlockchainHealthManagerSuite) TestMixedProviderStatusInSingleChain() { + // Register a provider for chain 1 + phm := NewProvidersHealthManager(1) + s.manager.RegisterProvidersHealthManager(s.ctx, phm) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Simulate mixed statuses within the same chain (one provider up, one provider down) + phm.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: "provider1_chain1", Timestamp: time.Now(), Err: nil}, // Provider 1 is up + {Name: "provider2_chain1", Timestamp: time.Now(), Err: errors.New("error")}, // Provider 2 is down + }) + + // Wait for the status to propagate + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Verify that the short status reflects the chain as down, since one provider is down + shortStatus := s.manager.GetShortStatus() + s.Equal(rpcstatus.StatusUp, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusUp, shortStatus.StatusPerChain[1].Status) // Chain 1 should be marked as down +} + +func TestBlockchainHealthManagerSuite(t *testing.T) { + suite.Run(t, new(BlockchainHealthManagerSuite)) +} diff --git a/healthmanager/provider_errors/provider_errors.go b/healthmanager/provider_errors/provider_errors.go new file mode 100644 index 000000000..bec09d0a1 --- /dev/null +++ b/healthmanager/provider_errors/provider_errors.go @@ -0,0 +1,247 @@ +package provider_errors + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "strings" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/rpc/chain/rpclimiter" +) + +// ProviderErrorType defines the type of non-RPC error for JSON serialization. +type ProviderErrorType string + +const ( + // Non-RPC Errors + ProviderErrorTypeNone ProviderErrorType = "none" + ProviderErrorTypeContextCanceled ProviderErrorType = "context_canceled" + ProviderErrorTypeContextDeadlineExceeded ProviderErrorType = "context_deadline" + ProviderErrorTypeConnection ProviderErrorType = "connection" + ProviderErrorTypeNotAuthorized ProviderErrorType = "not_authorized" + ProviderErrorTypeForbidden ProviderErrorType = "forbidden" + ProviderErrorTypeBadRequest ProviderErrorType = "bad_request" + ProviderErrorTypeContentTooLarge ProviderErrorType = "content_too_large" + ProviderErrorTypeInternalError ProviderErrorType = "internal" + ProviderErrorTypeServiceUnavailable ProviderErrorType = "service_unavailable" + ProviderErrorTypeRateLimit ProviderErrorType = "rate_limit" + ProviderErrorTypeOther ProviderErrorType = "other" +) + +// IsConnectionError checks if the error is related to network issues. +func IsConnectionError(err error) bool { + if err == nil { + return false + } + + // Check for net.Error (timeout or other network errors) + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return true + } + } + + // Check for DNS errors + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return true + } + + // Check for network operation errors (e.g., connection refused) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // Check for TLS errors + var tlsRecordErr *tls.RecordHeaderError + if errors.As(err, &tlsRecordErr) { + return true + } + + // FIXME: Check for TLS ECH Rejection Error (tls.ECHRejectionError is added in go 1.23) + + // Check for TLS Certificate Verification Error + var certVerifyErr *tls.CertificateVerificationError + if errors.As(err, &certVerifyErr) { + return true + } + + // Check for TLS Alert Error + var alertErr tls.AlertError + if errors.As(err, &alertErr) { + return true + } + + // Check for specific HTTP server closed error + if errors.Is(err, http.ErrServerClosed) { + return true + } + + // Common connection refused or timeout error messages + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "i/o timeout") || + strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "network is unreachable") || + strings.Contains(errMsg, "no such host") || + strings.Contains(errMsg, "tls handshake timeout") { + return true + } + + return false +} + +func IsRateLimitError(err error) bool { + if err == nil { + return false + } + + if ok, statusCode := IsHTTPError(err); ok && statusCode == 429 { + return true + } + + if errors.Is(err, rpclimiter.ErrRequestsOverLimit) { + return true + } + + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "backoff_seconds") || + strings.Contains(errMsg, "has exceeded its throughput limit") || + strings.Contains(errMsg, "request rate exceeded") { + return true + } + return false +} + +// Don't mark connection as failed if we get one of these errors +var propagateErrors = []error{ + vm.ErrOutOfGas, + vm.ErrCodeStoreOutOfGas, + vm.ErrDepth, + vm.ErrInsufficientBalance, + vm.ErrContractAddressCollision, + vm.ErrExecutionReverted, + vm.ErrMaxCodeSizeExceeded, + vm.ErrInvalidJump, + vm.ErrWriteProtection, + vm.ErrReturnDataOutOfBounds, + vm.ErrGasUintOverflow, + vm.ErrInvalidCode, + vm.ErrNonceUintOverflow, + + // Used by balance history to check state + bind.ErrNoCode, +} + +func IsHTTPError(err error) (bool, int) { + var httpErrPtr *rpc.HTTPError + if errors.As(err, &httpErrPtr) { + return true, httpErrPtr.StatusCode + } + + var httpErr rpc.HTTPError + if errors.As(err, &httpErr) { + return true, httpErr.StatusCode + } + + return false, 0 +} + +func IsNotAuthorizedError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 401 + } + return false +} + +func IsForbiddenError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 403 + } + return false +} + +func IsBadRequestError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 400 + } + return false +} + +func IsContentTooLargeError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 413 + } + return false +} + +func IsInternalServerError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 500 + } + return false +} + +func IsServiceUnavailableError(err error) bool { + if ok, statusCode := IsHTTPError(err); ok { + return statusCode == 503 + } + return false +} + +// determineProviderErrorType determines the ProviderErrorType based on the error. +func determineProviderErrorType(err error) ProviderErrorType { + if err == nil { + return ProviderErrorTypeNone + } + if errors.Is(err, context.Canceled) { + return ProviderErrorTypeContextCanceled + } + if errors.Is(err, context.DeadlineExceeded) { + return ProviderErrorTypeContextDeadlineExceeded + } + if IsConnectionError(err) { + return ProviderErrorTypeConnection + } + if IsNotAuthorizedError(err) { + return ProviderErrorTypeNotAuthorized + } + if IsForbiddenError(err) { + return ProviderErrorTypeForbidden + } + if IsBadRequestError(err) { + return ProviderErrorTypeBadRequest + } + if IsContentTooLargeError(err) { + return ProviderErrorTypeContentTooLarge + } + if IsInternalServerError(err) { + return ProviderErrorTypeInternalError + } + if IsServiceUnavailableError(err) { + return ProviderErrorTypeServiceUnavailable + } + if IsRateLimitError(err) { + return ProviderErrorTypeRateLimit + } + // Add additional non-RPC checks as necessary + return ProviderErrorTypeOther +} + +// IsNonCriticalProviderError determines if the non-RPC error is not critical. +func IsNonCriticalProviderError(err error) bool { + errorType := determineProviderErrorType(err) + + switch errorType { + case ProviderErrorTypeNone, ProviderErrorTypeContextCanceled, ProviderErrorTypeContentTooLarge, ProviderErrorTypeRateLimit: + return true + default: + return false + } +} diff --git a/healthmanager/provider_errors/provider_errors_test.go b/healthmanager/provider_errors/provider_errors_test.go new file mode 100644 index 000000000..34af6b81b --- /dev/null +++ b/healthmanager/provider_errors/provider_errors_test.go @@ -0,0 +1,115 @@ +package provider_errors + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "testing" +) + +// TestIsConnectionError tests the IsConnectionError function. +func TestIsConnectionError(t *testing.T) { + tests := []struct { + name string + err error + wantResult bool + }{ + { + name: "nil error", + err: nil, + wantResult: false, + }, + { + name: "net.DNSError with timeout", + err: &net.DNSError{IsTimeout: true}, + wantResult: true, + }, + { + name: "DNS error without timeout", + err: &net.DNSError{}, + wantResult: true, + }, + { + name: "net.OpError", + err: &net.OpError{}, + wantResult: true, + }, + { + name: "tls.RecordHeaderError", + err: &tls.RecordHeaderError{}, + wantResult: true, + }, + { + name: "tls.CertificateVerificationError", + err: &tls.CertificateVerificationError{}, + wantResult: true, + }, + { + name: "tls.AlertError", + err: tls.AlertError(0), + wantResult: true, + }, + { + name: "context.DeadlineExceeded", + err: context.DeadlineExceeded, + wantResult: true, + }, + { + name: "http.ErrServerClosed", + err: http.ErrServerClosed, + wantResult: true, + }, + { + name: "i/o timeout error message", + err: errors.New("i/o timeout"), + wantResult: true, + }, + { + name: "connection refused error message", + err: errors.New("connection refused"), + wantResult: true, + }, + { + name: "network is unreachable error message", + err: errors.New("network is unreachable"), + wantResult: true, + }, + { + name: "no such host error message", + err: errors.New("no such host"), + wantResult: true, + }, + { + name: "tls handshake timeout error message", + err: errors.New("tls handshake timeout"), + wantResult: true, + }, + { + name: "rps limit error 1", + err: errors.New("backoff_seconds"), + wantResult: false, + }, + { + name: "rps limit error 2", + err: errors.New("has exceeded its throughput limit"), + wantResult: false, + }, + { + name: "rps limit error 3", + err: errors.New("request rate exceeded"), + wantResult: false, + }, + } + + for _, tt := range tests { + tt := tt // capture the variable + t.Run(tt.name, func(t *testing.T) { + got := IsConnectionError(tt.err) + if got != tt.wantResult { + t.Errorf("IsConnectionError(%v) = %v; want %v", tt.err, got, tt.wantResult) + } + }) + } +} diff --git a/healthmanager/provider_errors/rpc_provider_errors.go b/healthmanager/provider_errors/rpc_provider_errors.go new file mode 100644 index 000000000..49f8c3447 --- /dev/null +++ b/healthmanager/provider_errors/rpc_provider_errors.go @@ -0,0 +1,86 @@ +package provider_errors + +import ( + "errors" + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/rpc" + "strings" +) + +type RpcProviderErrorType string + +const ( + // RPC Errors + RpcErrorTypeNone RpcProviderErrorType = "none" + RpcErrorTypeMethodNotFound RpcProviderErrorType = "rpc_method_not_found" + RpcErrorTypeRPSLimit RpcProviderErrorType = "rpc_rps_limit" + RpcErrorTypeVMError RpcProviderErrorType = "rpc_vm_error" + RpcErrorTypeRPCOther RpcProviderErrorType = "rpc_other" +) + +// Not found should not be cancelling the requests, as that's returned +// when we are hitting a non archival node for example, it should continue the +// chain as the next provider might have archival support. +func IsNotFoundError(err error) bool { + return strings.Contains(err.Error(), ethereum.NotFound.Error()) +} + +func IsRPCError(err error) (rpc.Error, bool) { + var rpcErr rpc.Error + if errors.As(err, &rpcErr) { + return rpcErr, true + } + return nil, false +} + +func IsMethodNotFoundError(err error) bool { + if rpcErr, ok := IsRPCError(err); ok { + return rpcErr.ErrorCode() == -32601 + } + return false +} + +func IsVMError(err error) bool { + if rpcErr, ok := IsRPCError(err); ok { + return rpcErr.ErrorCode() == -32015 // Код ошибки VM execution error + } + if strings.Contains(err.Error(), core.ErrInsufficientFunds.Error()) { + return true + } + for _, vmError := range propagateErrors { + if strings.Contains(err.Error(), vmError.Error()) { + return true + } + } + return false +} + +// determineRpcErrorType determines the RpcProviderErrorType based on the error. +func determineRpcErrorType(err error) RpcProviderErrorType { + if err == nil { + return RpcErrorTypeNone + } + //if IsRpsLimitError(err) { + // return RpcErrorTypeRPSLimit + //} + if IsMethodNotFoundError(err) || IsNotFoundError(err) { + return RpcErrorTypeMethodNotFound + } + if IsVMError(err) { + return RpcErrorTypeVMError + } + return RpcErrorTypeRPCOther +} + +// IsCriticalRpcError determines if the RPC error is critical. +func IsNonCriticalRpcError(err error) bool { + errorType := determineRpcErrorType(err) + + switch errorType { + case RpcErrorTypeNone, RpcErrorTypeMethodNotFound, RpcErrorTypeRPSLimit, RpcErrorTypeVMError: + return true + default: + return false + } +} diff --git a/healthmanager/provider_errors/rpc_provider_errors_test.go b/healthmanager/provider_errors/rpc_provider_errors_test.go new file mode 100644 index 000000000..3d16fae7f --- /dev/null +++ b/healthmanager/provider_errors/rpc_provider_errors_test.go @@ -0,0 +1,51 @@ +package provider_errors + +import ( + "errors" + "testing" +) + +// TestIsRpsLimitError tests the IsRpsLimitError function. +func TestIsRpsLimitError(t *testing.T) { + tests := []struct { + name string + err error + wantResult bool + }{ + { + name: "Error contains 'backoff_seconds'", + err: errors.New("Error: backoff_seconds: 30"), + wantResult: true, + }, + { + name: "Error contains 'has exceeded its throughput limit'", + err: errors.New("Your application has exceeded its throughput limit."), + wantResult: true, + }, + { + name: "Error contains 'request rate exceeded'", + err: errors.New("Request rate exceeded. Please try again later."), + wantResult: true, + }, + { + name: "Error does not contain any matching phrases", + err: errors.New("Some other error occurred."), + wantResult: false, + }, + { + name: "Error is nil", + err: nil, + wantResult: false, + }, + } + + for _, tt := range tests { + tt := tt // capture the variable + t.Run(tt.name, func(t *testing.T) { + got := IsRateLimitError(tt.err) + if got != tt.wantResult { + t.Errorf("IsRpsLimitError(%v) = %v; want %v", tt.err, got, tt.wantResult) + } + }) + } +} diff --git a/healthmanager/providers_health_manager.go b/healthmanager/providers_health_manager.go new file mode 100644 index 000000000..9250132c8 --- /dev/null +++ b/healthmanager/providers_health_manager.go @@ -0,0 +1,126 @@ +package healthmanager + +import ( + "context" + "fmt" + "sync" + + statusaggregator "github.com/status-im/status-go/healthmanager/aggregator" + "github.com/status-im/status-go/healthmanager/rpcstatus" +) + +type ProvidersHealthManager struct { + mu sync.RWMutex + chainID uint64 + aggregator *statusaggregator.Aggregator + subscribers []chan struct{} +} + +// NewProvidersHealthManager creates a new instance of ProvidersHealthManager with the given chain ID. +func NewProvidersHealthManager(chainID uint64) *ProvidersHealthManager { + aggregator := statusaggregator.NewAggregator(fmt.Sprintf("%d", chainID)) + + return &ProvidersHealthManager{ + chainID: chainID, + aggregator: aggregator, + } +} + +// Update processes a batch of provider call statuses, updates the aggregated status, and emits chain status changes if necessary. +func (p *ProvidersHealthManager) Update(ctx context.Context, callStatuses []rpcstatus.RpcProviderCallStatus) { + p.mu.Lock() + + previousStatus := p.aggregator.GetAggregatedStatus() + + // Update the aggregator with the new provider statuses + for _, rpcCallStatus := range callStatuses { + providerStatus := rpcstatus.NewRpcProviderStatus(rpcCallStatus) + p.aggregator.Update(providerStatus) + } + + newStatus := p.aggregator.GetAggregatedStatus() + + shouldEmit := newStatus.Status != previousStatus.Status + p.mu.Unlock() + + if !shouldEmit { + return + } + + p.emitChainStatus(ctx) +} + +// GetStatuses returns a copy of the current provider statuses. +func (p *ProvidersHealthManager) GetStatuses() map[string]rpcstatus.ProviderStatus { + p.mu.RLock() + defer p.mu.RUnlock() + return p.aggregator.GetStatuses() +} + +// Subscribe allows providers to receive notifications about changes. +func (p *ProvidersHealthManager) Subscribe() chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + + ch := make(chan struct{}, 1) + p.subscribers = append(p.subscribers, ch) + return ch +} + +// Unsubscribe removes a subscriber from receiving notifications. +func (p *ProvidersHealthManager) Unsubscribe(ch chan struct{}) { + p.mu.Lock() + defer p.mu.Unlock() + + for i, subscriber := range p.subscribers { + if subscriber == ch { + p.subscribers = append(p.subscribers[:i], p.subscribers[i+1:]...) + close(ch) + break + } + } +} + +// UnsubscribeAll removes all subscriber channels. +func (p *ProvidersHealthManager) UnsubscribeAll() { + p.mu.Lock() + defer p.mu.Unlock() + for _, subscriber := range p.subscribers { + close(subscriber) + } + p.subscribers = nil +} + +// Reset clears all provider statuses and resets the chain status to unknown. +func (p *ProvidersHealthManager) Reset() { + p.mu.Lock() + defer p.mu.Unlock() + p.aggregator = statusaggregator.NewAggregator(fmt.Sprintf("%d", p.chainID)) +} + +// Status Returns the current aggregated status +func (p *ProvidersHealthManager) Status() rpcstatus.ProviderStatus { + p.mu.RLock() + defer p.mu.RUnlock() + return p.aggregator.GetAggregatedStatus() +} + +// ChainID returns the ID of the chain. +func (p *ProvidersHealthManager) ChainID() uint64 { + return p.chainID +} + +// emitChainStatus sends a notification to all subscribers. +func (p *ProvidersHealthManager) emitChainStatus(ctx context.Context) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, subscriber := range p.subscribers { + select { + case subscriber <- struct{}{}: + case <-ctx.Done(): + return + default: + // Non-blocking send; skip if the channel is full + } + } +} diff --git a/healthmanager/providers_health_manager_test.go b/healthmanager/providers_health_manager_test.go new file mode 100644 index 000000000..4701e56ec --- /dev/null +++ b/healthmanager/providers_health_manager_test.go @@ -0,0 +1,206 @@ +package healthmanager + +import ( + "context" + "errors" + "fmt" + "github.com/status-im/status-go/healthmanager/rpcstatus" + "github.com/stretchr/testify/suite" + "sync" + "testing" + "time" +) + +type ProvidersHealthManagerSuite struct { + suite.Suite + phm *ProvidersHealthManager +} + +// SetupTest initializes the ProvidersHealthManager before each test +func (s *ProvidersHealthManagerSuite) SetupTest() { + s.phm = NewProvidersHealthManager(1) +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *ProvidersHealthManagerSuite) updateAndWait(ch <-chan struct{}, statuses []rpcstatus.RpcProviderCallStatus, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + s.phm.Update(context.Background(), statuses) + + select { + case <-ch: + // Received notification + case <-time.After(timeout): + s.Fail("Timeout waiting for chain status update") + } + + s.assertChainStatus(expectedChainStatus) +} + +// Helper method to update providers and wait for a notification on the given channel +func (s *ProvidersHealthManagerSuite) updateAndExpectNoNotification(ch <-chan struct{}, statuses []rpcstatus.RpcProviderCallStatus, expectedChainStatus rpcstatus.StatusType, timeout time.Duration) { + s.phm.Update(context.Background(), statuses) + + select { + case <-ch: + s.Fail("Unexpected status update") + case <-time.After(timeout): + // No notification as expected + } + + s.assertChainStatus(expectedChainStatus) +} + +// Helper method to assert the current chain status +func (s *ProvidersHealthManagerSuite) assertChainStatus(expected rpcstatus.StatusType) { + actual := s.phm.Status().Status + s.Equal(expected, actual, fmt.Sprintf("Expected chain status to be %s", expected)) +} + +func (s *ProvidersHealthManagerSuite) TestInitialStatus() { + s.assertChainStatus(rpcstatus.StatusDown) +} + +func (s *ProvidersHealthManagerSuite) TestUpdateProviderStatuses() { + s.updateAndWait(s.phm.Subscribe(), []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: nil}, + {Name: "Provider2", Timestamp: time.Now(), Err: errors.New("connection error")}, + }, rpcstatus.StatusUp, time.Second) + + statusMap := s.phm.GetStatuses() + s.Len(statusMap, 2, "Expected 2 provider statuses") + s.Equal(rpcstatus.StatusUp, statusMap["Provider1"].Status, "Expected Provider1 status to be Up") + s.Equal(rpcstatus.StatusDown, statusMap["Provider2"].Status, "Expected Provider2 status to be Down") +} + +func (s *ProvidersHealthManagerSuite) TestChainStatusUpdatesOnce() { + ch := s.phm.Subscribe() + s.assertChainStatus(rpcstatus.StatusDown) + + // Update providers to Down + statuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: errors.New("error")}, + {Name: "Provider2", Timestamp: time.Now(), Err: nil}, + } + s.updateAndWait(ch, statuses, rpcstatus.StatusUp, time.Second) + s.updateAndExpectNoNotification(ch, statuses, rpcstatus.StatusUp, 10*time.Millisecond) +} + +func (s *ProvidersHealthManagerSuite) TestSubscribeReceivesOnlyOnChange() { + ch := s.phm.Subscribe() + + // Update provider to Up and wait for notification + upStatuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: nil}, + } + s.updateAndWait(ch, upStatuses, rpcstatus.StatusUp, time.Second) + + // Update provider to Down and wait for notification + downStatuses := []rpcstatus.RpcProviderCallStatus{ + {Name: "Provider1", Timestamp: time.Now(), Err: errors.New("some critical error")}, + } + s.updateAndWait(ch, downStatuses, rpcstatus.StatusDown, time.Second) + + s.updateAndExpectNoNotification(ch, downStatuses, rpcstatus.StatusDown, 10*time.Millisecond) +} + +func (s *ProvidersHealthManagerSuite) TestConcurrency() { + var wg sync.WaitGroup + providerCount := 1000 + + s.phm.Update(context.Background(), []rpcstatus.RpcProviderCallStatus{ + {Name: "ProviderUp", Timestamp: time.Now(), Err: nil}, + }) + + ctx := context.Background() + for i := 0; i < providerCount-1; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + providerName := fmt.Sprintf("Provider%d", i) + var err error + if i%2 == 0 { + err = errors.New("error") + } + s.phm.Update(ctx, []rpcstatus.RpcProviderCallStatus{ + {Name: providerName, Timestamp: time.Now(), Err: err}, + }) + }(i) + } + wg.Wait() + + statuses := s.phm.GetStatuses() + s.Len(statuses, providerCount, "Expected 1000 provider statuses") + + chainStatus := s.phm.Status().Status + s.Equal(chainStatus, rpcstatus.StatusUp, "Expected chain status to be either Up or Down") +} + +func (s *BlockchainHealthManagerSuite) TestInterleavedChainStatusChanges() { + // Register providers for chains 1, 2, and 3 + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + phm3 := NewProvidersHealthManager(3) + s.manager.RegisterProvidersHealthManager(s.ctx, phm1) + s.manager.RegisterProvidersHealthManager(s.ctx, phm2) + s.manager.RegisterProvidersHealthManager(s.ctx, phm3) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Initially, all chains are up + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain1", Timestamp: time.Now(), Err: nil}}) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain2", Timestamp: time.Now(), Err: nil}}) + phm3.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain3", Timestamp: time.Now(), Err: nil}}) + + // Wait for the status to propagate + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Now chain 1 goes down, and chain 3 goes down at the same time + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain1", Timestamp: time.Now(), Err: errors.New("connection error")}}) + phm3.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider_chain3", Timestamp: time.Now(), Err: errors.New("connection error")}}) + + // Wait for the status to reflect the changes + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Check that short status correctly reflects the mixed state + shortStatus := s.manager.GetShortStatus() + s.Equal(rpcstatus.StatusUp, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[1].Status) // Chain 1 is down + s.Equal(rpcstatus.StatusUp, shortStatus.StatusPerChain[2].Status) // Chain 2 is still up + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[3].Status) // Chain 3 is down +} + +func (s *BlockchainHealthManagerSuite) TestDelayedChainUpdate() { + // Register providers for chains 1 and 2 + phm1 := NewProvidersHealthManager(1) + phm2 := NewProvidersHealthManager(2) + s.manager.RegisterProvidersHealthManager(s.ctx, phm1) + s.manager.RegisterProvidersHealthManager(s.ctx, phm2) + + // Subscribe to status updates + ch := s.manager.Subscribe() + defer s.manager.Unsubscribe(ch) + + // Initially, both chains are up + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain1", Timestamp: time.Now(), Err: nil}}) + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain2", Timestamp: time.Now(), Err: nil}}) + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Chain 2 goes down + phm2.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain2", Timestamp: time.Now(), Err: errors.New("connection error")}}) + s.waitForUpdate(ch, rpcstatus.StatusUp, 100*time.Millisecond) + + // Chain 1 goes down after a delay + phm1.Update(s.ctx, []rpcstatus.RpcProviderCallStatus{{Name: "provider1_chain1", Timestamp: time.Now(), Err: errors.New("connection error")}}) + s.waitForUpdate(ch, rpcstatus.StatusDown, 100*time.Millisecond) + + // Check that short status reflects the final state where both chains are down + shortStatus := s.manager.GetShortStatus() + s.Equal(rpcstatus.StatusDown, shortStatus.Status.Status) + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[1].Status) // Chain 1 is down + s.Equal(rpcstatus.StatusDown, shortStatus.StatusPerChain[2].Status) // Chain 2 is down +} + +func TestProvidersHealthManagerSuite(t *testing.T) { + suite.Run(t, new(ProvidersHealthManagerSuite)) +} diff --git a/healthmanager/rpcstatus/provider_status.go b/healthmanager/rpcstatus/provider_status.go new file mode 100644 index 000000000..a79b8066b --- /dev/null +++ b/healthmanager/rpcstatus/provider_status.go @@ -0,0 +1,77 @@ +package rpcstatus + +import ( + "time" + + "github.com/status-im/status-go/healthmanager/provider_errors" +) + +// StatusType represents the possible status values for a provider. +type StatusType string + +const ( + StatusUnknown StatusType = "unknown" + StatusUp StatusType = "up" + StatusDown StatusType = "down" +) + +// ProviderStatus holds the status information for a single provider. +type ProviderStatus struct { + Name string + LastSuccessAt time.Time + LastErrorAt time.Time + LastError error + Status StatusType +} + +// ProviderCallStatus represents the result of an arbitrary provider call. +type ProviderCallStatus struct { + Name string + Timestamp time.Time + Err error +} + +// RpcProviderCallStatus represents the result of an RPC provider call. +type RpcProviderCallStatus struct { + Name string + Timestamp time.Time + Err error +} + +// NewRpcProviderStatus processes RpcProviderCallStatus and returns a new ProviderStatus. +func NewRpcProviderStatus(res RpcProviderCallStatus) ProviderStatus { + status := ProviderStatus{ + Name: res.Name, + } + + // Determine if the error is critical + if res.Err == nil || provider_errors.IsNonCriticalRpcError(res.Err) || provider_errors.IsNonCriticalProviderError(res.Err) { + status.LastSuccessAt = res.Timestamp + status.Status = StatusUp + } else { + status.LastErrorAt = res.Timestamp + status.LastError = res.Err + status.Status = StatusDown + } + + return status +} + +// NewProviderStatus processes ProviderCallStatus and returns a new ProviderStatus. +func NewProviderStatus(res ProviderCallStatus) ProviderStatus { + status := ProviderStatus{ + Name: res.Name, + } + + // Determine if the error is critical + if res.Err == nil || provider_errors.IsNonCriticalProviderError(res.Err) { + status.LastSuccessAt = res.Timestamp + status.Status = StatusUp + } else { + status.LastErrorAt = res.Timestamp + status.LastError = res.Err + status.Status = StatusDown + } + + return status +} diff --git a/healthmanager/rpcstatus/provider_status_test.go b/healthmanager/rpcstatus/provider_status_test.go new file mode 100644 index 000000000..c06cef188 --- /dev/null +++ b/healthmanager/rpcstatus/provider_status_test.go @@ -0,0 +1,174 @@ +package rpcstatus + +import ( + "errors" + "github.com/status-im/status-go/rpc/chain/rpclimiter" + "testing" + "time" +) + +func TestNewRpcProviderStatus(t *testing.T) { + tests := []struct { + name string + res RpcProviderCallStatus + expected ProviderStatus + }{ + { + name: "No error, should be up", + res: RpcProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: nil, + }, + expected: ProviderStatus{ + Name: "Provider1", + Status: StatusUp, + }, + }, + { + name: "Critical RPC error, should be down", + res: RpcProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: errors.New("Some critical RPC error"), + }, + expected: ProviderStatus{ + Name: "Provider1", + LastError: errors.New("Some critical RPC error"), + Status: StatusDown, + }, + }, + { + name: "Non-critical RPC error, should be up", + res: RpcProviderCallStatus{ + Name: "Provider2", + Timestamp: time.Now(), + Err: rpclimiter.ErrRequestsOverLimit, // Assuming this is non-critical + }, + expected: ProviderStatus{ + Name: "Provider2", + Status: StatusUp, + }, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + got := NewRpcProviderStatus(tt.res) + + // Compare expected and got + if got.Name != tt.expected.Name { + t.Errorf("expected name %v, got %v", tt.expected.Name, got.Name) + } + + // Check LastSuccessAt for StatusUp + if tt.expected.Status == StatusUp { + if got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be set, but got zero value") + } + if !got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be zero, but got %v", got.LastErrorAt) + } + } else if tt.expected.Status == StatusDown { + if got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be set, but got zero value") + } + if !got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be zero, but got %v", got.LastSuccessAt) + } + } + + if got.Status != tt.expected.Status { + t.Errorf("expected status %v, got %v", tt.expected.Status, got.Status) + } + + if got.LastError != nil && tt.expected.LastError != nil && got.LastError.Error() != tt.expected.LastError.Error() { + t.Errorf("expected last error %v, got %v", tt.expected.LastError, got.LastError) + } + }) + } +} + +func TestNewProviderStatus(t *testing.T) { + tests := []struct { + name string + res ProviderCallStatus + expected ProviderStatus + }{ + { + name: "No error, should be up", + res: ProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: nil, + }, + expected: ProviderStatus{ + Name: "Provider1", + Status: StatusUp, + }, + }, + { + name: "Critical provider error, should be down", + res: ProviderCallStatus{ + Name: "Provider1", + Timestamp: time.Now(), + Err: errors.New("Some critical provider error"), + }, + expected: ProviderStatus{ + Name: "Provider1", + LastError: errors.New("Some critical provider error"), + Status: StatusDown, + }, + }, + { + name: "Non-critical provider error, should be up", + res: ProviderCallStatus{ + Name: "Provider2", + Timestamp: time.Now(), + Err: errors.New("backoff_seconds"), // Assuming this is non-critical + }, + expected: ProviderStatus{ + Name: "Provider2", + Status: StatusUp, + }, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + got := NewProviderStatus(tt.res) + + // Compare expected and got + if got.Name != tt.expected.Name { + t.Errorf("expected name %v, got %v", tt.expected.Name, got.Name) + } + + // Check LastSuccessAt for StatusUp + if tt.expected.Status == StatusUp { + if got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be set, but got zero value") + } + if !got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be zero, but got %v", got.LastErrorAt) + } + } else if tt.expected.Status == StatusDown { + if got.LastErrorAt.IsZero() { + t.Errorf("expected LastErrorAt to be set, but got zero value") + } + if !got.LastSuccessAt.IsZero() { + t.Errorf("expected LastSuccessAt to be zero, but got %v", got.LastSuccessAt) + } + } + + if got.Status != tt.expected.Status { + t.Errorf("expected status %v, got %v", tt.expected.Status, got.Status) + } + + if got.LastError != nil && tt.expected.LastError != nil && got.LastError.Error() != tt.expected.LastError.Error() { + t.Errorf("expected last error %v, got %v", tt.expected.LastError, got.LastError) + } + }) + } +}