diff --git a/circuitbreaker/circuit_breaker.go b/circuitbreaker/circuit_breaker.go index a1dfda3a2..0b6a97e2c 100644 --- a/circuitbreaker/circuit_breaker.go +++ b/circuitbreaker/circuit_breaker.go @@ -3,6 +3,7 @@ package circuitbreaker import ( "context" "fmt" + "time" "github.com/afex/hystrix-go/hystrix" @@ -12,8 +13,16 @@ import ( type FallbackFunc func() ([]any, error) type CommandResult struct { - res []any - err error + res []any + err error + functorCallStatuses []FunctorCallStatus + cancelled bool +} + +type FunctorCallStatus struct { + Name string + Timestamp time.Time + Err error } func (cr CommandResult) Result() []any { @@ -23,6 +32,21 @@ func (cr CommandResult) Result() []any { func (cr CommandResult) Error() error { return cr.err } +func (cr CommandResult) Cancelled() bool { + return cr.cancelled +} + +func (cr CommandResult) FunctorCallStatuses() []FunctorCallStatus { + return cr.functorCallStatuses +} + +func (cr *CommandResult) addCallStatus(circuitName string, err error) { + cr.functorCallStatuses = append(cr.functorCallStatuses, FunctorCallStatus{ + Name: circuitName, + Timestamp: time.Now(), + Err: err, + }) +} type Command struct { ctx context.Context @@ -106,23 +130,26 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { for i, f := range cmd.functors { if cmd.cancel { + result.cancelled = true break } var err error + circuitName := f.circuitName + if cb.circuitNameHandler != nil { + circuitName = cb.circuitNameHandler(circuitName) + } + // if last command, execute without circuit if i == len(cmd.functors)-1 { res, execErr := f.exec() err = execErr if err == nil { - result = CommandResult{res: res} + result.res = res + result.err = nil } + result.addCallStatus(circuitName, err) } else { - circuitName := f.circuitName - if cb.circuitNameHandler != nil { - circuitName = cb.circuitNameHandler(circuitName) - } - if hystrix.GetCircuitSettings()[circuitName] == nil { hystrix.ConfigureCommand(circuitName, hystrix.CommandConfig{ Timeout: cb.config.Timeout, @@ -137,13 +164,16 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { res, err := f.exec() // Write to result only if success if err == nil { - result = CommandResult{res: res} + result.res = res + result.err = nil } + result.addCallStatus(circuitName, err) // If the command has been cancelled, we don't count // the error towars breaking the circuit, and then we break if cmd.cancel { - result = accumulateCommandError(result, f.circuitName, err) + result = accumulateCommandError(result, circuitName, err) + result.cancelled = true return nil } if err != nil { @@ -156,7 +186,7 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult { break } - result = accumulateCommandError(result, f.circuitName, err) + result = accumulateCommandError(result, circuitName, err) // Lets abuse every provider with the same amount of MaxConcurrentRequests, // keep iterating even in case of ErrMaxConcurrency error diff --git a/circuitbreaker/circuit_breaker_test.go b/circuitbreaker/circuit_breaker_test.go index 42d35d0c3..1f7ce820f 100644 --- a/circuitbreaker/circuit_breaker_test.go +++ b/circuitbreaker/circuit_breaker_test.go @@ -34,6 +34,7 @@ func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) { result := cb.Execute(cmd) require.NoError(t, result.Error()) require.Equal(t, expectedResult, result.Result()[0].(string)) + require.False(t, result.Cancelled()) } func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) { @@ -219,9 +220,11 @@ func TestCircuitBreaker_CommandCancel(t *testing.T) { result := cb.Execute(cmd) require.True(t, errors.Is(result.Error(), expectedErr)) + require.True(t, result.Cancelled()) assert.Equal(t, 1, prov1Called) assert.Equal(t, 0, prov2Called) + } func TestCircuitBreaker_EmptyOrNilCommand(t *testing.T) { @@ -301,3 +304,149 @@ func TestCircuitBreaker_Fallback(t *testing.T) { assert.Equal(t, 1, prov1Called) } + +func TestCircuitBreaker_SuccessCallStatus(t *testing.T) { + cb := NewCircuitBreaker(Config{}) + + functor := NewFunctor(func() ([]any, error) { + return []any{"success"}, nil + }, "successCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + + result := cb.Execute(cmd) + + require.Nil(t, result.Error()) + require.False(t, result.Cancelled()) + assert.Len(t, result.Result(), 1) + require.Equal(t, "success", result.Result()[0]) + assert.Len(t, result.FunctorCallStatuses(), 1) + + status := result.FunctorCallStatuses()[0] + if status.name != "successCircuit" { + t.Errorf("Expected functor name to be 'successCircuit', got %s", status.name) + } + if status.err != nil { + t.Errorf("Expected no error in functor status, got %v", status.err) + } +} + +func TestCircuitBreaker_ErrorCallStatus(t *testing.T) { + cb := NewCircuitBreaker(Config{}) + + expectedError := errors.New("functor error") + functor := NewFunctor(func() ([]any, error) { + return nil, expectedError + }, "errorCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + + result := cb.Execute(cmd) + + require.NotNil(t, result.Error()) + require.True(t, errors.Is(result.Error(), expectedError)) + + assert.Len(t, result.Result(), 0) + assert.Len(t, result.FunctorCallStatuses(), 1) + + status := result.FunctorCallStatuses()[0] + if status.name != "errorCircuit" { + t.Errorf("Expected functor name to be 'errorCircuit', got %s", status.name) + } + if !errors.Is(status.err, expectedError) { + t.Errorf("Expected functor error to be '%v', got '%v'", expectedError, status.err) + } +} + +func TestCircuitBreaker_CancelledResult(t *testing.T) { + cb := NewCircuitBreaker(Config{Timeout: 1000}) + + functor := NewFunctor(func() ([]any, error) { + time.Sleep(500 * time.Millisecond) + return []any{"should not be returned"}, nil + }, "cancelCircuit") + + cmd := NewCommand(context.Background(), []*Functor{functor}) + cmd.Cancel() + + result := cb.Execute(cmd) + + assert.True(t, result.Cancelled()) + require.Nil(t, result.Error()) + require.Empty(t, result.Result()) + require.Empty(t, result.FunctorCallStatuses()) +} + +func TestCircuitBreaker_MultipleFunctorsResult(t *testing.T) { + cb := NewCircuitBreaker(Config{ + Timeout: 1000, + MaxConcurrentRequests: 100, + RequestVolumeThreshold: 20, + SleepWindow: 5000, + ErrorPercentThreshold: 50, + }) + + functor1 := NewFunctor(func() ([]any, error) { + return nil, errors.New("functor1 error") + }, "circuit1") + + functor2 := NewFunctor(func() ([]any, error) { + return []any{"success from functor2"}, nil + }, "circuit2") + + cmd := NewCommand(context.Background(), []*Functor{functor1, functor2}) + + result := cb.Execute(cmd) + + require.Nil(t, result.Error()) + + require.Len(t, result.Result(), 1) + require.Equal(t, result.Result()[0], "success from functor2") + statuses := result.FunctorCallStatuses() + require.Len(t, statuses, 2) + + require.Equal(t, statuses[0].name, "circuit1") + require.NotNil(t, statuses[0].err) + + require.Equal(t, statuses[1].name, "circuit2") + require.Nil(t, statuses[1].err) +} + +func TestCircuitBreaker_LastFunctorDirectExecution(t *testing.T) { + cb := NewCircuitBreaker(Config{ + Timeout: 10, // short timeout to open circuit + MaxConcurrentRequests: 1, + RequestVolumeThreshold: 1, + SleepWindow: 1000, + ErrorPercentThreshold: 1, + }) + + failingFunctor := NewFunctor(func() ([]any, error) { + time.Sleep(20 * time.Millisecond) + return nil, errors.New("should time out") + }, "circuitName") + + successFunctor := NewFunctor(func() ([]any, error) { + return []any{"success without circuit"}, nil + }, "circuitName") + + cmd := NewCommand(context.Background(), []*Functor{failingFunctor, successFunctor}) + + require.False(t, IsCircuitOpen("circuitName")) + result := cb.Execute(cmd) + + require.True(t, CircuitExists("circuitName")) + require.Nil(t, result.Error()) + + require.Len(t, result.Result(), 1) + require.Equal(t, result.Result()[0], "success without circuit") + + statuses := result.FunctorCallStatuses() + require.Len(t, statuses, 2) + + require.Equal(t, statuses[0].name, "circuitName") + require.NotNil(t, statuses[0].err) + + require.Equal(t, statuses[1].name, "circuitName") + require.Nil(t, statuses[1].err) +} diff --git a/healthmanager/blockchain_health_manager.go b/healthmanager/blockchain_health_manager.go index d6084ab48..04ba44e23 100644 --- a/healthmanager/blockchain_health_manager.go +++ b/healthmanager/blockchain_health_manager.go @@ -14,6 +14,12 @@ type BlockchainFullStatus struct { StatusPerChainPerProvider map[uint64]map[string]rpcstatus.ProviderStatus `json:"statusPerChainPerProvider"` } +// BlockchainStatus contains the status of the blockchain +type BlockchainStatus struct { + Status rpcstatus.ProviderStatus `json:"status"` + StatusPerChain map[uint64]rpcstatus.ProviderStatus `json:"statusPerChain"` +} + // BlockchainHealthManager manages the state of all providers and aggregates their statuses. type BlockchainHealthManager struct { mu sync.RWMutex diff --git a/node/get_status_node.go b/node/get_status_node.go index ef1ac50fa..dd4e92265 100644 --- a/node/get_status_node.go +++ b/node/get_status_node.go @@ -331,11 +331,11 @@ func (n *StatusNode) setupRPCClient() (err error) { }, } - n.rpcClient, err = rpc.NewClient(gethNodeClient, n.config.NetworkID, n.config.Networks, n.appDB, providerConfigs) + n.rpcClient, err = rpc.NewClient(gethNodeClient, n.config.NetworkID, n.config.Networks, n.appDB, &n.walletFeed, providerConfigs) + n.rpcClient.Start(context.Background()) if err != nil { return } - return } @@ -451,6 +451,7 @@ func (n *StatusNode) stop() error { return err } + n.rpcClient.Stop() n.rpcClient = nil // We need to clear `gethNode` because config is passed to `Start()` // and may be completely different. Similarly with `config`. diff --git a/rpc/chain/blockchain_health_test.go b/rpc/chain/blockchain_health_test.go new file mode 100644 index 000000000..214a850b9 --- /dev/null +++ b/rpc/chain/blockchain_health_test.go @@ -0,0 +1,295 @@ +package chain + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" + mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/status-im/status-go/rpc/chain/ethclient" + "go.uber.org/mock/gomock" +) + +type BlockchainHealthManagerSuite struct { + suite.Suite + blockchainHealthManager *healthmanager.BlockchainHealthManager + mockProviders map[uint64]*healthmanager.ProvidersHealthManager + mockEthClients map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface + clients map[uint64]*ClientWithFallback + mockCtrl *gomock.Controller +} + +func (s *BlockchainHealthManagerSuite) SetupTest() { + s.blockchainHealthManager = healthmanager.NewBlockchainHealthManager() + s.mockProviders = make(map[uint64]*healthmanager.ProvidersHealthManager) + s.mockEthClients = make(map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface) + s.clients = make(map[uint64]*ClientWithFallback) + s.mockCtrl = gomock.NewController(s.T()) +} + +func (s *BlockchainHealthManagerSuite) TearDownTest() { + s.blockchainHealthManager.Stop() + s.mockCtrl.Finish() +} + +func (s *BlockchainHealthManagerSuite) setupClients(chainIDs []uint64) { + ctx := context.Background() + + for _, chainID := range chainIDs { + mockEthClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl) + mockEthClient.EXPECT().GetName().AnyTimes().Return(fmt.Sprintf("test_client_chain_%d", chainID)) + mockEthClient.EXPECT().GetLimiter().AnyTimes().Return(nil) + + phm := healthmanager.NewProvidersHealthManager(chainID) + client := NewClient([]ethclient.RPSLimitedEthClientInterface{mockEthClient}, chainID, phm) + + s.blockchainHealthManager.RegisterProvidersHealthManager(ctx, phm) + + s.mockProviders[chainID] = phm + s.mockEthClients[chainID] = mockEthClient + s.clients[chainID] = client + } +} + +func (s *BlockchainHealthManagerSuite) simulateChainStatus(chainID uint64, up bool) { + client, exists := s.clients[chainID] + require.True(s.T(), exists, "Client for chainID %d not found", chainID) + + mockEthClient := s.mockEthClients[chainID] + ctx := context.Background() + hash := common.HexToHash("0x1234") + + if up { + block := &types.Block{} + mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1) + _, err := client.BlockByHash(ctx, hash) + require.NoError(s.T(), err) + } else { + mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + _, err := client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + } +} + +func (s *BlockchainHealthManagerSuite) waitForStatus(statusCh chan struct{}, expectedStatus rpcstatus.StatusType) { + timeout := time.After(2 * time.Second) + for { + select { + case <-statusCh: + status := s.blockchainHealthManager.Status() + if status.Status == expectedStatus { + return + } + case <-timeout: + s.T().Errorf("Did not receive expected blockchain status update in time") + return + } + } +} + +func (s *BlockchainHealthManagerSuite) TestAllChainsUp() { + s.setupClients([]uint64{1, 2, 3}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, true) + s.simulateChainStatus(2, true) + s.simulateChainStatus(3, true) + + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestSomeChainsDown() { + s.setupClients([]uint64{1, 2, 3}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, true) + s.simulateChainStatus(2, false) + s.simulateChainStatus(3, true) + + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestAllChainsDown() { + s.setupClients([]uint64{1, 2}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, false) + s.simulateChainStatus(2, false) + + s.waitForStatus(statusCh, rpcstatus.StatusDown) +} + +func (s *BlockchainHealthManagerSuite) TestChainStatusChanges() { + s.setupClients([]uint64{1, 2}) + + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + s.simulateChainStatus(1, false) + s.simulateChainStatus(2, false) + s.waitForStatus(statusCh, rpcstatus.StatusDown) + + s.simulateChainStatus(1, true) + s.waitForStatus(statusCh, rpcstatus.StatusUp) +} + +func (s *BlockchainHealthManagerSuite) TestGetFullStatus() { + // Setup clients for chain IDs 1 and 2 + s.setupClients([]uint64{1, 2}) + + // Subscribe to blockchain status updates + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + // Simulate provider statuses for chain 1 + providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain1", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain1", + Timestamp: time.Now(), + Err: errors.New("connection error"), // Down + }, + } + ctx := context.Background() + s.mockProviders[1].Update(ctx, providerCallStatusesChain1) + + // Simulate provider statuses for chain 2 + providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + } + s.mockProviders[2].Update(ctx, providerCallStatusesChain2) + + // Wait for status event to be triggered before getting full status + s.waitForStatus(statusCh, rpcstatus.StatusUp) + + // Get the full status from the BlockchainHealthManager + fullStatus := s.blockchainHealthManager.GetFullStatus() + + // Assert overall blockchain status + require.Equal(s.T(), rpcstatus.StatusUp, fullStatus.Status.Status) + + // Assert provider statuses per chain + require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(1)) + require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(2)) + + // Provider statuses for chain 1 + providerStatusesChain1 := fullStatus.StatusPerChainPerProvider[1] + require.Contains(s.T(), providerStatusesChain1, "provider1_chain1") + require.Contains(s.T(), providerStatusesChain1, "provider2_chain1") + + provider1Chain1Status := providerStatusesChain1["provider1_chain1"] + require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain1Status.Status) + + provider2Chain1Status := providerStatusesChain1["provider2_chain1"] + require.Equal(s.T(), rpcstatus.StatusDown, provider2Chain1Status.Status) + + // Provider statuses for chain 2 + providerStatusesChain2 := fullStatus.StatusPerChainPerProvider[2] + require.Contains(s.T(), providerStatusesChain2, "provider1_chain2") + require.Contains(s.T(), providerStatusesChain2, "provider2_chain2") + + provider1Chain2Status := providerStatusesChain2["provider1_chain2"] + require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain2Status.Status) + + provider2Chain2Status := providerStatusesChain2["provider2_chain2"] + require.Equal(s.T(), rpcstatus.StatusUp, provider2Chain2Status.Status) + + // Serialization to JSON works without errors + jsonData, err := json.MarshalIndent(fullStatus, "", " ") + require.NoError(s.T(), err) + require.NotEmpty(s.T(), jsonData) +} + +func (s *BlockchainHealthManagerSuite) TestGetShortStatus() { + // Setup clients for chain IDs 1 and 2 + s.setupClients([]uint64{1, 2}) + + // Subscribe to blockchain status updates + statusCh := s.blockchainHealthManager.Subscribe() + defer s.blockchainHealthManager.Unsubscribe(statusCh) + + // Simulate provider statuses for chain 1 + providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain1", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain1", + Timestamp: time.Now(), + Err: errors.New("connection error"), // Down + }, + } + ctx := context.Background() + s.mockProviders[1].Update(ctx, providerCallStatusesChain1) + + // Simulate provider statuses for chain 2 + providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{ + { + Name: "provider1_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + { + Name: "provider2_chain2", + Timestamp: time.Now(), + Err: nil, // Up + }, + } + s.mockProviders[2].Update(ctx, providerCallStatusesChain2) + + // Wait for status event to be triggered before getting short status + s.waitForStatus(statusCh, rpcstatus.StatusUp) + + // Get the short status from the BlockchainHealthManager + shortStatus := s.blockchainHealthManager.GetShortStatus() + + // Assert overall blockchain status + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.Status.Status) + + // Assert chain statuses + require.Contains(s.T(), shortStatus.StatusPerChain, uint64(1)) + require.Contains(s.T(), shortStatus.StatusPerChain, uint64(2)) + + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[1].Status) + require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[2].Status) + + // Serialization to JSON works without errors + jsonData, err := json.MarshalIndent(shortStatus, "", " ") + require.NoError(s.T(), err) + require.NotEmpty(s.T(), jsonData) +} + +func TestBlockchainHealthManagerSuite(t *testing.T) { + suite.Run(t, new(BlockchainHealthManagerSuite)) +} diff --git a/rpc/chain/client.go b/rpc/chain/client.go index 24741d161..226c5aa73 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -20,6 +20,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" "github.com/status-im/status-go/circuitbreaker" + "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" "github.com/status-im/status-go/rpc/chain/ethclient" "github.com/status-im/status-go/rpc/chain/rpclimiter" "github.com/status-im/status-go/rpc/chain/tagger" @@ -63,10 +65,11 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte } type ClientWithFallback struct { - ChainID uint64 - ethClients []ethclient.RPSLimitedEthClientInterface - commonLimiter rpclimiter.RequestLimiter - circuitbreaker *circuitbreaker.CircuitBreaker + ChainID uint64 + ethClients []ethclient.RPSLimitedEthClientInterface + commonLimiter rpclimiter.RequestLimiter + circuitbreaker *circuitbreaker.CircuitBreaker + providersHealthManager *healthmanager.ProvidersHealthManager WalletNotifier func(chainId uint64, message string) @@ -111,7 +114,7 @@ var propagateErrors = []error{ bind.ErrNoCode, } -func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64) *ClientWithFallback { +func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64, providersHealthManager *healthmanager.ProvidersHealthManager) *ClientWithFallback { cbConfig := circuitbreaker.Config{ Timeout: 20000, MaxConcurrentRequests: 100, @@ -123,11 +126,12 @@ func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint isConnected.Store(true) return &ClientWithFallback{ - ChainID: chainID, - ethClients: ethClients, - isConnected: isConnected, - LastCheckedAt: time.Now().Unix(), - circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig), + ChainID: chainID, + ethClients: ethClients, + isConnected: isConnected, + LastCheckedAt: time.Now().Unix(), + circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig), + providersHealthManager: providersHealthManager, } } @@ -238,6 +242,10 @@ func (c *ClientWithFallback) makeCall(ctx context.Context, ethClients []ethclien } result := c.circuitbreaker.Execute(cmd) + if c.providersHealthManager != nil { + rpcCallStatuses := convertFunctorCallStatuses(result.FunctorCallStatuses()) + c.providersHealthManager.Update(ctx, rpcCallStatuses) + } if result.Error() != nil { return nil, result.Error() } @@ -842,3 +850,10 @@ func (c *ClientWithFallback) GetCircuitBreaker() *circuitbreaker.CircuitBreaker func (c *ClientWithFallback) SetCircuitBreaker(cb *circuitbreaker.CircuitBreaker) { c.circuitbreaker = cb } + +func convertFunctorCallStatuses(statuses []circuitbreaker.FunctorCallStatus) (result []rpcstatus.RpcProviderCallStatus) { + for _, f := range statuses { + result = append(result, rpcstatus.RpcProviderCallStatus{Name: f.Name, Timestamp: f.Timestamp, Err: f.Err}) + } + return +} diff --git a/rpc/chain/client_health_test.go b/rpc/chain/client_health_test.go new file mode 100644 index 000000000..e5422baab --- /dev/null +++ b/rpc/chain/client_health_test.go @@ -0,0 +1,240 @@ +package chain + +import ( + "context" + "errors" + "github.com/ethereum/go-ethereum/core/vm" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + healthManager "github.com/status-im/status-go/healthmanager" + "github.com/status-im/status-go/healthmanager/rpcstatus" + "github.com/status-im/status-go/rpc/chain/ethclient" + "github.com/status-im/status-go/rpc/chain/rpclimiter" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient" +) + +type ClientWithFallbackSuite struct { + suite.Suite + client *ClientWithFallback + mockEthClients []*mockEthclient.MockRPSLimitedEthClientInterface + providersHealthManager *healthManager.ProvidersHealthManager + mockCtrl *gomock.Controller +} + +func (s *ClientWithFallbackSuite) SetupTest() { + s.mockCtrl = gomock.NewController(s.T()) +} + +func (s *ClientWithFallbackSuite) TearDownTest() { + s.mockCtrl.Finish() +} + +func (s *ClientWithFallbackSuite) setupClients(numClients int) { + s.mockEthClients = make([]*mockEthclient.MockRPSLimitedEthClientInterface, 0) + ethClients := make([]ethclient.RPSLimitedEthClientInterface, 0) + + for i := 0; i < numClients; i++ { + ethClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl) + ethClient.EXPECT().GetName().AnyTimes().Return("test" + strconv.Itoa(i)) + ethClient.EXPECT().GetLimiter().AnyTimes().Return(nil) + + s.mockEthClients = append(s.mockEthClients, ethClient) + ethClients = append(ethClients, ethClient) + } + var chainID uint64 = 0 + s.providersHealthManager = healthManager.NewProvidersHealthManager(chainID) + s.client = NewClient(ethClients, chainID, s.providersHealthManager) +} + +func (s *ClientWithFallbackSuite) TestSingleClientSuccess() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + block := &types.Block{} + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1) + + // WHEN + result, err := s.client.BlockByHash(ctx, hash) + require.NoError(s.T(), err) + require.Equal(s.T(), block, result) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestSingleClientConnectionError() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("connection error")).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown) +} + +func (s *ClientWithFallbackSuite) TestRPSLimitErrorDoesNotMarkChainDown() { + s.setupClients(1) + + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // WHEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1) + + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) + + status := providerStatuses["test0"] + require.Equal(s.T(), status.Status, rpcstatus.StatusUp, "provider shouldn't be DOWN on RPS limit") +} + +func (s *ClientWithFallbackSuite) TestContextCanceledDoesNotMarkChainDown() { + s.setupClients(1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + hash := common.HexToHash("0x1234") + + // WHEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, context.Canceled).Times(1) + + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + require.True(s.T(), errors.Is(err, context.Canceled)) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestVMErrorDoesNotMarkChainDown() { + s.setupClients(1) + ctx := context.Background() + hash := common.HexToHash("0x1234") + vmError := vm.ErrOutOfGas + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, vmError).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + require.True(s.T(), errors.Is(err, vm.ErrOutOfGas)) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 1) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp) +} + +func (s *ClientWithFallbackSuite) TestNoClientsChainUnknown() { + s.setupClients(0) + + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUnknown, chainStatus.Status) +} + +func (s *ClientWithFallbackSuite) TestAllClientsDifferentErrors() { + s.setupClients(3) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1) + s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, vm.ErrOutOfGas).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status) + + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 3) + + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown, "provider test0 should be DOWN due to a connection error") + require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusUp, "provider test1 should not be marked DOWN due to RPS limit error") + require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusUp, "provider test2 should not be labelled DOWN due to a VM error") +} + +func (s *ClientWithFallbackSuite) TestAllClientsNetworkErrors() { + s.setupClients(3) + ctx := context.Background() + hash := common.HexToHash("0x1234") + + // GIVEN + s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1) + + // WHEN + _, err := s.client.BlockByHash(ctx, hash) + require.Error(s.T(), err) + + // THEN + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status) + + providerStatuses := s.providersHealthManager.GetStatuses() + require.Len(s.T(), providerStatuses, 3) + require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown) + require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusDown) + require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusDown) +} + +func (s *ClientWithFallbackSuite) TestChainStatusUnknownWhenAllProvidersUnknown() { + s.setupClients(2) + + chainStatus := s.providersHealthManager.Status() + require.Equal(s.T(), rpcstatus.StatusUnknown, chainStatus.Status) +} + +func TestClientWithFallbackSuite(t *testing.T) { + suite.Run(t, new(ClientWithFallbackSuite)) +} diff --git a/rpc/chain/client_test.go b/rpc/chain/client_test.go index b73d5b54e..13a7577a5 100644 --- a/rpc/chain/client_test.go +++ b/rpc/chain/client_test.go @@ -32,7 +32,7 @@ func setupClientTest(t *testing.T) (*ClientWithFallback, []*mock_ethclient.MockR ethClients = append(ethClients, ethCl) } - client := NewClient(ethClients, 0) + client := NewClient(ethClients, 0, nil) cleanup := func() { mockCtrl.Finish() diff --git a/rpc/client.go b/rpc/client.go index 0ac3d5678..4def86455 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -19,7 +19,9 @@ import ( "github.com/ethereum/go-ethereum/log" gethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/ethereum/go-ethereum/event" appCommon "github.com/status-im/status-go/common" + "github.com/status-im/status-go/healthmanager" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/chain/ethclient" @@ -27,6 +29,7 @@ import ( "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/rpcstats" "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" ) const ( @@ -48,6 +51,8 @@ const ( // rpcUserAgentUpstreamFormat a separate user agent format for upstream, because we should not be using upstream // if we see this user agent in the logs that means parts of the application are using a malconfigured http client rpcUserAgentUpstreamFormat = "procuratee-%s-upstream/%s" + + EventBlockchainHealthChanged walletevent.EventType = "wallet-blockchain-health-changed" // Full status of the blockchain (including provider statuses) ) // List of RPC client errors. @@ -101,6 +106,10 @@ type Client struct { router *router NetworkManager *network.Manager + healthMgr *healthmanager.BlockchainHealthManager + stopMonitoringFunc context.CancelFunc + walletFeed *event.Feed + handlersMx sync.RWMutex // mx guards handlers handlers map[string]Handler // locally registered handlers log log.Logger @@ -116,7 +125,7 @@ var verifProxyInitFn func(c *Client) // // Client is safe for concurrent use and will automatically // reconnect to the server if connection is lost. -func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params.Network, db *sql.DB, providerConfigs []params.ProviderConfig) (*Client, error) { +func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params.Network, db *sql.DB, walletFeed *event.Feed, providerConfigs []params.ProviderConfig) (*Client, error) { var err error log := log.New("package", "status-go/rpc.Client") @@ -138,6 +147,8 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params limiterPerProvider: make(map[string]*rpclimiter.RPCRpsLimiter), log: log, providerConfigs: providerConfigs, + healthMgr: healthmanager.NewBlockchainHealthManager(), + walletFeed: walletFeed, } c.UpstreamChainID = upstreamChainID @@ -150,6 +161,55 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params return &c, nil } +func (c *Client) Start(ctx context.Context) { + if c.stopMonitoringFunc != nil { + c.log.Warn("Blockchain health manager already started") + return + } + + cancelableCtx, cancel := context.WithCancel(ctx) + c.stopMonitoringFunc = cancel + statusCh := c.healthMgr.Subscribe() + go c.monitorHealth(cancelableCtx, statusCh) +} + +func (c *Client) Stop() { + c.healthMgr.Stop() + if c.stopMonitoringFunc == nil { + return + } + c.stopMonitoringFunc() + c.stopMonitoringFunc = nil +} + +func (c *Client) monitorHealth(ctx context.Context, statusCh chan struct{}) { + sendFullStatusEventFunc := func() { + blockchainStatus := c.healthMgr.GetFullStatus() + encodedMessage, err := json.Marshal(blockchainStatus) + if err != nil { + c.log.Warn("could not marshal full blockchain status", "error", err) + return + } + if c.walletFeed == nil { + return + } + c.walletFeed.Send(walletevent.Event{ + Type: EventBlockchainHealthChanged, + Message: string(encodedMessage), + At: time.Now().Unix(), + }) + } + + for { + select { + case <-ctx.Done(): + return + case <-statusCh: + sendFullStatusEventFunc() + } + } +} + func (c *Client) GetNetworkManager() *network.Manager { return c.NetworkManager } @@ -207,8 +267,10 @@ func (c *Client) getClientUsingCache(chainID uint64) (chain.ClientInterface, err return nil, fmt.Errorf("could not find any RPC URL for chain: %d", chainID) } - client := chain.NewClient(ethClients, chainID) - client.SetWalletNotifier(c.walletNotifier) + phm := healthmanager.NewProvidersHealthManager(chainID) + c.healthMgr.RegisterProvidersHealthManager(context.Background(), phm) + + client := chain.NewClient(ethClients, chainID, phm) c.rpcClients[chainID] = client return client, nil } diff --git a/rpc/client_test.go b/rpc/client_test.go index 877a1af04..a0ad79af4 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -44,7 +44,7 @@ func TestBlockedRoutesCall(t *testing.T) { gethRPCClient, err := gethrpc.Dial(ts.URL) require.NoError(t, err) - c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) + c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil, nil) require.NoError(t, err) for _, m := range blockedMethods { @@ -83,7 +83,7 @@ func TestBlockedRoutesRawCall(t *testing.T) { gethRPCClient, err := gethrpc.Dial(ts.URL) require.NoError(t, err) - c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) + c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil, nil) require.NoError(t, err) for _, m := range blockedMethods { @@ -142,7 +142,7 @@ func TestGetClientsUsingCache(t *testing.T) { DefaultFallbackURL2: server.URL + path3, }, } - c, err := NewClient(nil, 1, networks, db, providerConfigs) + c, err := NewClient(nil, 1, networks, db, nil, providerConfigs) require.NoError(t, err) // Networks from DB must pick up DefaultRPCURL, DefaultFallbackURL, DefaultFallbackURL2 diff --git a/services/wallet/common/mock/feed_subscription.go b/services/wallet/common/mock/feed_subscription.go new file mode 100644 index 000000000..b508ded5f --- /dev/null +++ b/services/wallet/common/mock/feed_subscription.go @@ -0,0 +1,46 @@ +package mock_common + +import ( + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/status-im/status-go/services/wallet/walletevent" +) + +type FeedSubscription struct { + events chan walletevent.Event + feed *event.Feed + done chan struct{} +} + +func NewFeedSubscription(feed *event.Feed) *FeedSubscription { + events := make(chan walletevent.Event, 100) + done := make(chan struct{}) + + subscription := feed.Subscribe(events) + + go func() { + <-done + subscription.Unsubscribe() + close(events) + }() + + return &FeedSubscription{events: events, feed: feed, done: done} +} + +func (f *FeedSubscription) WaitForEvent(timeout time.Duration) (walletevent.Event, bool) { + select { + case evt := <-f.events: + return evt, true + case <-time.After(timeout): + return walletevent.Event{}, false + } +} + +func (f *FeedSubscription) GetFeed() *event.Feed { + return f.feed +} + +func (f *FeedSubscription) Close() { + close(f.done) +} diff --git a/services/wallet/market/market_feed_test.go b/services/wallet/market/market_feed_test.go new file mode 100644 index 000000000..cccdb5930 --- /dev/null +++ b/services/wallet/market/market_feed_test.go @@ -0,0 +1,73 @@ +package market + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + + "github.com/ethereum/go-ethereum/event" + mock_common "github.com/status-im/status-go/services/wallet/common/mock" + mock_market "github.com/status-im/status-go/services/wallet/market/mock" + "github.com/status-im/status-go/services/wallet/thirdparty" +) + +type MarketTestSuite struct { + suite.Suite + feedSub *mock_common.FeedSubscription + symbols []string + currencies []string +} + +func (s *MarketTestSuite) SetupTest() { + feed := new(event.Feed) + s.feedSub = mock_common.NewFeedSubscription(feed) + + s.symbols = []string{"BTC", "ETH"} + s.currencies = []string{"USD", "EUR"} +} + +func (s *MarketTestSuite) TearDownTest() { + s.feedSub.Close() +} + +func (s *MarketTestSuite) TestEventOnRpsError() { + ctrl := gomock.NewController(s.T()) + defer ctrl.Finish() + // GIVEN + customErr := errors.New("request rate exceeded") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed()) + + // WHEN + _, err := manager.FetchPrices(s.symbols, s.currencies) + s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError") + event, ok := s.feedSub.WaitForEvent(5 * time.Second) + s.Require().True(ok, "expected an event, but none was received") + + // THEN + s.Require().Equal(event.Type, EventMarketStatusChanged) +} + +func (s *MarketTestSuite) TestNoEventOnNetworkError() { + ctrl := gomock.NewController(s.T()) + defer ctrl.Finish() + + // GIVEN + customErr := errors.New("dial tcp: lookup optimism-goerli.infura.io: no such host") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed()) + + _, err := manager.FetchPrices(s.symbols, s.currencies) + s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError") + _, ok := s.feedSub.WaitForEvent(time.Millisecond * 500) + + //THEN + s.Require().False(ok, "expected no event, but one was received") +} + +func TestMarketTestSuite(t *testing.T) { + suite.Run(t, new(MarketTestSuite)) +} diff --git a/services/wallet/market/market_test.go b/services/wallet/market/market_test.go index a7d4f0660..e646fc6ac 100644 --- a/services/wallet/market/market_test.go +++ b/services/wallet/market/market_test.go @@ -10,48 +10,11 @@ import ( "github.com/stretchr/testify/require" + mock_market "github.com/status-im/status-go/services/wallet/market/mock" "github.com/status-im/status-go/services/wallet/thirdparty" mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock" ) -type MockPriceProvider struct { - mock_thirdparty.MockMarketDataProvider - mockPrices map[string]map[string]float64 -} - -func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider { - return &MockPriceProvider{ - MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl), - } -} - -func (mpp *MockPriceProvider) setMockPrices(prices map[string]map[string]float64) { - mpp.mockPrices = prices -} - -func (mpp *MockPriceProvider) ID() string { - return "MockPriceProvider" -} - -func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { - res := make(map[string]map[string]float64) - for _, symbol := range symbols { - res[symbol] = make(map[string]float64) - for _, currency := range currencies { - res[symbol][currency] = mpp.mockPrices[symbol][currency] - } - } - return res, nil -} - -type MockPriceProviderWithError struct { - MockPriceProvider -} - -func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { - return nil, errors.New("error") -} - func setupMarketManager(t *testing.T, providers []thirdparty.MarketDataProvider) *Manager { return NewManager(providers, &event.Feed{}) } @@ -80,8 +43,8 @@ var mockPrices = map[string]map[string]float64{ func TestPrice(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - priceProvider := NewMockPriceProvider(ctrl) - priceProvider.setMockPrices(mockPrices) + priceProvider := mock_market.NewMockPriceProvider(ctrl) + priceProvider.SetMockPrices(mockPrices) manager := setupMarketManager(t, []thirdparty.MarketDataProvider{priceProvider, priceProvider}) @@ -125,9 +88,12 @@ func TestPrice(t *testing.T) { func TestFetchPriceErrorFirstProvider(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - priceProvider := NewMockPriceProvider(ctrl) - priceProvider.setMockPrices(mockPrices) - priceProviderWithError := &MockPriceProviderWithError{} + priceProvider := mock_market.NewMockPriceProvider(ctrl) + priceProvider.SetMockPrices(mockPrices) + + customErr := errors.New("error") + priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr) + symbols := []string{"BTC", "ETH"} currencies := []string{"USD", "EUR"} diff --git a/services/wallet/market/mock/mock_price_provider.go b/services/wallet/market/mock/mock_price_provider.go new file mode 100644 index 000000000..5c0170d82 --- /dev/null +++ b/services/wallet/market/mock/mock_price_provider.go @@ -0,0 +1,54 @@ +package mock_market + +import ( + "go.uber.org/mock/gomock" + + mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock" +) + +type MockPriceProvider struct { + mock_thirdparty.MockMarketDataProvider + mockPrices map[string]map[string]float64 +} + +func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider { + return &MockPriceProvider{ + MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl), + } +} + +func (mpp *MockPriceProvider) SetMockPrices(prices map[string]map[string]float64) { + mpp.mockPrices = prices +} + +func (mpp *MockPriceProvider) ID() string { + return "MockPriceProvider" +} + +func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { + res := make(map[string]map[string]float64) + for _, symbol := range symbols { + res[symbol] = make(map[string]float64) + for _, currency := range currencies { + res[symbol][currency] = mpp.mockPrices[symbol][currency] + } + } + return res, nil +} + +type MockPriceProviderWithError struct { + MockPriceProvider + err error +} + +// NewMockPriceProviderWithError creates a new MockPriceProviderWithError with the specified error +func NewMockPriceProviderWithError(ctrl *gomock.Controller, err error) *MockPriceProviderWithError { + return &MockPriceProviderWithError{ + MockPriceProvider: *NewMockPriceProvider(ctrl), + err: err, + } +} + +func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) { + return nil, mpp.err +}