feat(wallet)_: Send a new event 'wallet-blockchain-health-changed' #5923

Keeping old event for the backward compatibility
This commit is contained in:
Andrey Bocharnikov 2024-10-07 15:16:50 +04:00
parent 3ad928a627
commit 3c9f9cb9c6
14 changed files with 1010 additions and 73 deletions

View File

@ -3,6 +3,7 @@ package circuitbreaker
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/afex/hystrix-go/hystrix" "github.com/afex/hystrix-go/hystrix"
@ -12,8 +13,16 @@ import (
type FallbackFunc func() ([]any, error) type FallbackFunc func() ([]any, error)
type CommandResult struct { type CommandResult struct {
res []any res []any
err error err error
functorCallStatuses []FunctorCallStatus
cancelled bool
}
type FunctorCallStatus struct {
Name string
Timestamp time.Time
Err error
} }
func (cr CommandResult) Result() []any { func (cr CommandResult) Result() []any {
@ -23,6 +32,21 @@ func (cr CommandResult) Result() []any {
func (cr CommandResult) Error() error { func (cr CommandResult) Error() error {
return cr.err return cr.err
} }
func (cr CommandResult) Cancelled() bool {
return cr.cancelled
}
func (cr CommandResult) FunctorCallStatuses() []FunctorCallStatus {
return cr.functorCallStatuses
}
func (cr *CommandResult) addCallStatus(circuitName string, err error) {
cr.functorCallStatuses = append(cr.functorCallStatuses, FunctorCallStatus{
Name: circuitName,
Timestamp: time.Now(),
Err: err,
})
}
type Command struct { type Command struct {
ctx context.Context ctx context.Context
@ -106,23 +130,26 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult {
for i, f := range cmd.functors { for i, f := range cmd.functors {
if cmd.cancel { if cmd.cancel {
result.cancelled = true
break break
} }
var err error var err error
circuitName := f.circuitName
if cb.circuitNameHandler != nil {
circuitName = cb.circuitNameHandler(circuitName)
}
// if last command, execute without circuit // if last command, execute without circuit
if i == len(cmd.functors)-1 { if i == len(cmd.functors)-1 {
res, execErr := f.exec() res, execErr := f.exec()
err = execErr err = execErr
if err == nil { if err == nil {
result = CommandResult{res: res} result.res = res
result.err = nil
} }
result.addCallStatus(circuitName, err)
} else { } else {
circuitName := f.circuitName
if cb.circuitNameHandler != nil {
circuitName = cb.circuitNameHandler(circuitName)
}
if hystrix.GetCircuitSettings()[circuitName] == nil { if hystrix.GetCircuitSettings()[circuitName] == nil {
hystrix.ConfigureCommand(circuitName, hystrix.CommandConfig{ hystrix.ConfigureCommand(circuitName, hystrix.CommandConfig{
Timeout: cb.config.Timeout, Timeout: cb.config.Timeout,
@ -137,13 +164,16 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult {
res, err := f.exec() res, err := f.exec()
// Write to result only if success // Write to result only if success
if err == nil { if err == nil {
result = CommandResult{res: res} result.res = res
result.err = nil
} }
result.addCallStatus(circuitName, err)
// If the command has been cancelled, we don't count // If the command has been cancelled, we don't count
// the error towars breaking the circuit, and then we break // the error towars breaking the circuit, and then we break
if cmd.cancel { if cmd.cancel {
result = accumulateCommandError(result, f.circuitName, err) result = accumulateCommandError(result, circuitName, err)
result.cancelled = true
return nil return nil
} }
if err != nil { if err != nil {
@ -156,7 +186,7 @@ func (cb *CircuitBreaker) Execute(cmd *Command) CommandResult {
break break
} }
result = accumulateCommandError(result, f.circuitName, err) result = accumulateCommandError(result, circuitName, err)
// Lets abuse every provider with the same amount of MaxConcurrentRequests, // Lets abuse every provider with the same amount of MaxConcurrentRequests,
// keep iterating even in case of ErrMaxConcurrency error // keep iterating even in case of ErrMaxConcurrency error

View File

@ -34,6 +34,7 @@ func TestCircuitBreaker_ExecuteSuccessSingle(t *testing.T) {
result := cb.Execute(cmd) result := cb.Execute(cmd)
require.NoError(t, result.Error()) require.NoError(t, result.Error())
require.Equal(t, expectedResult, result.Result()[0].(string)) require.Equal(t, expectedResult, result.Result()[0].(string))
require.False(t, result.Cancelled())
} }
func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) { func TestCircuitBreaker_ExecuteMultipleFallbacksFail(t *testing.T) {
@ -219,9 +220,11 @@ func TestCircuitBreaker_CommandCancel(t *testing.T) {
result := cb.Execute(cmd) result := cb.Execute(cmd)
require.True(t, errors.Is(result.Error(), expectedErr)) require.True(t, errors.Is(result.Error(), expectedErr))
require.True(t, result.Cancelled())
assert.Equal(t, 1, prov1Called) assert.Equal(t, 1, prov1Called)
assert.Equal(t, 0, prov2Called) assert.Equal(t, 0, prov2Called)
} }
func TestCircuitBreaker_EmptyOrNilCommand(t *testing.T) { func TestCircuitBreaker_EmptyOrNilCommand(t *testing.T) {
@ -301,3 +304,149 @@ func TestCircuitBreaker_Fallback(t *testing.T) {
assert.Equal(t, 1, prov1Called) assert.Equal(t, 1, prov1Called)
} }
func TestCircuitBreaker_SuccessCallStatus(t *testing.T) {
cb := NewCircuitBreaker(Config{})
functor := NewFunctor(func() ([]any, error) {
return []any{"success"}, nil
}, "successCircuit")
cmd := NewCommand(context.Background(), []*Functor{functor})
result := cb.Execute(cmd)
require.Nil(t, result.Error())
require.False(t, result.Cancelled())
assert.Len(t, result.Result(), 1)
require.Equal(t, "success", result.Result()[0])
assert.Len(t, result.FunctorCallStatuses(), 1)
status := result.FunctorCallStatuses()[0]
if status.name != "successCircuit" {
t.Errorf("Expected functor name to be 'successCircuit', got %s", status.name)
}
if status.err != nil {
t.Errorf("Expected no error in functor status, got %v", status.err)
}
}
func TestCircuitBreaker_ErrorCallStatus(t *testing.T) {
cb := NewCircuitBreaker(Config{})
expectedError := errors.New("functor error")
functor := NewFunctor(func() ([]any, error) {
return nil, expectedError
}, "errorCircuit")
cmd := NewCommand(context.Background(), []*Functor{functor})
result := cb.Execute(cmd)
require.NotNil(t, result.Error())
require.True(t, errors.Is(result.Error(), expectedError))
assert.Len(t, result.Result(), 0)
assert.Len(t, result.FunctorCallStatuses(), 1)
status := result.FunctorCallStatuses()[0]
if status.name != "errorCircuit" {
t.Errorf("Expected functor name to be 'errorCircuit', got %s", status.name)
}
if !errors.Is(status.err, expectedError) {
t.Errorf("Expected functor error to be '%v', got '%v'", expectedError, status.err)
}
}
func TestCircuitBreaker_CancelledResult(t *testing.T) {
cb := NewCircuitBreaker(Config{Timeout: 1000})
functor := NewFunctor(func() ([]any, error) {
time.Sleep(500 * time.Millisecond)
return []any{"should not be returned"}, nil
}, "cancelCircuit")
cmd := NewCommand(context.Background(), []*Functor{functor})
cmd.Cancel()
result := cb.Execute(cmd)
assert.True(t, result.Cancelled())
require.Nil(t, result.Error())
require.Empty(t, result.Result())
require.Empty(t, result.FunctorCallStatuses())
}
func TestCircuitBreaker_MultipleFunctorsResult(t *testing.T) {
cb := NewCircuitBreaker(Config{
Timeout: 1000,
MaxConcurrentRequests: 100,
RequestVolumeThreshold: 20,
SleepWindow: 5000,
ErrorPercentThreshold: 50,
})
functor1 := NewFunctor(func() ([]any, error) {
return nil, errors.New("functor1 error")
}, "circuit1")
functor2 := NewFunctor(func() ([]any, error) {
return []any{"success from functor2"}, nil
}, "circuit2")
cmd := NewCommand(context.Background(), []*Functor{functor1, functor2})
result := cb.Execute(cmd)
require.Nil(t, result.Error())
require.Len(t, result.Result(), 1)
require.Equal(t, result.Result()[0], "success from functor2")
statuses := result.FunctorCallStatuses()
require.Len(t, statuses, 2)
require.Equal(t, statuses[0].name, "circuit1")
require.NotNil(t, statuses[0].err)
require.Equal(t, statuses[1].name, "circuit2")
require.Nil(t, statuses[1].err)
}
func TestCircuitBreaker_LastFunctorDirectExecution(t *testing.T) {
cb := NewCircuitBreaker(Config{
Timeout: 10, // short timeout to open circuit
MaxConcurrentRequests: 1,
RequestVolumeThreshold: 1,
SleepWindow: 1000,
ErrorPercentThreshold: 1,
})
failingFunctor := NewFunctor(func() ([]any, error) {
time.Sleep(20 * time.Millisecond)
return nil, errors.New("should time out")
}, "circuitName")
successFunctor := NewFunctor(func() ([]any, error) {
return []any{"success without circuit"}, nil
}, "circuitName")
cmd := NewCommand(context.Background(), []*Functor{failingFunctor, successFunctor})
require.False(t, IsCircuitOpen("circuitName"))
result := cb.Execute(cmd)
require.True(t, CircuitExists("circuitName"))
require.Nil(t, result.Error())
require.Len(t, result.Result(), 1)
require.Equal(t, result.Result()[0], "success without circuit")
statuses := result.FunctorCallStatuses()
require.Len(t, statuses, 2)
require.Equal(t, statuses[0].name, "circuitName")
require.NotNil(t, statuses[0].err)
require.Equal(t, statuses[1].name, "circuitName")
require.Nil(t, statuses[1].err)
}

View File

@ -14,6 +14,12 @@ type BlockchainFullStatus struct {
StatusPerChainPerProvider map[uint64]map[string]rpcstatus.ProviderStatus `json:"statusPerChainPerProvider"` StatusPerChainPerProvider map[uint64]map[string]rpcstatus.ProviderStatus `json:"statusPerChainPerProvider"`
} }
// BlockchainStatus contains the status of the blockchain
type BlockchainStatus struct {
Status rpcstatus.ProviderStatus `json:"status"`
StatusPerChain map[uint64]rpcstatus.ProviderStatus `json:"statusPerChain"`
}
// BlockchainHealthManager manages the state of all providers and aggregates their statuses. // BlockchainHealthManager manages the state of all providers and aggregates their statuses.
type BlockchainHealthManager struct { type BlockchainHealthManager struct {
mu sync.RWMutex mu sync.RWMutex

View File

@ -331,11 +331,11 @@ func (n *StatusNode) setupRPCClient() (err error) {
}, },
} }
n.rpcClient, err = rpc.NewClient(gethNodeClient, n.config.NetworkID, n.config.Networks, n.appDB, providerConfigs) n.rpcClient, err = rpc.NewClient(gethNodeClient, n.config.NetworkID, n.config.Networks, n.appDB, &n.walletFeed, providerConfigs)
n.rpcClient.Start(context.Background())
if err != nil { if err != nil {
return return
} }
return return
} }
@ -451,6 +451,7 @@ func (n *StatusNode) stop() error {
return err return err
} }
n.rpcClient.Stop()
n.rpcClient = nil n.rpcClient = nil
// We need to clear `gethNode` because config is passed to `Start()` // We need to clear `gethNode` because config is passed to `Start()`
// and may be completely different. Similarly with `config`. // and may be completely different. Similarly with `config`.

View File

@ -0,0 +1,295 @@
package chain
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/status-im/status-go/healthmanager"
"github.com/status-im/status-go/healthmanager/rpcstatus"
mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/rpc/chain/ethclient"
"go.uber.org/mock/gomock"
)
type BlockchainHealthManagerSuite struct {
suite.Suite
blockchainHealthManager *healthmanager.BlockchainHealthManager
mockProviders map[uint64]*healthmanager.ProvidersHealthManager
mockEthClients map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface
clients map[uint64]*ClientWithFallback
mockCtrl *gomock.Controller
}
func (s *BlockchainHealthManagerSuite) SetupTest() {
s.blockchainHealthManager = healthmanager.NewBlockchainHealthManager()
s.mockProviders = make(map[uint64]*healthmanager.ProvidersHealthManager)
s.mockEthClients = make(map[uint64]*mockEthclient.MockRPSLimitedEthClientInterface)
s.clients = make(map[uint64]*ClientWithFallback)
s.mockCtrl = gomock.NewController(s.T())
}
func (s *BlockchainHealthManagerSuite) TearDownTest() {
s.blockchainHealthManager.Stop()
s.mockCtrl.Finish()
}
func (s *BlockchainHealthManagerSuite) setupClients(chainIDs []uint64) {
ctx := context.Background()
for _, chainID := range chainIDs {
mockEthClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl)
mockEthClient.EXPECT().GetName().AnyTimes().Return(fmt.Sprintf("test_client_chain_%d", chainID))
mockEthClient.EXPECT().GetLimiter().AnyTimes().Return(nil)
phm := healthmanager.NewProvidersHealthManager(chainID)
client := NewClient([]ethclient.RPSLimitedEthClientInterface{mockEthClient}, chainID, phm)
s.blockchainHealthManager.RegisterProvidersHealthManager(ctx, phm)
s.mockProviders[chainID] = phm
s.mockEthClients[chainID] = mockEthClient
s.clients[chainID] = client
}
}
func (s *BlockchainHealthManagerSuite) simulateChainStatus(chainID uint64, up bool) {
client, exists := s.clients[chainID]
require.True(s.T(), exists, "Client for chainID %d not found", chainID)
mockEthClient := s.mockEthClients[chainID]
ctx := context.Background()
hash := common.HexToHash("0x1234")
if up {
block := &types.Block{}
mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1)
_, err := client.BlockByHash(ctx, hash)
require.NoError(s.T(), err)
} else {
mockEthClient.EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1)
_, err := client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
}
}
func (s *BlockchainHealthManagerSuite) waitForStatus(statusCh chan struct{}, expectedStatus rpcstatus.StatusType) {
timeout := time.After(2 * time.Second)
for {
select {
case <-statusCh:
status := s.blockchainHealthManager.Status()
if status.Status == expectedStatus {
return
}
case <-timeout:
s.T().Errorf("Did not receive expected blockchain status update in time")
return
}
}
}
func (s *BlockchainHealthManagerSuite) TestAllChainsUp() {
s.setupClients([]uint64{1, 2, 3})
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
s.simulateChainStatus(1, true)
s.simulateChainStatus(2, true)
s.simulateChainStatus(3, true)
s.waitForStatus(statusCh, rpcstatus.StatusUp)
}
func (s *BlockchainHealthManagerSuite) TestSomeChainsDown() {
s.setupClients([]uint64{1, 2, 3})
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
s.simulateChainStatus(1, true)
s.simulateChainStatus(2, false)
s.simulateChainStatus(3, true)
s.waitForStatus(statusCh, rpcstatus.StatusUp)
}
func (s *BlockchainHealthManagerSuite) TestAllChainsDown() {
s.setupClients([]uint64{1, 2})
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
s.simulateChainStatus(1, false)
s.simulateChainStatus(2, false)
s.waitForStatus(statusCh, rpcstatus.StatusDown)
}
func (s *BlockchainHealthManagerSuite) TestChainStatusChanges() {
s.setupClients([]uint64{1, 2})
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
s.simulateChainStatus(1, false)
s.simulateChainStatus(2, false)
s.waitForStatus(statusCh, rpcstatus.StatusDown)
s.simulateChainStatus(1, true)
s.waitForStatus(statusCh, rpcstatus.StatusUp)
}
func (s *BlockchainHealthManagerSuite) TestGetFullStatus() {
// Setup clients for chain IDs 1 and 2
s.setupClients([]uint64{1, 2})
// Subscribe to blockchain status updates
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
// Simulate provider statuses for chain 1
providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{
{
Name: "provider1_chain1",
Timestamp: time.Now(),
Err: nil, // Up
},
{
Name: "provider2_chain1",
Timestamp: time.Now(),
Err: errors.New("connection error"), // Down
},
}
ctx := context.Background()
s.mockProviders[1].Update(ctx, providerCallStatusesChain1)
// Simulate provider statuses for chain 2
providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{
{
Name: "provider1_chain2",
Timestamp: time.Now(),
Err: nil, // Up
},
{
Name: "provider2_chain2",
Timestamp: time.Now(),
Err: nil, // Up
},
}
s.mockProviders[2].Update(ctx, providerCallStatusesChain2)
// Wait for status event to be triggered before getting full status
s.waitForStatus(statusCh, rpcstatus.StatusUp)
// Get the full status from the BlockchainHealthManager
fullStatus := s.blockchainHealthManager.GetFullStatus()
// Assert overall blockchain status
require.Equal(s.T(), rpcstatus.StatusUp, fullStatus.Status.Status)
// Assert provider statuses per chain
require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(1))
require.Contains(s.T(), fullStatus.StatusPerChainPerProvider, uint64(2))
// Provider statuses for chain 1
providerStatusesChain1 := fullStatus.StatusPerChainPerProvider[1]
require.Contains(s.T(), providerStatusesChain1, "provider1_chain1")
require.Contains(s.T(), providerStatusesChain1, "provider2_chain1")
provider1Chain1Status := providerStatusesChain1["provider1_chain1"]
require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain1Status.Status)
provider2Chain1Status := providerStatusesChain1["provider2_chain1"]
require.Equal(s.T(), rpcstatus.StatusDown, provider2Chain1Status.Status)
// Provider statuses for chain 2
providerStatusesChain2 := fullStatus.StatusPerChainPerProvider[2]
require.Contains(s.T(), providerStatusesChain2, "provider1_chain2")
require.Contains(s.T(), providerStatusesChain2, "provider2_chain2")
provider1Chain2Status := providerStatusesChain2["provider1_chain2"]
require.Equal(s.T(), rpcstatus.StatusUp, provider1Chain2Status.Status)
provider2Chain2Status := providerStatusesChain2["provider2_chain2"]
require.Equal(s.T(), rpcstatus.StatusUp, provider2Chain2Status.Status)
// Serialization to JSON works without errors
jsonData, err := json.MarshalIndent(fullStatus, "", " ")
require.NoError(s.T(), err)
require.NotEmpty(s.T(), jsonData)
}
func (s *BlockchainHealthManagerSuite) TestGetShortStatus() {
// Setup clients for chain IDs 1 and 2
s.setupClients([]uint64{1, 2})
// Subscribe to blockchain status updates
statusCh := s.blockchainHealthManager.Subscribe()
defer s.blockchainHealthManager.Unsubscribe(statusCh)
// Simulate provider statuses for chain 1
providerCallStatusesChain1 := []rpcstatus.RpcProviderCallStatus{
{
Name: "provider1_chain1",
Timestamp: time.Now(),
Err: nil, // Up
},
{
Name: "provider2_chain1",
Timestamp: time.Now(),
Err: errors.New("connection error"), // Down
},
}
ctx := context.Background()
s.mockProviders[1].Update(ctx, providerCallStatusesChain1)
// Simulate provider statuses for chain 2
providerCallStatusesChain2 := []rpcstatus.RpcProviderCallStatus{
{
Name: "provider1_chain2",
Timestamp: time.Now(),
Err: nil, // Up
},
{
Name: "provider2_chain2",
Timestamp: time.Now(),
Err: nil, // Up
},
}
s.mockProviders[2].Update(ctx, providerCallStatusesChain2)
// Wait for status event to be triggered before getting short status
s.waitForStatus(statusCh, rpcstatus.StatusUp)
// Get the short status from the BlockchainHealthManager
shortStatus := s.blockchainHealthManager.GetShortStatus()
// Assert overall blockchain status
require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.Status.Status)
// Assert chain statuses
require.Contains(s.T(), shortStatus.StatusPerChain, uint64(1))
require.Contains(s.T(), shortStatus.StatusPerChain, uint64(2))
require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[1].Status)
require.Equal(s.T(), rpcstatus.StatusUp, shortStatus.StatusPerChain[2].Status)
// Serialization to JSON works without errors
jsonData, err := json.MarshalIndent(shortStatus, "", " ")
require.NoError(s.T(), err)
require.NotEmpty(s.T(), jsonData)
}
func TestBlockchainHealthManagerSuite(t *testing.T) {
suite.Run(t, new(BlockchainHealthManagerSuite))
}

View File

@ -20,6 +20,8 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/circuitbreaker" "github.com/status-im/status-go/circuitbreaker"
"github.com/status-im/status-go/healthmanager"
"github.com/status-im/status-go/healthmanager/rpcstatus"
"github.com/status-im/status-go/rpc/chain/ethclient" "github.com/status-im/status-go/rpc/chain/ethclient"
"github.com/status-im/status-go/rpc/chain/rpclimiter" "github.com/status-im/status-go/rpc/chain/rpclimiter"
"github.com/status-im/status-go/rpc/chain/tagger" "github.com/status-im/status-go/rpc/chain/tagger"
@ -63,10 +65,11 @@ func ClientWithTag(chainClient ClientInterface, tag, groupTag string) ClientInte
} }
type ClientWithFallback struct { type ClientWithFallback struct {
ChainID uint64 ChainID uint64
ethClients []ethclient.RPSLimitedEthClientInterface ethClients []ethclient.RPSLimitedEthClientInterface
commonLimiter rpclimiter.RequestLimiter commonLimiter rpclimiter.RequestLimiter
circuitbreaker *circuitbreaker.CircuitBreaker circuitbreaker *circuitbreaker.CircuitBreaker
providersHealthManager *healthmanager.ProvidersHealthManager
WalletNotifier func(chainId uint64, message string) WalletNotifier func(chainId uint64, message string)
@ -111,7 +114,7 @@ var propagateErrors = []error{
bind.ErrNoCode, bind.ErrNoCode,
} }
func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64) *ClientWithFallback { func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint64, providersHealthManager *healthmanager.ProvidersHealthManager) *ClientWithFallback {
cbConfig := circuitbreaker.Config{ cbConfig := circuitbreaker.Config{
Timeout: 20000, Timeout: 20000,
MaxConcurrentRequests: 100, MaxConcurrentRequests: 100,
@ -123,11 +126,12 @@ func NewClient(ethClients []ethclient.RPSLimitedEthClientInterface, chainID uint
isConnected.Store(true) isConnected.Store(true)
return &ClientWithFallback{ return &ClientWithFallback{
ChainID: chainID, ChainID: chainID,
ethClients: ethClients, ethClients: ethClients,
isConnected: isConnected, isConnected: isConnected,
LastCheckedAt: time.Now().Unix(), LastCheckedAt: time.Now().Unix(),
circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig), circuitbreaker: circuitbreaker.NewCircuitBreaker(cbConfig),
providersHealthManager: providersHealthManager,
} }
} }
@ -238,6 +242,10 @@ func (c *ClientWithFallback) makeCall(ctx context.Context, ethClients []ethclien
} }
result := c.circuitbreaker.Execute(cmd) result := c.circuitbreaker.Execute(cmd)
if c.providersHealthManager != nil {
rpcCallStatuses := convertFunctorCallStatuses(result.FunctorCallStatuses())
c.providersHealthManager.Update(ctx, rpcCallStatuses)
}
if result.Error() != nil { if result.Error() != nil {
return nil, result.Error() return nil, result.Error()
} }
@ -842,3 +850,10 @@ func (c *ClientWithFallback) GetCircuitBreaker() *circuitbreaker.CircuitBreaker
func (c *ClientWithFallback) SetCircuitBreaker(cb *circuitbreaker.CircuitBreaker) { func (c *ClientWithFallback) SetCircuitBreaker(cb *circuitbreaker.CircuitBreaker) {
c.circuitbreaker = cb c.circuitbreaker = cb
} }
func convertFunctorCallStatuses(statuses []circuitbreaker.FunctorCallStatus) (result []rpcstatus.RpcProviderCallStatus) {
for _, f := range statuses {
result = append(result, rpcstatus.RpcProviderCallStatus{Name: f.Name, Timestamp: f.Timestamp, Err: f.Err})
}
return
}

View File

@ -0,0 +1,240 @@
package chain
import (
"context"
"errors"
"github.com/ethereum/go-ethereum/core/vm"
"strconv"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
healthManager "github.com/status-im/status-go/healthmanager"
"github.com/status-im/status-go/healthmanager/rpcstatus"
"github.com/status-im/status-go/rpc/chain/ethclient"
"github.com/status-im/status-go/rpc/chain/rpclimiter"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
mockEthclient "github.com/status-im/status-go/rpc/chain/ethclient/mock/client/ethclient"
)
type ClientWithFallbackSuite struct {
suite.Suite
client *ClientWithFallback
mockEthClients []*mockEthclient.MockRPSLimitedEthClientInterface
providersHealthManager *healthManager.ProvidersHealthManager
mockCtrl *gomock.Controller
}
func (s *ClientWithFallbackSuite) SetupTest() {
s.mockCtrl = gomock.NewController(s.T())
}
func (s *ClientWithFallbackSuite) TearDownTest() {
s.mockCtrl.Finish()
}
func (s *ClientWithFallbackSuite) setupClients(numClients int) {
s.mockEthClients = make([]*mockEthclient.MockRPSLimitedEthClientInterface, 0)
ethClients := make([]ethclient.RPSLimitedEthClientInterface, 0)
for i := 0; i < numClients; i++ {
ethClient := mockEthclient.NewMockRPSLimitedEthClientInterface(s.mockCtrl)
ethClient.EXPECT().GetName().AnyTimes().Return("test" + strconv.Itoa(i))
ethClient.EXPECT().GetLimiter().AnyTimes().Return(nil)
s.mockEthClients = append(s.mockEthClients, ethClient)
ethClients = append(ethClients, ethClient)
}
var chainID uint64 = 0
s.providersHealthManager = healthManager.NewProvidersHealthManager(chainID)
s.client = NewClient(ethClients, chainID, s.providersHealthManager)
}
func (s *ClientWithFallbackSuite) TestSingleClientSuccess() {
s.setupClients(1)
ctx := context.Background()
hash := common.HexToHash("0x1234")
block := &types.Block{}
// GIVEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(block, nil).Times(1)
// WHEN
result, err := s.client.BlockByHash(ctx, hash)
require.NoError(s.T(), err)
require.Equal(s.T(), block, result)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 1)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp)
}
func (s *ClientWithFallbackSuite) TestSingleClientConnectionError() {
s.setupClients(1)
ctx := context.Background()
hash := common.HexToHash("0x1234")
// GIVEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("connection error")).Times(1)
// WHEN
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 1)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown)
}
func (s *ClientWithFallbackSuite) TestRPSLimitErrorDoesNotMarkChainDown() {
s.setupClients(1)
ctx := context.Background()
hash := common.HexToHash("0x1234")
// WHEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1)
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 1)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp)
status := providerStatuses["test0"]
require.Equal(s.T(), status.Status, rpcstatus.StatusUp, "provider shouldn't be DOWN on RPS limit")
}
func (s *ClientWithFallbackSuite) TestContextCanceledDoesNotMarkChainDown() {
s.setupClients(1)
ctx, cancel := context.WithCancel(context.Background())
cancel()
hash := common.HexToHash("0x1234")
// WHEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, context.Canceled).Times(1)
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
require.True(s.T(), errors.Is(err, context.Canceled))
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 1)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp)
}
func (s *ClientWithFallbackSuite) TestVMErrorDoesNotMarkChainDown() {
s.setupClients(1)
ctx := context.Background()
hash := common.HexToHash("0x1234")
vmError := vm.ErrOutOfGas
// GIVEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, vmError).Times(1)
// WHEN
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
require.True(s.T(), errors.Is(err, vm.ErrOutOfGas))
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 1)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusUp)
}
func (s *ClientWithFallbackSuite) TestNoClientsChainUnknown() {
s.setupClients(0)
ctx := context.Background()
hash := common.HexToHash("0x1234")
// WHEN
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUnknown, chainStatus.Status)
}
func (s *ClientWithFallbackSuite) TestAllClientsDifferentErrors() {
s.setupClients(3)
ctx := context.Background()
hash := common.HexToHash("0x1234")
// GIVEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1)
s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, rpclimiter.ErrRequestsOverLimit).Times(1)
s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, vm.ErrOutOfGas).Times(1)
// WHEN
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUp, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 3)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown, "provider test0 should be DOWN due to a connection error")
require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusUp, "provider test1 should not be marked DOWN due to RPS limit error")
require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusUp, "provider test2 should not be labelled DOWN due to a VM error")
}
func (s *ClientWithFallbackSuite) TestAllClientsNetworkErrors() {
s.setupClients(3)
ctx := context.Background()
hash := common.HexToHash("0x1234")
// GIVEN
s.mockEthClients[0].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1)
s.mockEthClients[1].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1)
s.mockEthClients[2].EXPECT().BlockByHash(ctx, hash).Return(nil, errors.New("no such host")).Times(1)
// WHEN
_, err := s.client.BlockByHash(ctx, hash)
require.Error(s.T(), err)
// THEN
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusDown, chainStatus.Status)
providerStatuses := s.providersHealthManager.GetStatuses()
require.Len(s.T(), providerStatuses, 3)
require.Equal(s.T(), providerStatuses["test0"].Status, rpcstatus.StatusDown)
require.Equal(s.T(), providerStatuses["test1"].Status, rpcstatus.StatusDown)
require.Equal(s.T(), providerStatuses["test2"].Status, rpcstatus.StatusDown)
}
func (s *ClientWithFallbackSuite) TestChainStatusUnknownWhenAllProvidersUnknown() {
s.setupClients(2)
chainStatus := s.providersHealthManager.Status()
require.Equal(s.T(), rpcstatus.StatusUnknown, chainStatus.Status)
}
func TestClientWithFallbackSuite(t *testing.T) {
suite.Run(t, new(ClientWithFallbackSuite))
}

View File

@ -32,7 +32,7 @@ func setupClientTest(t *testing.T) (*ClientWithFallback, []*mock_ethclient.MockR
ethClients = append(ethClients, ethCl) ethClients = append(ethClients, ethCl)
} }
client := NewClient(ethClients, 0) client := NewClient(ethClients, 0, nil)
cleanup := func() { cleanup := func() {
mockCtrl.Finish() mockCtrl.Finish()

View File

@ -19,7 +19,9 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
gethrpc "github.com/ethereum/go-ethereum/rpc" gethrpc "github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/event"
appCommon "github.com/status-im/status-go/common" appCommon "github.com/status-im/status-go/common"
"github.com/status-im/status-go/healthmanager"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/chain"
"github.com/status-im/status-go/rpc/chain/ethclient" "github.com/status-im/status-go/rpc/chain/ethclient"
@ -27,6 +29,7 @@ import (
"github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/rpcstats" "github.com/status-im/status-go/services/rpcstats"
"github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/walletevent"
) )
const ( const (
@ -48,6 +51,8 @@ const (
// rpcUserAgentUpstreamFormat a separate user agent format for upstream, because we should not be using upstream // rpcUserAgentUpstreamFormat a separate user agent format for upstream, because we should not be using upstream
// if we see this user agent in the logs that means parts of the application are using a malconfigured http client // if we see this user agent in the logs that means parts of the application are using a malconfigured http client
rpcUserAgentUpstreamFormat = "procuratee-%s-upstream/%s" rpcUserAgentUpstreamFormat = "procuratee-%s-upstream/%s"
EventBlockchainHealthChanged walletevent.EventType = "wallet-blockchain-health-changed" // Full status of the blockchain (including provider statuses)
) )
// List of RPC client errors. // List of RPC client errors.
@ -101,6 +106,10 @@ type Client struct {
router *router router *router
NetworkManager *network.Manager NetworkManager *network.Manager
healthMgr *healthmanager.BlockchainHealthManager
stopMonitoringFunc context.CancelFunc
walletFeed *event.Feed
handlersMx sync.RWMutex // mx guards handlers handlersMx sync.RWMutex // mx guards handlers
handlers map[string]Handler // locally registered handlers handlers map[string]Handler // locally registered handlers
log log.Logger log log.Logger
@ -116,7 +125,7 @@ var verifProxyInitFn func(c *Client)
// //
// Client is safe for concurrent use and will automatically // Client is safe for concurrent use and will automatically
// reconnect to the server if connection is lost. // reconnect to the server if connection is lost.
func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params.Network, db *sql.DB, providerConfigs []params.ProviderConfig) (*Client, error) { func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params.Network, db *sql.DB, walletFeed *event.Feed, providerConfigs []params.ProviderConfig) (*Client, error) {
var err error var err error
log := log.New("package", "status-go/rpc.Client") log := log.New("package", "status-go/rpc.Client")
@ -138,6 +147,8 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params
limiterPerProvider: make(map[string]*rpclimiter.RPCRpsLimiter), limiterPerProvider: make(map[string]*rpclimiter.RPCRpsLimiter),
log: log, log: log,
providerConfigs: providerConfigs, providerConfigs: providerConfigs,
healthMgr: healthmanager.NewBlockchainHealthManager(),
walletFeed: walletFeed,
} }
c.UpstreamChainID = upstreamChainID c.UpstreamChainID = upstreamChainID
@ -150,6 +161,55 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, networks []params
return &c, nil return &c, nil
} }
func (c *Client) Start(ctx context.Context) {
if c.stopMonitoringFunc != nil {
c.log.Warn("Blockchain health manager already started")
return
}
cancelableCtx, cancel := context.WithCancel(ctx)
c.stopMonitoringFunc = cancel
statusCh := c.healthMgr.Subscribe()
go c.monitorHealth(cancelableCtx, statusCh)
}
func (c *Client) Stop() {
c.healthMgr.Stop()
if c.stopMonitoringFunc == nil {
return
}
c.stopMonitoringFunc()
c.stopMonitoringFunc = nil
}
func (c *Client) monitorHealth(ctx context.Context, statusCh chan struct{}) {
sendFullStatusEventFunc := func() {
blockchainStatus := c.healthMgr.GetFullStatus()
encodedMessage, err := json.Marshal(blockchainStatus)
if err != nil {
c.log.Warn("could not marshal full blockchain status", "error", err)
return
}
if c.walletFeed == nil {
return
}
c.walletFeed.Send(walletevent.Event{
Type: EventBlockchainHealthChanged,
Message: string(encodedMessage),
At: time.Now().Unix(),
})
}
for {
select {
case <-ctx.Done():
return
case <-statusCh:
sendFullStatusEventFunc()
}
}
}
func (c *Client) GetNetworkManager() *network.Manager { func (c *Client) GetNetworkManager() *network.Manager {
return c.NetworkManager return c.NetworkManager
} }
@ -207,8 +267,10 @@ func (c *Client) getClientUsingCache(chainID uint64) (chain.ClientInterface, err
return nil, fmt.Errorf("could not find any RPC URL for chain: %d", chainID) return nil, fmt.Errorf("could not find any RPC URL for chain: %d", chainID)
} }
client := chain.NewClient(ethClients, chainID) phm := healthmanager.NewProvidersHealthManager(chainID)
client.SetWalletNotifier(c.walletNotifier) c.healthMgr.RegisterProvidersHealthManager(context.Background(), phm)
client := chain.NewClient(ethClients, chainID, phm)
c.rpcClients[chainID] = client c.rpcClients[chainID] = client
return client, nil return client, nil
} }

View File

@ -44,7 +44,7 @@ func TestBlockedRoutesCall(t *testing.T) {
gethRPCClient, err := gethrpc.Dial(ts.URL) gethRPCClient, err := gethrpc.Dial(ts.URL)
require.NoError(t, err) require.NoError(t, err)
c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil, nil)
require.NoError(t, err) require.NoError(t, err)
for _, m := range blockedMethods { for _, m := range blockedMethods {
@ -83,7 +83,7 @@ func TestBlockedRoutesRawCall(t *testing.T) {
gethRPCClient, err := gethrpc.Dial(ts.URL) gethRPCClient, err := gethrpc.Dial(ts.URL)
require.NoError(t, err) require.NoError(t, err)
c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil) c, err := NewClient(gethRPCClient, 1, []params.Network{}, db, nil, nil)
require.NoError(t, err) require.NoError(t, err)
for _, m := range blockedMethods { for _, m := range blockedMethods {
@ -142,7 +142,7 @@ func TestGetClientsUsingCache(t *testing.T) {
DefaultFallbackURL2: server.URL + path3, DefaultFallbackURL2: server.URL + path3,
}, },
} }
c, err := NewClient(nil, 1, networks, db, providerConfigs) c, err := NewClient(nil, 1, networks, db, nil, providerConfigs)
require.NoError(t, err) require.NoError(t, err)
// Networks from DB must pick up DefaultRPCURL, DefaultFallbackURL, DefaultFallbackURL2 // Networks from DB must pick up DefaultRPCURL, DefaultFallbackURL, DefaultFallbackURL2

View File

@ -0,0 +1,46 @@
package mock_common
import (
"time"
"github.com/ethereum/go-ethereum/event"
"github.com/status-im/status-go/services/wallet/walletevent"
)
type FeedSubscription struct {
events chan walletevent.Event
feed *event.Feed
done chan struct{}
}
func NewFeedSubscription(feed *event.Feed) *FeedSubscription {
events := make(chan walletevent.Event, 100)
done := make(chan struct{})
subscription := feed.Subscribe(events)
go func() {
<-done
subscription.Unsubscribe()
close(events)
}()
return &FeedSubscription{events: events, feed: feed, done: done}
}
func (f *FeedSubscription) WaitForEvent(timeout time.Duration) (walletevent.Event, bool) {
select {
case evt := <-f.events:
return evt, true
case <-time.After(timeout):
return walletevent.Event{}, false
}
}
func (f *FeedSubscription) GetFeed() *event.Feed {
return f.feed
}
func (f *FeedSubscription) Close() {
close(f.done)
}

View File

@ -0,0 +1,73 @@
package market
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"github.com/ethereum/go-ethereum/event"
mock_common "github.com/status-im/status-go/services/wallet/common/mock"
mock_market "github.com/status-im/status-go/services/wallet/market/mock"
"github.com/status-im/status-go/services/wallet/thirdparty"
)
type MarketTestSuite struct {
suite.Suite
feedSub *mock_common.FeedSubscription
symbols []string
currencies []string
}
func (s *MarketTestSuite) SetupTest() {
feed := new(event.Feed)
s.feedSub = mock_common.NewFeedSubscription(feed)
s.symbols = []string{"BTC", "ETH"}
s.currencies = []string{"USD", "EUR"}
}
func (s *MarketTestSuite) TearDownTest() {
s.feedSub.Close()
}
func (s *MarketTestSuite) TestEventOnRpsError() {
ctrl := gomock.NewController(s.T())
defer ctrl.Finish()
// GIVEN
customErr := errors.New("request rate exceeded")
priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr)
manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed())
// WHEN
_, err := manager.FetchPrices(s.symbols, s.currencies)
s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError")
event, ok := s.feedSub.WaitForEvent(5 * time.Second)
s.Require().True(ok, "expected an event, but none was received")
// THEN
s.Require().Equal(event.Type, EventMarketStatusChanged)
}
func (s *MarketTestSuite) TestNoEventOnNetworkError() {
ctrl := gomock.NewController(s.T())
defer ctrl.Finish()
// GIVEN
customErr := errors.New("dial tcp: lookup optimism-goerli.infura.io: no such host")
priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr)
manager := NewManager([]thirdparty.MarketDataProvider{priceProviderWithError}, s.feedSub.GetFeed())
_, err := manager.FetchPrices(s.symbols, s.currencies)
s.Require().Error(err, "expected error from FetchPrices due to MockPriceProviderWithError")
_, ok := s.feedSub.WaitForEvent(time.Millisecond * 500)
//THEN
s.Require().False(ok, "expected no event, but one was received")
}
func TestMarketTestSuite(t *testing.T) {
suite.Run(t, new(MarketTestSuite))
}

View File

@ -10,48 +10,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
mock_market "github.com/status-im/status-go/services/wallet/market/mock"
"github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/services/wallet/thirdparty"
mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock" mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock"
) )
type MockPriceProvider struct {
mock_thirdparty.MockMarketDataProvider
mockPrices map[string]map[string]float64
}
func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider {
return &MockPriceProvider{
MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl),
}
}
func (mpp *MockPriceProvider) setMockPrices(prices map[string]map[string]float64) {
mpp.mockPrices = prices
}
func (mpp *MockPriceProvider) ID() string {
return "MockPriceProvider"
}
func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) {
res := make(map[string]map[string]float64)
for _, symbol := range symbols {
res[symbol] = make(map[string]float64)
for _, currency := range currencies {
res[symbol][currency] = mpp.mockPrices[symbol][currency]
}
}
return res, nil
}
type MockPriceProviderWithError struct {
MockPriceProvider
}
func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) {
return nil, errors.New("error")
}
func setupMarketManager(t *testing.T, providers []thirdparty.MarketDataProvider) *Manager { func setupMarketManager(t *testing.T, providers []thirdparty.MarketDataProvider) *Manager {
return NewManager(providers, &event.Feed{}) return NewManager(providers, &event.Feed{})
} }
@ -80,8 +43,8 @@ var mockPrices = map[string]map[string]float64{
func TestPrice(t *testing.T) { func TestPrice(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
priceProvider := NewMockPriceProvider(ctrl) priceProvider := mock_market.NewMockPriceProvider(ctrl)
priceProvider.setMockPrices(mockPrices) priceProvider.SetMockPrices(mockPrices)
manager := setupMarketManager(t, []thirdparty.MarketDataProvider{priceProvider, priceProvider}) manager := setupMarketManager(t, []thirdparty.MarketDataProvider{priceProvider, priceProvider})
@ -125,9 +88,12 @@ func TestPrice(t *testing.T) {
func TestFetchPriceErrorFirstProvider(t *testing.T) { func TestFetchPriceErrorFirstProvider(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
priceProvider := NewMockPriceProvider(ctrl) priceProvider := mock_market.NewMockPriceProvider(ctrl)
priceProvider.setMockPrices(mockPrices) priceProvider.SetMockPrices(mockPrices)
priceProviderWithError := &MockPriceProviderWithError{}
customErr := errors.New("error")
priceProviderWithError := mock_market.NewMockPriceProviderWithError(ctrl, customErr)
symbols := []string{"BTC", "ETH"} symbols := []string{"BTC", "ETH"}
currencies := []string{"USD", "EUR"} currencies := []string{"USD", "EUR"}

View File

@ -0,0 +1,54 @@
package mock_market
import (
"go.uber.org/mock/gomock"
mock_thirdparty "github.com/status-im/status-go/services/wallet/thirdparty/mock"
)
type MockPriceProvider struct {
mock_thirdparty.MockMarketDataProvider
mockPrices map[string]map[string]float64
}
func NewMockPriceProvider(ctrl *gomock.Controller) *MockPriceProvider {
return &MockPriceProvider{
MockMarketDataProvider: *mock_thirdparty.NewMockMarketDataProvider(ctrl),
}
}
func (mpp *MockPriceProvider) SetMockPrices(prices map[string]map[string]float64) {
mpp.mockPrices = prices
}
func (mpp *MockPriceProvider) ID() string {
return "MockPriceProvider"
}
func (mpp *MockPriceProvider) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) {
res := make(map[string]map[string]float64)
for _, symbol := range symbols {
res[symbol] = make(map[string]float64)
for _, currency := range currencies {
res[symbol][currency] = mpp.mockPrices[symbol][currency]
}
}
return res, nil
}
type MockPriceProviderWithError struct {
MockPriceProvider
err error
}
// NewMockPriceProviderWithError creates a new MockPriceProviderWithError with the specified error
func NewMockPriceProviderWithError(ctrl *gomock.Controller, err error) *MockPriceProviderWithError {
return &MockPriceProviderWithError{
MockPriceProvider: *NewMockPriceProvider(ctrl),
err: err,
}
}
func (mpp *MockPriceProviderWithError) FetchPrices(symbols []string, currencies []string) (map[string]map[string]float64, error) {
return nil, mpp.err
}