diff --git a/Makefile b/Makefile index edbf472b3..eda9a6f82 100644 --- a/Makefile +++ b/Makefile @@ -346,6 +346,8 @@ mock: ##@other Regenerate mocks mockgen -package=mock_client -destination=rpc/chain/mock/client/client.go -source=rpc/chain/client.go mockgen -package=mock_token -destination=services/wallet/token/mock/token/tokenmanager.go -source=services/wallet/token/token.go mockgen -package=mock_balance_persistence -destination=services/wallet/token/mock/balance_persistence/balance_persistence.go -source=services/wallet/token/balance_persistence.go + mockgen -package=mock_network -destination=rpc/network/mock/network.go -source=rpc/network/network.go + mockgen -package=mock_rpcclient -destination=rpc/mock/client/client.go -source=rpc/client.go docker-test: ##@tests Run tests in a docker container with golang. docker run --privileged --rm -it -v "$(PWD):$(DOCKER_TEST_WORKDIR)" -w "$(DOCKER_TEST_WORKDIR)" $(DOCKER_TEST_IMAGE) go test ${ARGS} diff --git a/contracts/contracts.go b/contracts/contracts.go index 6a4063e8e..f5f123f9e 100644 --- a/contracts/contracts.go +++ b/contracts/contracts.go @@ -16,15 +16,17 @@ import ( ) type ContractMakerIface interface { - NewEthScan(chainID uint64) (*ethscan.BalanceScanner, uint, error) + NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error) + NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error) + NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error) // TODO extend with other contracts } type ContractMaker struct { - RPCClient *rpc.Client + RPCClient rpc.ClientInterface } -func NewContractMaker(client *rpc.Client) (*ContractMaker, error) { +func NewContractMaker(client rpc.ClientInterface) (*ContractMaker, error) { if client == nil { return nil, errors.New("could not initialize ContractMaker with an rpc client") } @@ -72,7 +74,7 @@ func (c *ContractMaker) NewUsernameRegistrar(chainID uint64, contractAddr common ) } -func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (*ierc20.IERC20, error) { +func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error) { backend, err := c.RPCClient.EthClient(chainID) if err != nil { return nil, err @@ -83,6 +85,14 @@ func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (* backend, ) } +func (c *ContractMaker) NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error) { + backend, err := c.RPCClient.EthClient(chainID) + if err != nil { + return nil, err + } + + return ierc20.NewIERC20Caller(contractAddr, backend) +} func (c *ContractMaker) NewSNT(chainID uint64) (*snt.SNT, error) { contractAddr, err := snt.ContractAddress(chainID) @@ -166,7 +176,7 @@ func (c *ContractMaker) NewDirectory(chainID uint64) (*directory.Directory, erro ) } -func (c *ContractMaker) NewEthScan(chainID uint64) (*ethscan.BalanceScanner, uint, error) { +func (c *ContractMaker) NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error) { contractAddr, err := ethscan.ContractAddress(chainID) if err != nil { return nil, 0, err diff --git a/contracts/ethscan/ethscan_iface.go b/contracts/ethscan/ethscan_iface.go new file mode 100644 index 000000000..2bb08ba81 --- /dev/null +++ b/contracts/ethscan/ethscan_iface.go @@ -0,0 +1,15 @@ +package ethscan + +import ( + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" +) + +type BalanceScannerIface interface { + EtherBalances(opts *bind.CallOpts, addresses []common.Address) ([]BalanceScannerResult, error) + TokenBalances(opts *bind.CallOpts, addresses []common.Address, tokenAddress common.Address) ([]BalanceScannerResult, error) + TokensBalance(opts *bind.CallOpts, owner common.Address, contracts []common.Address) ([]BalanceScannerResult, error) +} + +// Verify that BalanceScanner implements BalanceScannerIface. If contract changes, this will fail to compile. +var _ BalanceScannerIface = (*BalanceScanner)(nil) diff --git a/contracts/ierc20/ierc20_iface.go b/contracts/ierc20/ierc20_iface.go new file mode 100644 index 000000000..4766995c7 --- /dev/null +++ b/contracts/ierc20/ierc20_iface.go @@ -0,0 +1,28 @@ +package ierc20 + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" +) + +type IERC20Iface interface { + BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error) + Name(opts *bind.CallOpts) (string, error) + Symbol(opts *bind.CallOpts) (string, error) + Decimals(opts *bind.CallOpts) (uint8, error) + Allowance(opts *bind.CallOpts, owner common.Address, spender common.Address) (*big.Int, error) +} + +// Verify that IERC20 implements IERC20Iface. If contract changes, this will fail to compile, update interface to match. +var _ IERC20Iface = (*IERC20)(nil) + +type IERC20CallerIface interface { + BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error) + Name(opts *bind.CallOpts) (string, error) + Symbol(opts *bind.CallOpts) (string, error) + Decimals(opts *bind.CallOpts) (uint8, error) +} + +var _ IERC20CallerIface = (*IERC20Caller)(nil) diff --git a/contracts/mock/contracts.go b/contracts/mock/contracts.go new file mode 100644 index 000000000..ee95ffca5 --- /dev/null +++ b/contracts/mock/contracts.go @@ -0,0 +1,83 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: contracts/contracts.go + +// Package mock_contracts is a generated GoMock package. +package mock_contracts + +import ( + reflect "reflect" + + common "github.com/ethereum/go-ethereum/common" + gomock "github.com/golang/mock/gomock" + ethscan "github.com/status-im/status-go/contracts/ethscan" + ierc20 "github.com/status-im/status-go/contracts/ierc20" +) + +// MockContractMakerIface is a mock of ContractMakerIface interface. +type MockContractMakerIface struct { + ctrl *gomock.Controller + recorder *MockContractMakerIfaceMockRecorder +} + +// MockContractMakerIfaceMockRecorder is the mock recorder for MockContractMakerIface. +type MockContractMakerIfaceMockRecorder struct { + mock *MockContractMakerIface +} + +// NewMockContractMakerIface creates a new mock instance. +func NewMockContractMakerIface(ctrl *gomock.Controller) *MockContractMakerIface { + mock := &MockContractMakerIface{ctrl: ctrl} + mock.recorder = &MockContractMakerIfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockContractMakerIface) EXPECT() *MockContractMakerIfaceMockRecorder { + return m.recorder +} + +// NewERC20 mocks base method. +func (m *MockContractMakerIface) NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewERC20", chainID, contractAddr) + ret0, _ := ret[0].(ierc20.IERC20Iface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewERC20 indicates an expected call of NewERC20. +func (mr *MockContractMakerIfaceMockRecorder) NewERC20(chainID, contractAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewERC20", reflect.TypeOf((*MockContractMakerIface)(nil).NewERC20), chainID, contractAddr) +} + +// NewERC20Caller mocks base method. +func (m *MockContractMakerIface) NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewERC20Caller", chainID, contractAddr) + ret0, _ := ret[0].(ierc20.IERC20CallerIface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewERC20Caller indicates an expected call of NewERC20Caller. +func (mr *MockContractMakerIfaceMockRecorder) NewERC20Caller(chainID, contractAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewERC20Caller", reflect.TypeOf((*MockContractMakerIface)(nil).NewERC20Caller), chainID, contractAddr) +} + +// NewEthScan mocks base method. +func (m *MockContractMakerIface) NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewEthScan", chainID) + ret0, _ := ret[0].(ethscan.BalanceScannerIface) + ret1, _ := ret[1].(uint) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// NewEthScan indicates an expected call of NewEthScan. +func (mr *MockContractMakerIfaceMockRecorder) NewEthScan(chainID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewEthScan", reflect.TypeOf((*MockContractMakerIface)(nil).NewEthScan), chainID) +} diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index 0ee422ae2..3b8b4de88 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -37,6 +37,7 @@ import ( "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/transport" + "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/wallet/bigint" walletcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/thirdparty" @@ -271,22 +272,23 @@ type CommunityTokensServiceInterface interface { } type DefaultTokenManager struct { - tokenManager *token.Manager + tokenManager *token.Manager + networkManager network.ManagerInterface } -func NewDefaultTokenManager(tm *token.Manager) *DefaultTokenManager { - return &DefaultTokenManager{tokenManager: tm} +func NewDefaultTokenManager(tm *token.Manager, nm network.ManagerInterface) *DefaultTokenManager { + return &DefaultTokenManager{tokenManager: tm, networkManager: nm} } type BalancesByChain = map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) { - networks, err := m.tokenManager.RPCClient.NetworkManager.Get(false) + networks, err := m.networkManager.Get(false) if err != nil { return nil, err } - areTestNetworksEnabled, err := m.tokenManager.RPCClient.NetworkManager.GetTestNetworksEnabled() + areTestNetworksEnabled, err := m.networkManager.GetTestNetworksEnabled() if err != nil { return nil, err } diff --git a/protocol/messenger.go b/protocol/messenger.go index 40e1467f0..e9e6a950d 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -472,7 +472,7 @@ func NewMessenger( managerOptions = append(managerOptions, communities.WithTokenManager(c.tokenManager)) } else if c.rpcClient != nil { tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil, nil, nil, token.NewPersistence(c.walletDb)) - managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager))) + managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager, c.rpcClient.NetworkManager))) } if c.walletConfig != nil { diff --git a/rpc/chain/client.go b/rpc/chain/client.go index 7023e7906..cb14f399a 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -28,7 +28,7 @@ type BatchCallClient interface { } type ChainInterface interface { - BatchCallContext(ctx context.Context, b []rpc.BatchElem) error + BatchCallClient HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) diff --git a/rpc/chain/mock/client/client.go b/rpc/chain/mock/client/client.go index 57ddeb5df..ca67d0631 100644 --- a/rpc/chain/mock/client/client.go +++ b/rpc/chain/mock/client/client.go @@ -9,12 +9,11 @@ import ( big "math/big" reflect "reflect" - gomock "github.com/golang/mock/gomock" - ethereum "github.com/ethereum/go-ethereum" common "github.com/ethereum/go-ethereum/common" types "github.com/ethereum/go-ethereum/core/types" rpc "github.com/ethereum/go-ethereum/rpc" + gomock "github.com/golang/mock/gomock" chain "github.com/status-im/status-go/rpc/chain" ) diff --git a/rpc/client.go b/rpc/client.go index 0eb471dad..857371d79 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -37,6 +37,9 @@ type Handler func(context.Context, uint64, ...interface{}) (interface{}, error) type ClientInterface interface { AbstractEthClient(chainID common.ChainID) (chain.BatchCallClient, error) + EthClient(chainID uint64) (chain.ClientInterface, error) + EthClients(chainIDs []uint64) (map[uint64]chain.ClientInterface, error) + CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error } // Client represents RPC client with custom routing diff --git a/rpc/mock/client/client.go b/rpc/mock/client/client.go new file mode 100644 index 000000000..1e1468344 --- /dev/null +++ b/rpc/mock/client/client.go @@ -0,0 +1,101 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: rpc/client.go + +// Package mock_rpcclient is a generated GoMock package. +package mock_rpcclient + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + chain "github.com/status-im/status-go/rpc/chain" + common "github.com/status-im/status-go/services/wallet/common" +) + +// MockClientInterface is a mock of ClientInterface interface. +type MockClientInterface struct { + ctrl *gomock.Controller + recorder *MockClientInterfaceMockRecorder +} + +// MockClientInterfaceMockRecorder is the mock recorder for MockClientInterface. +type MockClientInterfaceMockRecorder struct { + mock *MockClientInterface +} + +// NewMockClientInterface creates a new mock instance. +func NewMockClientInterface(ctrl *gomock.Controller) *MockClientInterface { + mock := &MockClientInterface{ctrl: ctrl} + mock.recorder = &MockClientInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientInterface) EXPECT() *MockClientInterfaceMockRecorder { + return m.recorder +} + +// AbstractEthClient mocks base method. +func (m *MockClientInterface) AbstractEthClient(chainID common.ChainID) (chain.BatchCallClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AbstractEthClient", chainID) + ret0, _ := ret[0].(chain.BatchCallClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AbstractEthClient indicates an expected call of AbstractEthClient. +func (mr *MockClientInterfaceMockRecorder) AbstractEthClient(chainID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AbstractEthClient", reflect.TypeOf((*MockClientInterface)(nil).AbstractEthClient), chainID) +} + +// CallContext mocks base method. +func (m *MockClientInterface) CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error { + m.ctrl.T.Helper() + varargs := []interface{}{context, result, chainID, method} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CallContext", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// CallContext indicates an expected call of CallContext. +func (mr *MockClientInterfaceMockRecorder) CallContext(context, result, chainID, method interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{context, result, chainID, method}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallContext", reflect.TypeOf((*MockClientInterface)(nil).CallContext), varargs...) +} + +// EthClient mocks base method. +func (m *MockClientInterface) EthClient(chainID uint64) (chain.ClientInterface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EthClient", chainID) + ret0, _ := ret[0].(chain.ClientInterface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EthClient indicates an expected call of EthClient. +func (mr *MockClientInterfaceMockRecorder) EthClient(chainID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClient", reflect.TypeOf((*MockClientInterface)(nil).EthClient), chainID) +} + +// EthClients mocks base method. +func (m *MockClientInterface) EthClients(chainIDs []uint64) (map[uint64]chain.ClientInterface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EthClients", chainIDs) + ret0, _ := ret[0].(map[uint64]chain.ClientInterface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EthClients indicates an expected call of EthClients. +func (mr *MockClientInterfaceMockRecorder) EthClients(chainIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClients", reflect.TypeOf((*MockClientInterface)(nil).EthClients), chainIDs) +} diff --git a/rpc/network/mock/network.go b/rpc/network/mock/network.go new file mode 100644 index 000000000..62db4335b --- /dev/null +++ b/rpc/network/mock/network.go @@ -0,0 +1,108 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: rpc/network/network.go + +// Package mock_network is a generated GoMock package. +package mock_network + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + params "github.com/status-im/status-go/params" +) + +// MockManagerInterface is a mock of ManagerInterface interface. +type MockManagerInterface struct { + ctrl *gomock.Controller + recorder *MockManagerInterfaceMockRecorder +} + +// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface. +type MockManagerInterfaceMockRecorder struct { + mock *MockManagerInterface +} + +// NewMockManagerInterface creates a new mock instance. +func NewMockManagerInterface(ctrl *gomock.Controller) *MockManagerInterface { + mock := &MockManagerInterface{ctrl: ctrl} + mock.recorder = &MockManagerInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManagerInterface) EXPECT() *MockManagerInterfaceMockRecorder { + return m.recorder +} + +// Find mocks base method. +func (m *MockManagerInterface) Find(chainID uint64) *params.Network { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Find", chainID) + ret0, _ := ret[0].(*params.Network) + return ret0 +} + +// Find indicates an expected call of Find. +func (mr *MockManagerInterfaceMockRecorder) Find(chainID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockManagerInterface)(nil).Find), chainID) +} + +// Get mocks base method. +func (m *MockManagerInterface) Get(onlyEnabled bool) ([]*params.Network, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", onlyEnabled) + ret0, _ := ret[0].([]*params.Network) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockManagerInterfaceMockRecorder) Get(onlyEnabled interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockManagerInterface)(nil).Get), onlyEnabled) +} + +// GetAll mocks base method. +func (m *MockManagerInterface) GetAll() ([]*params.Network, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAll") + ret0, _ := ret[0].([]*params.Network) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAll indicates an expected call of GetAll. +func (mr *MockManagerInterfaceMockRecorder) GetAll() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAll", reflect.TypeOf((*MockManagerInterface)(nil).GetAll)) +} + +// GetConfiguredNetworks mocks base method. +func (m *MockManagerInterface) GetConfiguredNetworks() []params.Network { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConfiguredNetworks") + ret0, _ := ret[0].([]params.Network) + return ret0 +} + +// GetConfiguredNetworks indicates an expected call of GetConfiguredNetworks. +func (mr *MockManagerInterfaceMockRecorder) GetConfiguredNetworks() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfiguredNetworks", reflect.TypeOf((*MockManagerInterface)(nil).GetConfiguredNetworks)) +} + +// GetTestNetworksEnabled mocks base method. +func (m *MockManagerInterface) GetTestNetworksEnabled() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTestNetworksEnabled") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTestNetworksEnabled indicates an expected call of GetTestNetworksEnabled. +func (mr *MockManagerInterfaceMockRecorder) GetTestNetworksEnabled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTestNetworksEnabled", reflect.TypeOf((*MockManagerInterface)(nil).GetTestNetworksEnabled)) +} diff --git a/rpc/network/network.go b/rpc/network/network.go index bef22c0be..682deee65 100644 --- a/rpc/network/network.go +++ b/rpc/network/network.go @@ -81,6 +81,14 @@ func (nq *networksQuery) exec(db *sql.DB) ([]*params.Network, error) { return res, err } +type ManagerInterface interface { + Get(onlyEnabled bool) ([]*params.Network, error) + GetAll() ([]*params.Network, error) + Find(chainID uint64) *params.Network + GetConfiguredNetworks() []params.Network + GetTestNetworksEnabled() (bool, error) +} + type Manager struct { db *sql.DB configuredNetworks []params.Network diff --git a/services/wallet/activity/service_test.go b/services/wallet/activity/service_test.go index a2ffd37ee..f191f0473 100644 --- a/services/wallet/activity/service_test.go +++ b/services/wallet/activity/service_test.go @@ -7,17 +7,20 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" + eth "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/event" "github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/rpc/chain" + mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/services/wallet/token" + mock_token "github.com/status-im/status-go/services/wallet/token/mock/token" "github.com/status-im/status-go/services/wallet/transfer" "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/t/helpers" @@ -53,45 +56,15 @@ func (m *mockCollectiblesManager) FetchCollectionSocialsAsync(contractID thirdpa return nil } -// mockTokenManager implements the token.ManagerInterface -type mockTokenManager struct { - mock.Mock -} - -func (m *mockTokenManager) LookupTokenIdentity(chainID uint64, address eth.Address, native bool) *token.Token { - args := m.Called(chainID, address, native) - res := args.Get(0) - if res == nil { - return nil - } - return res.(*token.Token) -} - -func (m *mockTokenManager) LookupToken(chainID *uint64, tokenSymbol string) (tkn *token.Token, isNative bool) { - args := m.Called(chainID, tokenSymbol) - return args.Get(0).(*token.Token), args.Bool(1) -} - -func (m *mockTokenManager) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []eth.Address) (map[uint64]map[eth.Address]map[eth.Address]*hexutil.Big, error) { - return nil, nil // Not used here -} - -func (m *mockTokenManager) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) { - return nil, nil // Not used here -} - -func (m *mockTokenManager) GetTokenHistoricalBalance(account eth.Address, chainID uint64, symbol string, timestamp int64) (*big.Int, error) { - return nil, nil // Not used here -} - type testState struct { service *Service eventFeed *event.Feed - tokenMock *mockTokenManager + tokenMock *mock_token.MockManagerInterface collectiblesMock *mockCollectiblesManager close func() pendingTracker *transactions.PendingTxTracker chainClient *transactions.MockChainClient + rpcClient *mock_rpcclient.MockClientInterface } func setupTestService(tb testing.TB) (state testState) { @@ -104,20 +77,26 @@ func setupTestService(tb testing.TB) (state testState) { require.NoError(tb, err) state.eventFeed = new(event.Feed) - state.tokenMock = &mockTokenManager{} + mockCtrl := gomock.NewController(tb) + state.tokenMock = mock_token.NewMockManagerInterface(mockCtrl) state.collectiblesMock = &mockCollectiblesManager{} state.chainClient = transactions.NewMockChainClient() + state.rpcClient = mock_rpcclient.NewMockClientInterface(mockCtrl) + state.rpcClient.EXPECT().AbstractEthClient(gomock.Any()).DoAndReturn(func(chainID common.ChainID) (chain.BatchCallClient, error) { + return state.chainClient.AbstractEthClient(chainID) + }).AnyTimes() // Ensure we process pending transactions as needed, only once pendingCheckInterval := time.Second - state.pendingTracker = transactions.NewPendingTxTracker(db, state.chainClient, nil, state.eventFeed, pendingCheckInterval) + state.pendingTracker = transactions.NewPendingTxTracker(db, state.rpcClient, nil, state.eventFeed, pendingCheckInterval) state.service = NewService(db, accountsDB, state.tokenMock, state.collectiblesMock, state.eventFeed, state.pendingTracker) state.service.debounceDuration = 0 state.close = func() { require.NoError(tb, state.pendingTracker.Stop()) require.NoError(tb, db.Close()) + defer mockCtrl.Finish() } return state @@ -168,13 +147,13 @@ func TestService_UpdateCollectibleInfo(t *testing.T) { sub := state.eventFeed.Subscribe(ch) // Expect one call for the fungible token - state.tokenMock.On("LookupTokenIdentity", uint64(5), eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"), false).Return( + state.tokenMock.EXPECT().LookupTokenIdentity(uint64(5), eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"), false).Return( &token.Token{ ChainID: 5, Address: eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"), Symbol: "STT", - }, false, - ).Once() + }, + ).Times(1) state.collectiblesMock.On("FetchAssetsByCollectibleUniqueID", []thirdparty.CollectibleUniqueID{ { ContractID: thirdparty.ContractID{ @@ -296,21 +275,21 @@ func setupTransactions(t *testing.T, state testState, txCount int, testTxs []tra allAddresses = append(append(allAddresses, fromTrs...), toTrs...) - state.tokenMock.On("LookupTokenIdentity", mock.Anything, mock.Anything, mock.Anything).Return( + state.tokenMock.EXPECT().LookupTokenIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return( &token.Token{ ChainID: 5, Address: eth.Address{}, Symbol: "ETH", - }, true, - ).Times(0) + }, + ).AnyTimes() - state.tokenMock.On("LookupToken", mock.Anything, mock.Anything).Return( + state.tokenMock.EXPECT().LookupToken(gomock.Any(), gomock.Any()).Return( &token.Token{ ChainID: 5, Address: eth.Address{}, Symbol: "ETH", }, true, - ).Times(0) + ).AnyTimes() return allAddresses, pendings, ch, func() { sub.Unsubscribe() diff --git a/services/wallet/api.go b/services/wallet/api.go index 61edd5f5a..76df13ef5 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -55,7 +55,7 @@ func NewAPI(s *Service) *API { erc1155Transfer := pathprocessor.NewERC1155Processor(rpcClient, transactor) router.AddPathProcessor(erc1155Transfer) - hop := pathprocessor.NewHopBridgeProcessor(rpcClient, transactor, tokenManager) + hop := pathprocessor.NewHopBridgeProcessor(rpcClient, transactor, tokenManager, rpcClient.NetworkManager) router.AddPathProcessor(hop) if featureFlags.EnableCelerBridge { diff --git a/services/wallet/router/pathprocessor/processor_bridge_hop.go b/services/wallet/router/pathprocessor/processor_bridge_hop.go index 6943098be..7379f0760 100644 --- a/services/wallet/router/pathprocessor/processor_bridge_hop.go +++ b/services/wallet/router/pathprocessor/processor_bridge_hop.go @@ -31,6 +31,7 @@ import ( "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc/chain" + "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/wallet/bigint" walletCommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/thirdparty" @@ -106,20 +107,22 @@ func (bf *BonderFee) UnmarshalJSON(data []byte) error { } type HopBridgeProcessor struct { - transactor transactions.TransactorIface - httpClient *thirdparty.HTTPClient - tokenManager *token.Manager - contractMaker *contracts.ContractMaker - bonderFee *sync.Map // [fromChainName-toChainName]BonderFee + transactor transactions.TransactorIface + httpClient *thirdparty.HTTPClient + tokenManager *token.Manager + contractMaker *contracts.ContractMaker + networkManager network.ManagerInterface + bonderFee *sync.Map // [fromChainName-toChainName]BonderFee } -func NewHopBridgeProcessor(rpcClient *rpc.Client, transactor transactions.TransactorIface, tokenManager *token.Manager) *HopBridgeProcessor { +func NewHopBridgeProcessor(rpcClient rpc.ClientInterface, transactor transactions.TransactorIface, tokenManager *token.Manager, networkManager network.ManagerInterface) *HopBridgeProcessor { return &HopBridgeProcessor{ - contractMaker: &contracts.ContractMaker{RPCClient: rpcClient}, - httpClient: thirdparty.NewHTTPClient(), - transactor: transactor, - tokenManager: tokenManager, - bonderFee: &sync.Map{}, + contractMaker: &contracts.ContractMaker{RPCClient: rpcClient}, + httpClient: thirdparty.NewHTTPClient(), + transactor: transactor, + tokenManager: tokenManager, + networkManager: networkManager, + bonderFee: &sync.Map{}, } } @@ -299,7 +302,7 @@ func (h *HopBridgeProcessor) GetContractAddress(params ProcessorInputParams) (co } func (h *HopBridgeProcessor) sendOrBuild(sendArgs *MultipathProcessorTxArgs, signerFn bind.SignerFn) (tx *ethTypes.Transaction, err error) { - fromChain := h.contractMaker.RPCClient.NetworkManager.Find(sendArgs.ChainID) + fromChain := h.networkManager.Find(sendArgs.ChainID) if fromChain == nil { return tx, fmt.Errorf("ChainID not supported %d", sendArgs.ChainID) } diff --git a/services/wallet/router/pathprocessor/processor_test.go b/services/wallet/router/pathprocessor/processor_test.go index 5c5bce8eb..5cd560a4c 100644 --- a/services/wallet/router/pathprocessor/processor_test.go +++ b/services/wallet/router/pathprocessor/processor_test.go @@ -309,7 +309,7 @@ func TestPathProcessors(t *testing.T) { if processorName == ProcessorTransferName { processor = NewTransferProcessor(nil, nil) } else if processorName == ProcessorBridgeHopName { - processor = NewHopBridgeProcessor(nil, nil, nil) + processor = NewHopBridgeProcessor(nil, nil, nil, nil) } else if processorName == ProcessorSwapParaswapName { processor = NewSwapParaswapProcessor(nil, nil, nil) } diff --git a/services/wallet/router/router_v2_test.go b/services/wallet/router/router_v2_test.go index d1aa75943..95c12df75 100644 --- a/services/wallet/router/router_v2_test.go +++ b/services/wallet/router/router_v2_test.go @@ -223,7 +223,7 @@ func setupRouter(t *testing.T) (*Router, func()) { erc1155Transfer := pathprocessor.NewERC1155Processor(nil, nil) router.AddPathProcessor(erc1155Transfer) - hop := pathprocessor.NewHopBridgeProcessor(nil, nil, nil) + hop := pathprocessor.NewHopBridgeProcessor(nil, nil, nil, nil) router.AddPathProcessor(hop) paraswap := pathprocessor.NewSwapParaswapProcessor(nil, nil, nil) diff --git a/services/wallet/token/balancefetcher/balance_fetcher.go b/services/wallet/token/balancefetcher/balance_fetcher.go index bd3e3fc55..76e0a083b 100644 --- a/services/wallet/token/balancefetcher/balance_fetcher.go +++ b/services/wallet/token/balancefetcher/balance_fetcher.go @@ -18,7 +18,7 @@ import ( "github.com/status-im/status-go/services/wallet/async" ) -var nativeChainAddress = common.HexToAddress("0x") +var NativeChainAddress = common.HexToAddress("0x") var requestTimeout = 20 * time.Second const ( @@ -26,7 +26,6 @@ const ( ) type BalanceFetcher interface { - FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error) GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) @@ -44,7 +43,7 @@ func NewDefaultBalanceFetcher(contractMaker contracts.ContractMakerIface) *Defau } } -func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { +func (bf *DefaultBalanceFetcher) fetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { var ( group = async.NewAtomicGroup(parent) mu sync.Mutex @@ -61,14 +60,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c } for token, balance := range tokenBalance { - if _, ok := balances[account][token]; !ok { - zeroHex := hexutil.Big(*big.NewInt(0)) - balances[account][token] = &zeroHex - } - - sum := big.NewInt(0).Add(balances[account][token].ToInt(), balance.ToInt()) - sumHex := hexutil.Big(*sum) - balances[account][token] = &sumHex + balances[account][token] = balance } } } @@ -79,12 +71,10 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c return nil, err } - atBlock := atBlocks[client.NetworkID()] - fetchChainBalance := false for _, token := range tokens { - if token == nativeChainAddress { + if token == NativeChainAddress { fetchChainBalance = true } } @@ -117,7 +107,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c if atBlock == nil || big.NewInt(int64(availableAtBlock)).Cmp(atBlock) < 0 { accTokenBalance, err = bf.FetchTokenBalancesWithScanContract(ctx, ethScanContract, account, chunk, atBlock) } else { - accTokenBalance, err = bf.FetchTokenBalancesWithTokenContracts(ctx, client, account, chunk, atBlock) + accTokenBalance, err = bf.fetchTokenBalancesWithTokenContracts(ctx, client, account, chunk, atBlock) } if err != nil { @@ -138,7 +128,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c return balances, group.Error() } -func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, accounts []common.Address, ethScanContract *ethscan.BalanceScanner, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { +func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, accounts []common.Address, ethScanContract ethscan.BalanceScannerIface, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big) ctx, cancel := context.WithTimeout(parent, requestTimeout) @@ -160,13 +150,13 @@ func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, acco accTokenBalance[account] = make(map[common.Address]*hexutil.Big) } - accTokenBalance[account][nativeChainAddress] = (*hexutil.Big)(balance) + accTokenBalance[account][NativeChainAddress] = (*hexutil.Big)(balance) } return accTokenBalance, nil } -func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.Context, ethScanContract *ethscan.BalanceScanner, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { +func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.Context, ethScanContract ethscan.BalanceScannerIface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big) res, err := ethScanContract.TokensBalance(&bind.CallOpts{ Context: ctx, @@ -178,7 +168,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context. } if len(res) != len(chunk) { - log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete") + log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete", "expected", len(chunk), "got", len(res)) return nil, errors.New("response not complete") } @@ -198,7 +188,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context. return accTokenBalance, nil } -func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithTokenContracts(ctx context.Context, client chain.ClientInterface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { +func (bf *DefaultBalanceFetcher) fetchTokenBalancesWithTokenContracts(ctx context.Context, client chain.ClientInterface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big) for _, token := range chunk { balance, err := bf.GetTokenBalanceAt(ctx, client, account, token, atBlock) @@ -220,7 +210,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithTokenContracts(ctx contex } func (bf *DefaultBalanceFetcher) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error) { - caller, err := ierc20.NewIERC20Caller(token, client) + caller, err := bf.contractMaker.NewERC20Caller(client.NetworkID(), token) if err != nil { return nil, err } @@ -270,7 +260,7 @@ func (bf *DefaultBalanceFetcher) GetChainBalance(ctx context.Context, client cha } func (bf *DefaultBalanceFetcher) GetBalance(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address) (*big.Int, error) { - if token == nativeChainAddress { + if token == NativeChainAddress { return bf.GetChainBalance(ctx, client, account) } @@ -305,12 +295,15 @@ func (bf *DefaultBalanceFetcher) GetBalancesAtByChain(parent context.Context, cl // Keep the reference to the client. DO NOT USE A LOOP, the client will be overridden in the coroutine client := clients[clientIdx] - balances, err := bf.FetchBalancesForChain(parent, client, accounts, tokens, atBlocks) - if err != nil { - return nil, err - } + group.Add(func(parent context.Context) error { + balances, err := bf.fetchBalancesForChain(parent, client, accounts, tokens, atBlocks[client.NetworkID()]) + if err != nil { + return nil + } - updateBalance(client.NetworkID(), balances) + updateBalance(client.NetworkID(), balances) + return nil + }) } select { case <-group.WaitAsync(): diff --git a/services/wallet/token/balancefetcher/balance_fetcher_test.go b/services/wallet/token/balancefetcher/balance_fetcher_test.go new file mode 100644 index 000000000..e4ce49d72 --- /dev/null +++ b/services/wallet/token/balancefetcher/balance_fetcher_test.go @@ -0,0 +1,386 @@ +package balancefetcher + +import ( + "context" + "errors" + "math/big" + "os" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/log" + "github.com/status-im/status-go/contracts/ethscan" + "github.com/status-im/status-go/params" + + mock_contracts "github.com/status-im/status-go/contracts/mock" + "github.com/status-im/status-go/rpc/chain" + mock_client "github.com/status-im/status-go/rpc/chain/mock/client" + mock_network "github.com/status-im/status-go/rpc/network/mock" + w_common "github.com/status-im/status-go/services/wallet/common" +) + +type FakeBalanceScanner struct { + etherBalances map[common.Address]*big.Int + tokenBalances map[common.Address]map[common.Address]*big.Int +} + +func (f *FakeBalanceScanner) EtherBalances(opts *bind.CallOpts, addresses []common.Address) ([]ethscan.BalanceScannerResult, error) { + result := make([]ethscan.BalanceScannerResult, 0, len(addresses)) + + for _, address := range addresses { + balance, ok := f.etherBalances[address] + if !ok { + result = append(result, ethscan.BalanceScannerResult{ + Success: false, + Data: []byte{}, + }) + } else { + result = append(result, ethscan.BalanceScannerResult{ + Success: true, + Data: balance.Bytes(), + }) + } + } + + return result, nil +} + +func (f *FakeBalanceScanner) TokenBalances(opts *bind.CallOpts, addresses []common.Address, tokenAddress common.Address) ([]ethscan.BalanceScannerResult, error) { + result := make([]ethscan.BalanceScannerResult, 0, len(addresses)) + + for _, address := range addresses { + balances, ok := f.tokenBalances[address] + if !ok { + result = append(result, ethscan.BalanceScannerResult{ + Success: false, + Data: []byte{}, + }) + } else { + balance, ok := balances[tokenAddress] + if !ok { + result = append(result, ethscan.BalanceScannerResult{ + Success: false, + Data: []byte{}, + }) + } else { + result = append(result, ethscan.BalanceScannerResult{ + Success: true, + Data: balance.Bytes(), + }) + } + } + } + + return result, nil +} + +func (f *FakeBalanceScanner) TokensBalance(opts *bind.CallOpts, owner common.Address, contracts []common.Address) ([]ethscan.BalanceScannerResult, error) { + result := make([]ethscan.BalanceScannerResult, 0, len(contracts)) + + for _, contract := range contracts { + balances, ok := f.tokenBalances[owner] + if !ok { + result = append(result, ethscan.BalanceScannerResult{ + Success: false, + Data: []byte{}, + }) + } else { + balance, ok := balances[contract] + if !ok { + result = append(result, ethscan.BalanceScannerResult{ + Success: false, + Data: []byte{}, + }) + } else { + result = append(result, ethscan.BalanceScannerResult{ + Success: true, + Data: balance.Bytes(), + }) + } + } + } + + return result, nil +} + +type FakeERC20Caller struct { + accountBalances map[common.Address]*big.Int +} + +func (f *FakeERC20Caller) BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error) { + balance, ok := f.accountBalances[account] + if !ok { + return nil, errors.New("account not found") + } + + return balance, nil +} + +func (f *FakeERC20Caller) Name(opts *bind.CallOpts) (string, error) { + return "TestToken", nil +} + +func (f *FakeERC20Caller) Symbol(opts *bind.CallOpts) (string, error) { + return "TT", nil +} + +func (f *FakeERC20Caller) Decimals(opts *bind.CallOpts) (uint8, error) { + return 18, nil +} + +func TestBalanceFetcherFetchBalancesForChainNativeAndTokensWithScanContract(t *testing.T) { + log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true)))) + + ctx := context.Background() + accounts := []common.Address{ + common.HexToAddress("0x1234567890abcdef"), + common.HexToAddress("0xabcdef1234567890"), + } + tokens := []common.Address{ + NativeChainAddress, + common.HexToAddress("0xabcdef1234567890"), + common.HexToAddress("0x0987654321fedcba"), + } + var atBlock *big.Int // nil triggers using a scan contract + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + networkManager := mock_network.NewMockManagerInterface(ctrl) + networkManager.EXPECT().GetAll().Return([]*params.Network{ + { + ChainID: w_common.EthereumMainnet, + }, + }, nil).AnyTimes() + + chainClient := mock_client.NewMockClientInterface(ctrl) + chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes() + expectedEthBalances := map[common.Address]*big.Int{ + accounts[0]: big.NewInt(100), + accounts[1]: big.NewInt(200), + } + + expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{ + accounts[0]: { + tokens[1]: big.NewInt(1000), + tokens[2]: big.NewInt(2000), + }, + accounts[1]: { + tokens[1]: big.NewInt(3000), + tokens[2]: big.NewInt(4000), + }, + } + + expectedBalances := map[common.Address]map[common.Address]*hexutil.Big{ + accounts[0]: { + tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[0]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[1]]), + tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[2]]), + }, + accounts[1]: { + tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[1]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[1]]), + tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[2]]), + }, + } + + contractMaker := mock_contracts.NewMockContractMakerIface(ctrl) + contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{ + etherBalances: expectedEthBalances, + tokenBalances: expectedTokenBalances, + }, uint(0), nil).AnyTimes() + bf := NewDefaultBalanceFetcher(contractMaker) + + // Fetch native balances and token balances using scan contract + balances, err := bf.fetchBalancesForChain(ctx, chainClient, accounts, tokens, atBlock) + + require.NoError(t, err) + require.Equal(t, expectedBalances, balances) +} + +func TestBalanceFetcherFetchBalancesForChainTokensWithTokenContracts(t *testing.T) { + log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true)))) + + ctx := context.Background() + accounts := []common.Address{ + common.HexToAddress("0x1234567890abcdef"), + common.HexToAddress("0xabcdef1234567890"), + } + tokens := []common.Address{ + common.HexToAddress("0xabcdef1234567890"), + common.HexToAddress("0x0987654321fedcba"), + } + atBlock := big.NewInt(0) // will trigger using a token contract + + ctrl := gomock.NewController(t) + networkManager := mock_network.NewMockManagerInterface(ctrl) + networkManager.EXPECT().GetAll().Return([]*params.Network{ + { + ChainID: w_common.EthereumMainnet, + }, + }, nil).AnyTimes() + + chainClient := mock_client.NewMockClientInterface(ctrl) + chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes() + chainClient.EXPECT().CallContract(gomock.Any(), gomock.Any(), atBlock).Return([]byte{}, nil).AnyTimes() + + expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{ + tokens[0]: { + accounts[0]: big.NewInt(1000), + accounts[1]: big.NewInt(2000), + }, + tokens[1]: { + accounts[0]: big.NewInt(3000), + accounts[1]: big.NewInt(4000), + }, + } + + expectedBalances := map[common.Address]map[common.Address]*hexutil.Big{ + accounts[0]: { + tokens[0]: (*hexutil.Big)(expectedTokenBalances[tokens[0]][accounts[0]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[tokens[1]][accounts[0]]), + }, + accounts[1]: { + tokens[0]: (*hexutil.Big)(expectedTokenBalances[tokens[0]][accounts[1]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[tokens[1]][accounts[1]]), + }, + } + + contractMaker := mock_contracts.NewMockContractMakerIface(ctrl) + contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{}, uint(0), nil).Times(1) + for _, token := range tokens { + contractMaker.EXPECT().NewERC20Caller(w_common.EthereumMainnet, token).Return(&FakeERC20Caller{ + accountBalances: expectedTokenBalances[token], + }, nil).AnyTimes() + } + bf := NewDefaultBalanceFetcher(contractMaker) + + // Fetch token balances using tokens contracts + balances, err := bf.fetchBalancesForChain(ctx, chainClient, accounts, tokens, atBlock) + + require.NoError(t, err) + require.Equal(t, expectedBalances, balances) +} + +func TestBalanceFetcherGetBalancesAtByChain(t *testing.T) { + log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true)))) + + ctx := context.Background() + accounts := []common.Address{ + common.HexToAddress("0x1234567890abcdef"), + common.HexToAddress("0xabcdef1234567890"), + } + tokens := []common.Address{ + NativeChainAddress, + common.HexToAddress("0xabcdef1234567890"), + common.HexToAddress("0x0987654321fedcba"), + } + var atBlock *big.Int // nil triggers using a scan contract + atBlocks := map[uint64]*big.Int{ + w_common.EthereumMainnet: atBlock, // nil triggers using a scan contract + w_common.OptimismMainnet: atBlock, // nil triggers using a scan contract + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + networkManager := mock_network.NewMockManagerInterface(ctrl) + networkManager.EXPECT().GetAll().Return([]*params.Network{ + { + ChainID: w_common.EthereumMainnet, + }, + { + ChainID: w_common.OptimismMainnet, + }, + { + ChainID: w_common.ArbitrumMainnet, + }, + }, nil).AnyTimes() + + chainClient := mock_client.NewMockClientInterface(ctrl) + chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes() + chainClientOpt := mock_client.NewMockClientInterface(ctrl) + chainClientOpt.EXPECT().NetworkID().Return(w_common.OptimismMainnet).AnyTimes() + chainClientArb := mock_client.NewMockClientInterface(ctrl) + chainClientArb.EXPECT().NetworkID().Return(w_common.ArbitrumMainnet).AnyTimes() + chainClients := map[uint64]chain.ClientInterface{ + w_common.EthereumMainnet: chainClient, + w_common.OptimismMainnet: chainClientOpt, + w_common.ArbitrumMainnet: chainClientArb, + } + + expectedEthBalances := map[common.Address]*big.Int{ + accounts[0]: big.NewInt(100), + accounts[1]: big.NewInt(200), + } + + expectedEthOptBalances := map[common.Address]*big.Int{ + accounts[0]: big.NewInt(300), + accounts[1]: big.NewInt(400), + } + + expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{ + accounts[0]: { + tokens[1]: big.NewInt(1000), + tokens[2]: big.NewInt(2000), + }, + accounts[1]: { + tokens[1]: big.NewInt(3000), + tokens[2]: big.NewInt(4000), + }, + } + + expectedTokenOptBalances := map[common.Address]map[common.Address]*big.Int{ + accounts[0]: { + tokens[1]: big.NewInt(5000), + tokens[2]: big.NewInt(6000), + }, + } + + expectedBalances := map[uint64]map[common.Address]map[common.Address]*hexutil.Big{ + w_common.EthereumMainnet: { + accounts[0]: { + tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[0]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[1]]), + tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[2]]), + }, + accounts[1]: { + tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[1]]), + tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[1]]), + tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[2]]), + }, + }, + w_common.OptimismMainnet: { + accounts[0]: { + tokens[0]: (*hexutil.Big)(expectedEthOptBalances[accounts[0]]), + tokens[1]: (*hexutil.Big)(expectedTokenOptBalances[accounts[0]][tokens[1]]), + tokens[2]: (*hexutil.Big)(expectedTokenOptBalances[accounts[0]][tokens[2]]), + }, + accounts[1]: { + tokens[0]: (*hexutil.Big)(expectedEthOptBalances[accounts[1]]), + }, + }, + } + + contractMaker := mock_contracts.NewMockContractMakerIface(ctrl) + contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{ + etherBalances: expectedEthBalances, + tokenBalances: expectedTokenBalances, + }, uint(0), nil).AnyTimes() + contractMaker.EXPECT().NewEthScan(w_common.OptimismMainnet).Return(&FakeBalanceScanner{ + etherBalances: expectedEthOptBalances, + tokenBalances: expectedTokenOptBalances, + }, uint(0), nil).AnyTimes() + contractMaker.EXPECT().NewEthScan(w_common.ArbitrumMainnet).Return(nil, uint(0), errors.New("no scan contract")).AnyTimes() + + bf := NewDefaultBalanceFetcher(contractMaker) + + // Fetch native balances and token balances using scan contract + balances, err := bf.GetBalancesAtByChain(ctx, chainClients, accounts, tokens, atBlocks) + + require.NoError(t, err) + require.Equal(t, expectedBalances, balances) +} diff --git a/services/wallet/token/mock/token/tokenmanager.go b/services/wallet/token/mock/token/tokenmanager.go index bf9c95e94..fe0886cc2 100644 --- a/services/wallet/token/mock/token/tokenmanager.go +++ b/services/wallet/token/mock/token/tokenmanager.go @@ -9,82 +9,82 @@ import ( big "math/big" reflect "reflect" - gomock "github.com/golang/mock/gomock" - common "github.com/ethereum/go-ethereum/common" hexutil "github.com/ethereum/go-ethereum/common/hexutil" + gomock "github.com/golang/mock/gomock" chain "github.com/status-im/status-go/rpc/chain" token "github.com/status-im/status-go/services/wallet/token" ) -// MockManagerInterface is a mock of ManagerInterface interface +// MockManagerInterface is a mock of ManagerInterface interface. type MockManagerInterface struct { ctrl *gomock.Controller recorder *MockManagerInterfaceMockRecorder } -// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface +// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface. type MockManagerInterfaceMockRecorder struct { mock *MockManagerInterface } -// NewMockManagerInterface creates a new mock instance +// NewMockManagerInterface creates a new mock instance. func NewMockManagerInterface(ctrl *gomock.Controller) *MockManagerInterface { mock := &MockManagerInterface{ctrl: ctrl} mock.recorder = &MockManagerInterfaceMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockManagerInterface) EXPECT() *MockManagerInterfaceMockRecorder { return m.recorder } -// LookupTokenIdentity mocks base method -func (m *MockManagerInterface) LookupTokenIdentity(chainID uint64, address common.Address, native bool) *token.Token { +// FetchBalancesForChain mocks base method. +func (m *MockManagerInterface) FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LookupTokenIdentity", chainID, address, native) - ret0, _ := ret[0].(*token.Token) - return ret0 -} - -// LookupTokenIdentity indicates an expected call of LookupTokenIdentity -func (mr *MockManagerInterfaceMockRecorder) LookupTokenIdentity(chainID, address, native interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupTokenIdentity", reflect.TypeOf((*MockManagerInterface)(nil).LookupTokenIdentity), chainID, address, native) -} - -// LookupToken mocks base method -func (m *MockManagerInterface) LookupToken(chainID *uint64, tokenSymbol string) (*token.Token, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LookupToken", chainID, tokenSymbol) - ret0, _ := ret[0].(*token.Token) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// LookupToken indicates an expected call of LookupToken -func (mr *MockManagerInterfaceMockRecorder) LookupToken(chainID, tokenSymbol interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupToken", reflect.TypeOf((*MockManagerInterface)(nil).LookupToken), chainID, tokenSymbol) -} - -// GetTokensByChainIDs mocks base method -func (m *MockManagerInterface) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTokensByChainIDs", chainIDs) - ret0, _ := ret[0].([]*token.Token) + ret := m.ctrl.Call(m, "FetchBalancesForChain", parent, client, accounts, tokens, atBlock) + ret0, _ := ret[0].(map[common.Address]map[common.Address]*hexutil.Big) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetTokensByChainIDs indicates an expected call of GetTokensByChainIDs -func (mr *MockManagerInterfaceMockRecorder) GetTokensByChainIDs(chainIDs interface{}) *gomock.Call { +// FetchBalancesForChain indicates an expected call of FetchBalancesForChain. +func (mr *MockManagerInterfaceMockRecorder) FetchBalancesForChain(parent, client, accounts, tokens, atBlock interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokensByChainIDs", reflect.TypeOf((*MockManagerInterface)(nil).GetTokensByChainIDs), chainIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchBalancesForChain", reflect.TypeOf((*MockManagerInterface)(nil).FetchBalancesForChain), parent, client, accounts, tokens, atBlock) } -// GetBalancesByChain mocks base method +// GetBalance mocks base method. +func (m *MockManagerInterface) GetBalance(ctx context.Context, client chain.ClientInterface, account, token common.Address) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalance", ctx, client, account, token) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBalance indicates an expected call of GetBalance. +func (mr *MockManagerInterfaceMockRecorder) GetBalance(ctx, client, account, token interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetBalance), ctx, client, account, token) +} + +// GetBalancesAtByChain mocks base method. +func (m *MockManagerInterface) GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalancesAtByChain", parent, clients, accounts, tokens, atBlocks) + ret0, _ := ret[0].(map[uint64]map[common.Address]map[common.Address]*hexutil.Big) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBalancesAtByChain indicates an expected call of GetBalancesAtByChain. +func (mr *MockManagerInterfaceMockRecorder) GetBalancesAtByChain(parent, clients, accounts, tokens, atBlocks interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalancesAtByChain", reflect.TypeOf((*MockManagerInterface)(nil).GetBalancesAtByChain), parent, clients, accounts, tokens, atBlocks) +} + +// GetBalancesByChain mocks base method. func (m *MockManagerInterface) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBalancesByChain", parent, clients, accounts, tokens) @@ -93,13 +93,43 @@ func (m *MockManagerInterface) GetBalancesByChain(parent context.Context, client return ret0, ret1 } -// GetBalancesByChain indicates an expected call of GetBalancesByChain +// GetBalancesByChain indicates an expected call of GetBalancesByChain. func (mr *MockManagerInterfaceMockRecorder) GetBalancesByChain(parent, clients, accounts, tokens interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalancesByChain", reflect.TypeOf((*MockManagerInterface)(nil).GetBalancesByChain), parent, clients, accounts, tokens) } -// GetTokenHistoricalBalance mocks base method +// GetChainBalance mocks base method. +func (m *MockManagerInterface) GetChainBalance(ctx context.Context, client chain.ClientInterface, account common.Address) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChainBalance", ctx, client, account) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChainBalance indicates an expected call of GetChainBalance. +func (mr *MockManagerInterfaceMockRecorder) GetChainBalance(ctx, client, account interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChainBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetChainBalance), ctx, client, account) +} + +// GetTokenBalanceAt mocks base method. +func (m *MockManagerInterface) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account, token common.Address, blockNumber *big.Int) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTokenBalanceAt", ctx, client, account, token, blockNumber) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTokenBalanceAt indicates an expected call of GetTokenBalanceAt. +func (mr *MockManagerInterfaceMockRecorder) GetTokenBalanceAt(ctx, client, account, token, blockNumber interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenBalanceAt", reflect.TypeOf((*MockManagerInterface)(nil).GetTokenBalanceAt), ctx, client, account, token, blockNumber) +} + +// GetTokenHistoricalBalance mocks base method. func (m *MockManagerInterface) GetTokenHistoricalBalance(account common.Address, chainID uint64, symbol string, timestamp int64) (*big.Int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTokenHistoricalBalance", account, chainID, symbol, timestamp) @@ -108,8 +138,52 @@ func (m *MockManagerInterface) GetTokenHistoricalBalance(account common.Address, return ret0, ret1 } -// GetTokenHistoricalBalance indicates an expected call of GetTokenHistoricalBalance +// GetTokenHistoricalBalance indicates an expected call of GetTokenHistoricalBalance. func (mr *MockManagerInterfaceMockRecorder) GetTokenHistoricalBalance(account, chainID, symbol, timestamp interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenHistoricalBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetTokenHistoricalBalance), account, chainID, symbol, timestamp) } + +// GetTokensByChainIDs mocks base method. +func (m *MockManagerInterface) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTokensByChainIDs", chainIDs) + ret0, _ := ret[0].([]*token.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTokensByChainIDs indicates an expected call of GetTokensByChainIDs. +func (mr *MockManagerInterfaceMockRecorder) GetTokensByChainIDs(chainIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokensByChainIDs", reflect.TypeOf((*MockManagerInterface)(nil).GetTokensByChainIDs), chainIDs) +} + +// LookupToken mocks base method. +func (m *MockManagerInterface) LookupToken(chainID *uint64, tokenSymbol string) (*token.Token, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LookupToken", chainID, tokenSymbol) + ret0, _ := ret[0].(*token.Token) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// LookupToken indicates an expected call of LookupToken. +func (mr *MockManagerInterfaceMockRecorder) LookupToken(chainID, tokenSymbol interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupToken", reflect.TypeOf((*MockManagerInterface)(nil).LookupToken), chainID, tokenSymbol) +} + +// LookupTokenIdentity mocks base method. +func (m *MockManagerInterface) LookupTokenIdentity(chainID uint64, address common.Address, native bool) *token.Token { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LookupTokenIdentity", chainID, address, native) + ret0, _ := ret[0].(*token.Token) + return ret0 +} + +// LookupTokenIdentity indicates an expected call of LookupTokenIdentity. +func (mr *MockManagerInterfaceMockRecorder) LookupTokenIdentity(chainID, address, native interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupTokenIdentity", reflect.TypeOf((*MockManagerInterface)(nil).LookupTokenIdentity), chainID, address, native) +} diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index 438ab5172..542eefbc3 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -97,9 +97,9 @@ type ManagerInterface interface { type Manager struct { balancefetcher.BalanceFetcher db *sql.DB - RPCClient *rpc.Client + RPCClient rpc.ClientInterface ContractMaker *contracts.ContractMaker - networkManager *network.Manager + networkManager network.ManagerInterface stores []store // Set on init, not changed afterwards communityTokensDB *communitytokensdatabase.Database communityManager *community.Manager @@ -130,7 +130,7 @@ func mergeTokens(sliceLists [][]*Token) []*Token { return res } -func prepareTokens(networkManager *network.Manager, stores []store) []*Token { +func prepareTokens(networkManager network.ManagerInterface, stores []store) []*Token { tokens := make([]*Token, 0) networks, err := networkManager.GetAll() @@ -158,9 +158,9 @@ func prepareTokens(networkManager *network.Manager, stores []store) []*Token { func NewTokenManager( db *sql.DB, - RPCClient *rpc.Client, + RPCClient rpc.ClientInterface, communityManager *community.Manager, - networkManager *network.Manager, + networkManager network.ManagerInterface, appDB *sql.DB, mediaServer *server.MediaServer, walletFeed *event.Feed, @@ -173,10 +173,10 @@ func NewTokenManager( tokens := prepareTokens(networkManager, stores) return &Manager{ - BalanceFetcher: balancefetcher.NewDefaultBalanceFetcher(maker), - db: db, - RPCClient: RPCClient, - // ContractMaker: maker, + BalanceFetcher: balancefetcher.NewDefaultBalanceFetcher(maker), + db: db, + RPCClient: RPCClient, + ContractMaker: maker, networkManager: networkManager, communityManager: communityManager, stores: stores, diff --git a/services/wallet/token/token_test.go b/services/wallet/token/token_test.go index 2440793d5..bc72ef59c 100644 --- a/services/wallet/token/token_test.go +++ b/services/wallet/token/token_test.go @@ -24,6 +24,7 @@ import ( "github.com/status-im/status-go/services/accounts/accountsevent" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/community" + "github.com/status-im/status-go/t/helpers" "github.com/status-im/status-go/t/utils" "github.com/status-im/status-go/transactions/fake" diff --git a/services/wallet/transfer/commands_sequential_test.go b/services/wallet/transfer/commands_sequential_test.go index a21507ff1..e85448847 100644 --- a/services/wallet/transfer/commands_sequential_test.go +++ b/services/wallet/transfer/commands_sequential_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/pkg/errors" "github.com/stretchr/testify/mock" "golang.org/x/exp/slices" // since 1.21, this is in the standard library @@ -29,6 +30,8 @@ import ( "github.com/status-im/status-go/contracts/ierc20" ethtypes "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/rpc/chain" + mock_client "github.com/status-im/status-go/rpc/chain/mock/client" + mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" "github.com/status-im/status-go/server" "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/balance" @@ -1295,6 +1298,7 @@ func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) type MockChainClient struct { mock.Mock + mock_client.MockClientInterface clients map[walletcommon.ChainID]*MockETHClient } @@ -1360,7 +1364,11 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) { address := common.HexToAddress("0x1234") chainClient := newMockChainClient() - tracker := transactions.NewPendingTxTracker(db, chainClient, nil, &event.Feed{}, transactions.PendingCheckInterval) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rpcClient := mock_rpcclient.NewMockClientInterface(ctrl) + rpcClient.EXPECT().AbstractEthClient(tc.NetworkID()).Return(chainClient, nil).AnyTimes() + tracker := transactions.NewPendingTxTracker(db, rpcClient, nil, &event.Feed{}, transactions.PendingCheckInterval) accDB, err := accounts.NewDB(appdb) require.NoError(t, err) diff --git a/services/wallet/transfer/transaction_manager_multitransaction_test.go b/services/wallet/transfer/transaction_manager_multitransaction_test.go index be7f0c998..4f8b1187c 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction_test.go +++ b/services/wallet/transfer/transaction_manager_multitransaction_test.go @@ -16,6 +16,8 @@ import ( "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/rpc/chain" + mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" wallet_common "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/router/pathprocessor" "github.com/status-im/status-go/services/wallet/router/pathprocessor/mock_pathprocessor" @@ -174,15 +176,21 @@ func TestSendTransactionsETHFailOnTransactor(t *testing.T) { func TestWatchTransaction(t *testing.T) { tm, _, _ := setupTransactionManager(t) - chainID := uint64(1) + chainID := uint64(777) // GeneratePendingTransaction uses this chainID pendingTxTimeout = 2 * time.Millisecond walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) chainClient := transactions.NewMockChainClient() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rpcClient := mock_rpcclient.NewMockClientInterface(ctrl) + rpcClient.EXPECT().AbstractEthClient(wallet_common.ChainID(chainID)).DoAndReturn(func(chainID wallet_common.ChainID) (chain.BatchCallClient, error) { + return chainClient.AbstractEthClient(chainID) + }).AnyTimes() eventFeed := &event.Feed{} // For now, pending tracker is not interface, so we have to use a real one - tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout) + tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, rpcClient, nil, eventFeed, pendingTxTimeout) tm.eventFeed = eventFeed // Create a context with timeout @@ -219,16 +227,22 @@ func TestWatchTransaction(t *testing.T) { func TestWatchTransaction_Timeout(t *testing.T) { tm, _, _ := setupTransactionManager(t) - chainID := uint64(1) + chainID := uint64(777) // GeneratePendingTransaction uses this chainID transactionHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") pendingTxTimeout = 2 * time.Millisecond walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) chainClient := transactions.NewMockChainClient() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rpcClient := mock_rpcclient.NewMockClientInterface(gomock.NewController(t)) + rpcClient.EXPECT().AbstractEthClient(wallet_common.ChainID(chainID)).DoAndReturn(func(chainID wallet_common.ChainID) (chain.BatchCallClient, error) { + return chainClient.AbstractEthClient(chainID) + }).AnyTimes() eventFeed := &event.Feed{} // For now, pending tracker is not interface, so we have to use a real one - tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout) + tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, rpcClient, nil, eventFeed, pendingTxTimeout) tm.eventFeed = eventFeed // Create a context with timeout diff --git a/transactions/mock_transactor/transactor.go b/transactions/mock_transactor/transactor.go index 9bcdfc1af..77c4458ff 100644 --- a/transactions/mock_transactor/transactor.go +++ b/transactions/mock_transactor/transactor.go @@ -8,10 +8,9 @@ import ( big "math/big" reflect "reflect" - gomock "github.com/golang/mock/gomock" - common "github.com/ethereum/go-ethereum/common" types "github.com/ethereum/go-ethereum/core/types" + gomock "github.com/golang/mock/gomock" account "github.com/status-im/status-go/account" types0 "github.com/status-im/status-go/eth-node/types" params "github.com/status-im/status-go/params" @@ -89,7 +88,7 @@ func (mr *MockTransactorIfaceMockRecorder) EstimateGas(network, from, to, value, } // NextNonce mocks base method. -func (m *MockTransactorIface) NextNonce(rpcClient *rpc.Client, chainID uint64, from types0.Address) (uint64, error) { +func (m *MockTransactorIface) NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types0.Address) (uint64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NextNonce", rpcClient, chainID, from) ret0, _ := ret[0].(uint64) diff --git a/transactions/pendingtxtracker_test.go b/transactions/pendingtxtracker_test.go index 87cb7a088..638b2f23f 100644 --- a/transactions/pendingtxtracker_test.go +++ b/transactions/pendingtxtracker_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -16,6 +17,8 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/rpc/chain" + mock_rpcclient "github.com/status-im/status-go/rpc/mock/client" "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/walletevent" @@ -29,12 +32,21 @@ func setupTestTransactionDB(t *testing.T, checkInterval *time.Duration) (*Pendin require.NoError(t, err) chainClient := NewMockChainClient() + ctrl := gomock.NewController(t) + defer ctrl.Finish() eventFeed := &event.Feed{} pendingCheckInterval := PendingCheckInterval if checkInterval != nil { pendingCheckInterval = *checkInterval } - return NewPendingTxTracker(db, chainClient, nil, eventFeed, pendingCheckInterval), func() { + rpcClient := mock_rpcclient.NewMockClientInterface(ctrl) + rpcClient.EXPECT().EthClient(common.EthereumMainnet).Return(chainClient, nil).AnyTimes() + + // Delegate the call to the fake implementation + rpcClient.EXPECT().AbstractEthClient(gomock.Any()).DoAndReturn(func(chainID common.ChainID) (chain.BatchCallClient, error) { + return chainClient.AbstractEthClient(chainID) + }).AnyTimes() + return NewPendingTxTracker(db, rpcClient, nil, eventFeed, pendingCheckInterval), func() { require.NoError(t, db.Close()) }, chainClient, eventFeed } diff --git a/transactions/rpc_wrapper.go b/transactions/rpc_wrapper.go index 3e06e3201..87ea476bf 100644 --- a/transactions/rpc_wrapper.go +++ b/transactions/rpc_wrapper.go @@ -16,11 +16,11 @@ import ( // rpcWrapper wraps provides convenient interface for ethereum RPC APIs we need for sending transactions type rpcWrapper struct { - RPCClient *rpc.Client + RPCClient rpc.ClientInterface chainID uint64 } -func newRPCWrapper(client *rpc.Client, chainID uint64) *rpcWrapper { +func newRPCWrapper(client rpc.ClientInterface, chainID uint64) *rpcWrapper { return &rpcWrapper{RPCClient: client, chainID: chainID} } diff --git a/transactions/testhelpers.go b/transactions/testhelpers.go index d06e6c77f..b908ce5ff 100644 --- a/transactions/testhelpers.go +++ b/transactions/testhelpers.go @@ -10,6 +10,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rpc" "github.com/status-im/status-go/rpc/chain" + mock_client "github.com/status-im/status-go/rpc/chain/mock/client" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/common" @@ -21,6 +22,8 @@ type MockETHClient struct { mock.Mock } +var _ chain.BatchCallClient = (*MockETHClient)(nil) + func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { args := m.Called(ctx, b) return args.Error(0) @@ -28,10 +31,13 @@ func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) type MockChainClient struct { mock.Mock + mock_client.MockClientInterface Clients map[common.ChainID]*MockETHClient } +var _ chain.ClientInterface = (*MockChainClient)(nil) + func NewMockChainClient() *MockChainClient { return &MockChainClient{ Clients: make(map[common.ChainID]*MockETHClient), diff --git a/transactions/transactor.go b/transactions/transactor.go index 032831064..2efbb8dc3 100644 --- a/transactions/transactor.go +++ b/transactions/transactor.go @@ -46,7 +46,7 @@ func (e *ErrBadNonce) Error() string { // Transactor is an interface that defines the methods for validating and sending transactions. type TransactorIface interface { - NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, error) + NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types.Address) (uint64, error) EstimateGas(network *params.Network, from common.Address, to common.Address, value *big.Int, input []byte) (uint64, error) SendTransaction(sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) SendTransactionWithChainID(chainID uint64, sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) @@ -96,7 +96,7 @@ func (t *Transactor) SetRPC(rpcClient *rpc.Client, timeout time.Duration) { t.rpcCallTimeout = timeout } -func (t *Transactor) NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, error) { +func (t *Transactor) NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types.Address) (uint64, error) { wrapper := newRPCWrapper(rpcClient, chainID) ctx := context.Background() nonce, err := wrapper.PendingNonceAt(ctx, common.Address(from))