test(wallet)_: add unit tests for balance fetcher

replace types with interfaces where necessary to allow mocking
implement fake eth scanner and erc20 contracts
This commit is contained in:
Ivan Belyakov 2024-06-27 23:27:09 +02:00 committed by IvanBelyakoff
parent d180e19fa8
commit 3983114ae5
30 changed files with 1001 additions and 167 deletions

View File

@ -346,6 +346,8 @@ mock: ##@other Regenerate mocks
mockgen -package=mock_client -destination=rpc/chain/mock/client/client.go -source=rpc/chain/client.go
mockgen -package=mock_token -destination=services/wallet/token/mock/token/tokenmanager.go -source=services/wallet/token/token.go
mockgen -package=mock_balance_persistence -destination=services/wallet/token/mock/balance_persistence/balance_persistence.go -source=services/wallet/token/balance_persistence.go
mockgen -package=mock_network -destination=rpc/network/mock/network.go -source=rpc/network/network.go
mockgen -package=mock_rpcclient -destination=rpc/mock/client/client.go -source=rpc/client.go
docker-test: ##@tests Run tests in a docker container with golang.
docker run --privileged --rm -it -v "$(PWD):$(DOCKER_TEST_WORKDIR)" -w "$(DOCKER_TEST_WORKDIR)" $(DOCKER_TEST_IMAGE) go test ${ARGS}

View File

@ -16,15 +16,17 @@ import (
)
type ContractMakerIface interface {
NewEthScan(chainID uint64) (*ethscan.BalanceScanner, uint, error)
NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error)
NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error)
NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error)
// TODO extend with other contracts
}
type ContractMaker struct {
RPCClient *rpc.Client
RPCClient rpc.ClientInterface
}
func NewContractMaker(client *rpc.Client) (*ContractMaker, error) {
func NewContractMaker(client rpc.ClientInterface) (*ContractMaker, error) {
if client == nil {
return nil, errors.New("could not initialize ContractMaker with an rpc client")
}
@ -72,7 +74,7 @@ func (c *ContractMaker) NewUsernameRegistrar(chainID uint64, contractAddr common
)
}
func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (*ierc20.IERC20, error) {
func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error) {
backend, err := c.RPCClient.EthClient(chainID)
if err != nil {
return nil, err
@ -83,6 +85,14 @@ func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (*
backend,
)
}
func (c *ContractMaker) NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error) {
backend, err := c.RPCClient.EthClient(chainID)
if err != nil {
return nil, err
}
return ierc20.NewIERC20Caller(contractAddr, backend)
}
func (c *ContractMaker) NewSNT(chainID uint64) (*snt.SNT, error) {
contractAddr, err := snt.ContractAddress(chainID)
@ -166,7 +176,7 @@ func (c *ContractMaker) NewDirectory(chainID uint64) (*directory.Directory, erro
)
}
func (c *ContractMaker) NewEthScan(chainID uint64) (*ethscan.BalanceScanner, uint, error) {
func (c *ContractMaker) NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error) {
contractAddr, err := ethscan.ContractAddress(chainID)
if err != nil {
return nil, 0, err

View File

@ -0,0 +1,15 @@
package ethscan
import (
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
)
type BalanceScannerIface interface {
EtherBalances(opts *bind.CallOpts, addresses []common.Address) ([]BalanceScannerResult, error)
TokenBalances(opts *bind.CallOpts, addresses []common.Address, tokenAddress common.Address) ([]BalanceScannerResult, error)
TokensBalance(opts *bind.CallOpts, owner common.Address, contracts []common.Address) ([]BalanceScannerResult, error)
}
// Verify that BalanceScanner implements BalanceScannerIface. If contract changes, this will fail to compile.
var _ BalanceScannerIface = (*BalanceScanner)(nil)

View File

@ -0,0 +1,28 @@
package ierc20
import (
"math/big"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
)
type IERC20Iface interface {
BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error)
Name(opts *bind.CallOpts) (string, error)
Symbol(opts *bind.CallOpts) (string, error)
Decimals(opts *bind.CallOpts) (uint8, error)
Allowance(opts *bind.CallOpts, owner common.Address, spender common.Address) (*big.Int, error)
}
// Verify that IERC20 implements IERC20Iface. If contract changes, this will fail to compile, update interface to match.
var _ IERC20Iface = (*IERC20)(nil)
type IERC20CallerIface interface {
BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error)
Name(opts *bind.CallOpts) (string, error)
Symbol(opts *bind.CallOpts) (string, error)
Decimals(opts *bind.CallOpts) (uint8, error)
}
var _ IERC20CallerIface = (*IERC20Caller)(nil)

View File

@ -0,0 +1,83 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: contracts/contracts.go
// Package mock_contracts is a generated GoMock package.
package mock_contracts
import (
reflect "reflect"
common "github.com/ethereum/go-ethereum/common"
gomock "github.com/golang/mock/gomock"
ethscan "github.com/status-im/status-go/contracts/ethscan"
ierc20 "github.com/status-im/status-go/contracts/ierc20"
)
// MockContractMakerIface is a mock of ContractMakerIface interface.
type MockContractMakerIface struct {
ctrl *gomock.Controller
recorder *MockContractMakerIfaceMockRecorder
}
// MockContractMakerIfaceMockRecorder is the mock recorder for MockContractMakerIface.
type MockContractMakerIfaceMockRecorder struct {
mock *MockContractMakerIface
}
// NewMockContractMakerIface creates a new mock instance.
func NewMockContractMakerIface(ctrl *gomock.Controller) *MockContractMakerIface {
mock := &MockContractMakerIface{ctrl: ctrl}
mock.recorder = &MockContractMakerIfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockContractMakerIface) EXPECT() *MockContractMakerIfaceMockRecorder {
return m.recorder
}
// NewERC20 mocks base method.
func (m *MockContractMakerIface) NewERC20(chainID uint64, contractAddr common.Address) (ierc20.IERC20Iface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewERC20", chainID, contractAddr)
ret0, _ := ret[0].(ierc20.IERC20Iface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NewERC20 indicates an expected call of NewERC20.
func (mr *MockContractMakerIfaceMockRecorder) NewERC20(chainID, contractAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewERC20", reflect.TypeOf((*MockContractMakerIface)(nil).NewERC20), chainID, contractAddr)
}
// NewERC20Caller mocks base method.
func (m *MockContractMakerIface) NewERC20Caller(chainID uint64, contractAddr common.Address) (ierc20.IERC20CallerIface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewERC20Caller", chainID, contractAddr)
ret0, _ := ret[0].(ierc20.IERC20CallerIface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NewERC20Caller indicates an expected call of NewERC20Caller.
func (mr *MockContractMakerIfaceMockRecorder) NewERC20Caller(chainID, contractAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewERC20Caller", reflect.TypeOf((*MockContractMakerIface)(nil).NewERC20Caller), chainID, contractAddr)
}
// NewEthScan mocks base method.
func (m *MockContractMakerIface) NewEthScan(chainID uint64) (ethscan.BalanceScannerIface, uint, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewEthScan", chainID)
ret0, _ := ret[0].(ethscan.BalanceScannerIface)
ret1, _ := ret[1].(uint)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// NewEthScan indicates an expected call of NewEthScan.
func (mr *MockContractMakerIfaceMockRecorder) NewEthScan(chainID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewEthScan", reflect.TypeOf((*MockContractMakerIface)(nil).NewEthScan), chainID)
}

View File

@ -37,6 +37,7 @@ import (
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/requests"
"github.com/status-im/status-go/protocol/transport"
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/wallet/bigint"
walletcommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
@ -271,22 +272,23 @@ type CommunityTokensServiceInterface interface {
}
type DefaultTokenManager struct {
tokenManager *token.Manager
tokenManager *token.Manager
networkManager network.ManagerInterface
}
func NewDefaultTokenManager(tm *token.Manager) *DefaultTokenManager {
return &DefaultTokenManager{tokenManager: tm}
func NewDefaultTokenManager(tm *token.Manager, nm network.ManagerInterface) *DefaultTokenManager {
return &DefaultTokenManager{tokenManager: tm, networkManager: nm}
}
type BalancesByChain = map[uint64]map[gethcommon.Address]map[gethcommon.Address]*hexutil.Big
func (m *DefaultTokenManager) GetAllChainIDs() ([]uint64, error) {
networks, err := m.tokenManager.RPCClient.NetworkManager.Get(false)
networks, err := m.networkManager.Get(false)
if err != nil {
return nil, err
}
areTestNetworksEnabled, err := m.tokenManager.RPCClient.NetworkManager.GetTestNetworksEnabled()
areTestNetworksEnabled, err := m.networkManager.GetTestNetworksEnabled()
if err != nil {
return nil, err
}

View File

@ -472,7 +472,7 @@ func NewMessenger(
managerOptions = append(managerOptions, communities.WithTokenManager(c.tokenManager))
} else if c.rpcClient != nil {
tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil, nil, nil, token.NewPersistence(c.walletDb))
managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager)))
managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager, c.rpcClient.NetworkManager)))
}
if c.walletConfig != nil {

View File

@ -28,7 +28,7 @@ type BatchCallClient interface {
}
type ChainInterface interface {
BatchCallContext(ctx context.Context, b []rpc.BatchElem) error
BatchCallClient
HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error)
BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error)
BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error)

View File

@ -9,12 +9,11 @@ import (
big "math/big"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
ethereum "github.com/ethereum/go-ethereum"
common "github.com/ethereum/go-ethereum/common"
types "github.com/ethereum/go-ethereum/core/types"
rpc "github.com/ethereum/go-ethereum/rpc"
gomock "github.com/golang/mock/gomock"
chain "github.com/status-im/status-go/rpc/chain"
)

View File

@ -37,6 +37,9 @@ type Handler func(context.Context, uint64, ...interface{}) (interface{}, error)
type ClientInterface interface {
AbstractEthClient(chainID common.ChainID) (chain.BatchCallClient, error)
EthClient(chainID uint64) (chain.ClientInterface, error)
EthClients(chainIDs []uint64) (map[uint64]chain.ClientInterface, error)
CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error
}
// Client represents RPC client with custom routing

101
rpc/mock/client/client.go Normal file
View File

@ -0,0 +1,101 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: rpc/client.go
// Package mock_rpcclient is a generated GoMock package.
package mock_rpcclient
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
chain "github.com/status-im/status-go/rpc/chain"
common "github.com/status-im/status-go/services/wallet/common"
)
// MockClientInterface is a mock of ClientInterface interface.
type MockClientInterface struct {
ctrl *gomock.Controller
recorder *MockClientInterfaceMockRecorder
}
// MockClientInterfaceMockRecorder is the mock recorder for MockClientInterface.
type MockClientInterfaceMockRecorder struct {
mock *MockClientInterface
}
// NewMockClientInterface creates a new mock instance.
func NewMockClientInterface(ctrl *gomock.Controller) *MockClientInterface {
mock := &MockClientInterface{ctrl: ctrl}
mock.recorder = &MockClientInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockClientInterface) EXPECT() *MockClientInterfaceMockRecorder {
return m.recorder
}
// AbstractEthClient mocks base method.
func (m *MockClientInterface) AbstractEthClient(chainID common.ChainID) (chain.BatchCallClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AbstractEthClient", chainID)
ret0, _ := ret[0].(chain.BatchCallClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AbstractEthClient indicates an expected call of AbstractEthClient.
func (mr *MockClientInterfaceMockRecorder) AbstractEthClient(chainID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AbstractEthClient", reflect.TypeOf((*MockClientInterface)(nil).AbstractEthClient), chainID)
}
// CallContext mocks base method.
func (m *MockClientInterface) CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{context, result, chainID, method}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "CallContext", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// CallContext indicates an expected call of CallContext.
func (mr *MockClientInterfaceMockRecorder) CallContext(context, result, chainID, method interface{}, args ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{context, result, chainID, method}, args...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallContext", reflect.TypeOf((*MockClientInterface)(nil).CallContext), varargs...)
}
// EthClient mocks base method.
func (m *MockClientInterface) EthClient(chainID uint64) (chain.ClientInterface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EthClient", chainID)
ret0, _ := ret[0].(chain.ClientInterface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EthClient indicates an expected call of EthClient.
func (mr *MockClientInterfaceMockRecorder) EthClient(chainID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClient", reflect.TypeOf((*MockClientInterface)(nil).EthClient), chainID)
}
// EthClients mocks base method.
func (m *MockClientInterface) EthClients(chainIDs []uint64) (map[uint64]chain.ClientInterface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EthClients", chainIDs)
ret0, _ := ret[0].(map[uint64]chain.ClientInterface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EthClients indicates an expected call of EthClients.
func (mr *MockClientInterfaceMockRecorder) EthClients(chainIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClients", reflect.TypeOf((*MockClientInterface)(nil).EthClients), chainIDs)
}

108
rpc/network/mock/network.go Normal file
View File

@ -0,0 +1,108 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: rpc/network/network.go
// Package mock_network is a generated GoMock package.
package mock_network
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
params "github.com/status-im/status-go/params"
)
// MockManagerInterface is a mock of ManagerInterface interface.
type MockManagerInterface struct {
ctrl *gomock.Controller
recorder *MockManagerInterfaceMockRecorder
}
// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface.
type MockManagerInterfaceMockRecorder struct {
mock *MockManagerInterface
}
// NewMockManagerInterface creates a new mock instance.
func NewMockManagerInterface(ctrl *gomock.Controller) *MockManagerInterface {
mock := &MockManagerInterface{ctrl: ctrl}
mock.recorder = &MockManagerInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManagerInterface) EXPECT() *MockManagerInterfaceMockRecorder {
return m.recorder
}
// Find mocks base method.
func (m *MockManagerInterface) Find(chainID uint64) *params.Network {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Find", chainID)
ret0, _ := ret[0].(*params.Network)
return ret0
}
// Find indicates an expected call of Find.
func (mr *MockManagerInterfaceMockRecorder) Find(chainID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockManagerInterface)(nil).Find), chainID)
}
// Get mocks base method.
func (m *MockManagerInterface) Get(onlyEnabled bool) ([]*params.Network, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", onlyEnabled)
ret0, _ := ret[0].([]*params.Network)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockManagerInterfaceMockRecorder) Get(onlyEnabled interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockManagerInterface)(nil).Get), onlyEnabled)
}
// GetAll mocks base method.
func (m *MockManagerInterface) GetAll() ([]*params.Network, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAll")
ret0, _ := ret[0].([]*params.Network)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAll indicates an expected call of GetAll.
func (mr *MockManagerInterfaceMockRecorder) GetAll() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAll", reflect.TypeOf((*MockManagerInterface)(nil).GetAll))
}
// GetConfiguredNetworks mocks base method.
func (m *MockManagerInterface) GetConfiguredNetworks() []params.Network {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetConfiguredNetworks")
ret0, _ := ret[0].([]params.Network)
return ret0
}
// GetConfiguredNetworks indicates an expected call of GetConfiguredNetworks.
func (mr *MockManagerInterfaceMockRecorder) GetConfiguredNetworks() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfiguredNetworks", reflect.TypeOf((*MockManagerInterface)(nil).GetConfiguredNetworks))
}
// GetTestNetworksEnabled mocks base method.
func (m *MockManagerInterface) GetTestNetworksEnabled() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTestNetworksEnabled")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTestNetworksEnabled indicates an expected call of GetTestNetworksEnabled.
func (mr *MockManagerInterfaceMockRecorder) GetTestNetworksEnabled() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTestNetworksEnabled", reflect.TypeOf((*MockManagerInterface)(nil).GetTestNetworksEnabled))
}

View File

@ -81,6 +81,14 @@ func (nq *networksQuery) exec(db *sql.DB) ([]*params.Network, error) {
return res, err
}
type ManagerInterface interface {
Get(onlyEnabled bool) ([]*params.Network, error)
GetAll() ([]*params.Network, error)
Find(chainID uint64) *params.Network
GetConfiguredNetworks() []params.Network
GetTestNetworksEnabled() (bool, error)
}
type Manager struct {
db *sql.DB
configuredNetworks []params.Network

View File

@ -7,17 +7,20 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
eth "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/event"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/rpc/chain"
mock_rpcclient "github.com/status-im/status-go/rpc/mock/client"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/services/wallet/token"
mock_token "github.com/status-im/status-go/services/wallet/token/mock/token"
"github.com/status-im/status-go/services/wallet/transfer"
"github.com/status-im/status-go/services/wallet/walletevent"
"github.com/status-im/status-go/t/helpers"
@ -53,45 +56,15 @@ func (m *mockCollectiblesManager) FetchCollectionSocialsAsync(contractID thirdpa
return nil
}
// mockTokenManager implements the token.ManagerInterface
type mockTokenManager struct {
mock.Mock
}
func (m *mockTokenManager) LookupTokenIdentity(chainID uint64, address eth.Address, native bool) *token.Token {
args := m.Called(chainID, address, native)
res := args.Get(0)
if res == nil {
return nil
}
return res.(*token.Token)
}
func (m *mockTokenManager) LookupToken(chainID *uint64, tokenSymbol string) (tkn *token.Token, isNative bool) {
args := m.Called(chainID, tokenSymbol)
return args.Get(0).(*token.Token), args.Bool(1)
}
func (m *mockTokenManager) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []eth.Address) (map[uint64]map[eth.Address]map[eth.Address]*hexutil.Big, error) {
return nil, nil // Not used here
}
func (m *mockTokenManager) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) {
return nil, nil // Not used here
}
func (m *mockTokenManager) GetTokenHistoricalBalance(account eth.Address, chainID uint64, symbol string, timestamp int64) (*big.Int, error) {
return nil, nil // Not used here
}
type testState struct {
service *Service
eventFeed *event.Feed
tokenMock *mockTokenManager
tokenMock *mock_token.MockManagerInterface
collectiblesMock *mockCollectiblesManager
close func()
pendingTracker *transactions.PendingTxTracker
chainClient *transactions.MockChainClient
rpcClient *mock_rpcclient.MockClientInterface
}
func setupTestService(tb testing.TB) (state testState) {
@ -104,20 +77,26 @@ func setupTestService(tb testing.TB) (state testState) {
require.NoError(tb, err)
state.eventFeed = new(event.Feed)
state.tokenMock = &mockTokenManager{}
mockCtrl := gomock.NewController(tb)
state.tokenMock = mock_token.NewMockManagerInterface(mockCtrl)
state.collectiblesMock = &mockCollectiblesManager{}
state.chainClient = transactions.NewMockChainClient()
state.rpcClient = mock_rpcclient.NewMockClientInterface(mockCtrl)
state.rpcClient.EXPECT().AbstractEthClient(gomock.Any()).DoAndReturn(func(chainID common.ChainID) (chain.BatchCallClient, error) {
return state.chainClient.AbstractEthClient(chainID)
}).AnyTimes()
// Ensure we process pending transactions as needed, only once
pendingCheckInterval := time.Second
state.pendingTracker = transactions.NewPendingTxTracker(db, state.chainClient, nil, state.eventFeed, pendingCheckInterval)
state.pendingTracker = transactions.NewPendingTxTracker(db, state.rpcClient, nil, state.eventFeed, pendingCheckInterval)
state.service = NewService(db, accountsDB, state.tokenMock, state.collectiblesMock, state.eventFeed, state.pendingTracker)
state.service.debounceDuration = 0
state.close = func() {
require.NoError(tb, state.pendingTracker.Stop())
require.NoError(tb, db.Close())
defer mockCtrl.Finish()
}
return state
@ -168,13 +147,13 @@ func TestService_UpdateCollectibleInfo(t *testing.T) {
sub := state.eventFeed.Subscribe(ch)
// Expect one call for the fungible token
state.tokenMock.On("LookupTokenIdentity", uint64(5), eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"), false).Return(
state.tokenMock.EXPECT().LookupTokenIdentity(uint64(5), eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"), false).Return(
&token.Token{
ChainID: 5,
Address: eth.HexToAddress("0x3d6afaa395c31fcd391fe3d562e75fe9e8ec7e6a"),
Symbol: "STT",
}, false,
).Once()
},
).Times(1)
state.collectiblesMock.On("FetchAssetsByCollectibleUniqueID", []thirdparty.CollectibleUniqueID{
{
ContractID: thirdparty.ContractID{
@ -296,21 +275,21 @@ func setupTransactions(t *testing.T, state testState, txCount int, testTxs []tra
allAddresses = append(append(allAddresses, fromTrs...), toTrs...)
state.tokenMock.On("LookupTokenIdentity", mock.Anything, mock.Anything, mock.Anything).Return(
state.tokenMock.EXPECT().LookupTokenIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(
&token.Token{
ChainID: 5,
Address: eth.Address{},
Symbol: "ETH",
}, true,
).Times(0)
},
).AnyTimes()
state.tokenMock.On("LookupToken", mock.Anything, mock.Anything).Return(
state.tokenMock.EXPECT().LookupToken(gomock.Any(), gomock.Any()).Return(
&token.Token{
ChainID: 5,
Address: eth.Address{},
Symbol: "ETH",
}, true,
).Times(0)
).AnyTimes()
return allAddresses, pendings, ch, func() {
sub.Unsubscribe()

View File

@ -55,7 +55,7 @@ func NewAPI(s *Service) *API {
erc1155Transfer := pathprocessor.NewERC1155Processor(rpcClient, transactor)
router.AddPathProcessor(erc1155Transfer)
hop := pathprocessor.NewHopBridgeProcessor(rpcClient, transactor, tokenManager)
hop := pathprocessor.NewHopBridgeProcessor(rpcClient, transactor, tokenManager, rpcClient.NetworkManager)
router.AddPathProcessor(hop)
if featureFlags.EnableCelerBridge {

View File

@ -31,6 +31,7 @@ import (
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/chain"
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/wallet/bigint"
walletCommon "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/thirdparty"
@ -106,20 +107,22 @@ func (bf *BonderFee) UnmarshalJSON(data []byte) error {
}
type HopBridgeProcessor struct {
transactor transactions.TransactorIface
httpClient *thirdparty.HTTPClient
tokenManager *token.Manager
contractMaker *contracts.ContractMaker
bonderFee *sync.Map // [fromChainName-toChainName]BonderFee
transactor transactions.TransactorIface
httpClient *thirdparty.HTTPClient
tokenManager *token.Manager
contractMaker *contracts.ContractMaker
networkManager network.ManagerInterface
bonderFee *sync.Map // [fromChainName-toChainName]BonderFee
}
func NewHopBridgeProcessor(rpcClient *rpc.Client, transactor transactions.TransactorIface, tokenManager *token.Manager) *HopBridgeProcessor {
func NewHopBridgeProcessor(rpcClient rpc.ClientInterface, transactor transactions.TransactorIface, tokenManager *token.Manager, networkManager network.ManagerInterface) *HopBridgeProcessor {
return &HopBridgeProcessor{
contractMaker: &contracts.ContractMaker{RPCClient: rpcClient},
httpClient: thirdparty.NewHTTPClient(),
transactor: transactor,
tokenManager: tokenManager,
bonderFee: &sync.Map{},
contractMaker: &contracts.ContractMaker{RPCClient: rpcClient},
httpClient: thirdparty.NewHTTPClient(),
transactor: transactor,
tokenManager: tokenManager,
networkManager: networkManager,
bonderFee: &sync.Map{},
}
}
@ -299,7 +302,7 @@ func (h *HopBridgeProcessor) GetContractAddress(params ProcessorInputParams) (co
}
func (h *HopBridgeProcessor) sendOrBuild(sendArgs *MultipathProcessorTxArgs, signerFn bind.SignerFn) (tx *ethTypes.Transaction, err error) {
fromChain := h.contractMaker.RPCClient.NetworkManager.Find(sendArgs.ChainID)
fromChain := h.networkManager.Find(sendArgs.ChainID)
if fromChain == nil {
return tx, fmt.Errorf("ChainID not supported %d", sendArgs.ChainID)
}

View File

@ -309,7 +309,7 @@ func TestPathProcessors(t *testing.T) {
if processorName == ProcessorTransferName {
processor = NewTransferProcessor(nil, nil)
} else if processorName == ProcessorBridgeHopName {
processor = NewHopBridgeProcessor(nil, nil, nil)
processor = NewHopBridgeProcessor(nil, nil, nil, nil)
} else if processorName == ProcessorSwapParaswapName {
processor = NewSwapParaswapProcessor(nil, nil, nil)
}

View File

@ -223,7 +223,7 @@ func setupRouter(t *testing.T) (*Router, func()) {
erc1155Transfer := pathprocessor.NewERC1155Processor(nil, nil)
router.AddPathProcessor(erc1155Transfer)
hop := pathprocessor.NewHopBridgeProcessor(nil, nil, nil)
hop := pathprocessor.NewHopBridgeProcessor(nil, nil, nil, nil)
router.AddPathProcessor(hop)
paraswap := pathprocessor.NewSwapParaswapProcessor(nil, nil, nil)

View File

@ -18,7 +18,7 @@ import (
"github.com/status-im/status-go/services/wallet/async"
)
var nativeChainAddress = common.HexToAddress("0x")
var NativeChainAddress = common.HexToAddress("0x")
var requestTimeout = 20 * time.Second
const (
@ -26,7 +26,6 @@ const (
)
type BalanceFetcher interface {
FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error)
GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error)
GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error)
GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error)
@ -44,7 +43,7 @@ func NewDefaultBalanceFetcher(contractMaker contracts.ContractMakerIface) *Defau
}
}
func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
func (bf *DefaultBalanceFetcher) fetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
var (
group = async.NewAtomicGroup(parent)
mu sync.Mutex
@ -61,14 +60,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c
}
for token, balance := range tokenBalance {
if _, ok := balances[account][token]; !ok {
zeroHex := hexutil.Big(*big.NewInt(0))
balances[account][token] = &zeroHex
}
sum := big.NewInt(0).Add(balances[account][token].ToInt(), balance.ToInt())
sumHex := hexutil.Big(*sum)
balances[account][token] = &sumHex
balances[account][token] = balance
}
}
}
@ -79,12 +71,10 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c
return nil, err
}
atBlock := atBlocks[client.NetworkID()]
fetchChainBalance := false
for _, token := range tokens {
if token == nativeChainAddress {
if token == NativeChainAddress {
fetchChainBalance = true
}
}
@ -117,7 +107,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c
if atBlock == nil || big.NewInt(int64(availableAtBlock)).Cmp(atBlock) < 0 {
accTokenBalance, err = bf.FetchTokenBalancesWithScanContract(ctx, ethScanContract, account, chunk, atBlock)
} else {
accTokenBalance, err = bf.FetchTokenBalancesWithTokenContracts(ctx, client, account, chunk, atBlock)
accTokenBalance, err = bf.fetchTokenBalancesWithTokenContracts(ctx, client, account, chunk, atBlock)
}
if err != nil {
@ -138,7 +128,7 @@ func (bf *DefaultBalanceFetcher) FetchBalancesForChain(parent context.Context, c
return balances, group.Error()
}
func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, accounts []common.Address, ethScanContract *ethscan.BalanceScanner, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, accounts []common.Address, ethScanContract ethscan.BalanceScannerIface, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
ctx, cancel := context.WithTimeout(parent, requestTimeout)
@ -160,13 +150,13 @@ func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, acco
accTokenBalance[account] = make(map[common.Address]*hexutil.Big)
}
accTokenBalance[account][nativeChainAddress] = (*hexutil.Big)(balance)
accTokenBalance[account][NativeChainAddress] = (*hexutil.Big)(balance)
}
return accTokenBalance, nil
}
func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.Context, ethScanContract *ethscan.BalanceScanner, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.Context, ethScanContract ethscan.BalanceScannerIface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
res, err := ethScanContract.TokensBalance(&bind.CallOpts{
Context: ctx,
@ -178,7 +168,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.
}
if len(res) != len(chunk) {
log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete")
log.Error("can't fetch erc20 token balance 7", "account", account, "error", "response not complete", "expected", len(chunk), "got", len(res))
return nil, errors.New("response not complete")
}
@ -198,7 +188,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.
return accTokenBalance, nil
}
func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithTokenContracts(ctx context.Context, client chain.ClientInterface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
func (bf *DefaultBalanceFetcher) fetchTokenBalancesWithTokenContracts(ctx context.Context, client chain.ClientInterface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
for _, token := range chunk {
balance, err := bf.GetTokenBalanceAt(ctx, client, account, token, atBlock)
@ -220,7 +210,7 @@ func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithTokenContracts(ctx contex
}
func (bf *DefaultBalanceFetcher) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error) {
caller, err := ierc20.NewIERC20Caller(token, client)
caller, err := bf.contractMaker.NewERC20Caller(client.NetworkID(), token)
if err != nil {
return nil, err
}
@ -270,7 +260,7 @@ func (bf *DefaultBalanceFetcher) GetChainBalance(ctx context.Context, client cha
}
func (bf *DefaultBalanceFetcher) GetBalance(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address) (*big.Int, error) {
if token == nativeChainAddress {
if token == NativeChainAddress {
return bf.GetChainBalance(ctx, client, account)
}
@ -305,12 +295,15 @@ func (bf *DefaultBalanceFetcher) GetBalancesAtByChain(parent context.Context, cl
// Keep the reference to the client. DO NOT USE A LOOP, the client will be overridden in the coroutine
client := clients[clientIdx]
balances, err := bf.FetchBalancesForChain(parent, client, accounts, tokens, atBlocks)
if err != nil {
return nil, err
}
group.Add(func(parent context.Context) error {
balances, err := bf.fetchBalancesForChain(parent, client, accounts, tokens, atBlocks[client.NetworkID()])
if err != nil {
return nil
}
updateBalance(client.NetworkID(), balances)
updateBalance(client.NetworkID(), balances)
return nil
})
}
select {
case <-group.WaitAsync():

View File

@ -0,0 +1,386 @@
package balancefetcher
import (
"context"
"errors"
"math/big"
"os"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/contracts/ethscan"
"github.com/status-im/status-go/params"
mock_contracts "github.com/status-im/status-go/contracts/mock"
"github.com/status-im/status-go/rpc/chain"
mock_client "github.com/status-im/status-go/rpc/chain/mock/client"
mock_network "github.com/status-im/status-go/rpc/network/mock"
w_common "github.com/status-im/status-go/services/wallet/common"
)
type FakeBalanceScanner struct {
etherBalances map[common.Address]*big.Int
tokenBalances map[common.Address]map[common.Address]*big.Int
}
func (f *FakeBalanceScanner) EtherBalances(opts *bind.CallOpts, addresses []common.Address) ([]ethscan.BalanceScannerResult, error) {
result := make([]ethscan.BalanceScannerResult, 0, len(addresses))
for _, address := range addresses {
balance, ok := f.etherBalances[address]
if !ok {
result = append(result, ethscan.BalanceScannerResult{
Success: false,
Data: []byte{},
})
} else {
result = append(result, ethscan.BalanceScannerResult{
Success: true,
Data: balance.Bytes(),
})
}
}
return result, nil
}
func (f *FakeBalanceScanner) TokenBalances(opts *bind.CallOpts, addresses []common.Address, tokenAddress common.Address) ([]ethscan.BalanceScannerResult, error) {
result := make([]ethscan.BalanceScannerResult, 0, len(addresses))
for _, address := range addresses {
balances, ok := f.tokenBalances[address]
if !ok {
result = append(result, ethscan.BalanceScannerResult{
Success: false,
Data: []byte{},
})
} else {
balance, ok := balances[tokenAddress]
if !ok {
result = append(result, ethscan.BalanceScannerResult{
Success: false,
Data: []byte{},
})
} else {
result = append(result, ethscan.BalanceScannerResult{
Success: true,
Data: balance.Bytes(),
})
}
}
}
return result, nil
}
func (f *FakeBalanceScanner) TokensBalance(opts *bind.CallOpts, owner common.Address, contracts []common.Address) ([]ethscan.BalanceScannerResult, error) {
result := make([]ethscan.BalanceScannerResult, 0, len(contracts))
for _, contract := range contracts {
balances, ok := f.tokenBalances[owner]
if !ok {
result = append(result, ethscan.BalanceScannerResult{
Success: false,
Data: []byte{},
})
} else {
balance, ok := balances[contract]
if !ok {
result = append(result, ethscan.BalanceScannerResult{
Success: false,
Data: []byte{},
})
} else {
result = append(result, ethscan.BalanceScannerResult{
Success: true,
Data: balance.Bytes(),
})
}
}
}
return result, nil
}
type FakeERC20Caller struct {
accountBalances map[common.Address]*big.Int
}
func (f *FakeERC20Caller) BalanceOf(opts *bind.CallOpts, account common.Address) (*big.Int, error) {
balance, ok := f.accountBalances[account]
if !ok {
return nil, errors.New("account not found")
}
return balance, nil
}
func (f *FakeERC20Caller) Name(opts *bind.CallOpts) (string, error) {
return "TestToken", nil
}
func (f *FakeERC20Caller) Symbol(opts *bind.CallOpts) (string, error) {
return "TT", nil
}
func (f *FakeERC20Caller) Decimals(opts *bind.CallOpts) (uint8, error) {
return 18, nil
}
func TestBalanceFetcherFetchBalancesForChainNativeAndTokensWithScanContract(t *testing.T) {
log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true))))
ctx := context.Background()
accounts := []common.Address{
common.HexToAddress("0x1234567890abcdef"),
common.HexToAddress("0xabcdef1234567890"),
}
tokens := []common.Address{
NativeChainAddress,
common.HexToAddress("0xabcdef1234567890"),
common.HexToAddress("0x0987654321fedcba"),
}
var atBlock *big.Int // nil triggers using a scan contract
ctrl := gomock.NewController(t)
defer ctrl.Finish()
networkManager := mock_network.NewMockManagerInterface(ctrl)
networkManager.EXPECT().GetAll().Return([]*params.Network{
{
ChainID: w_common.EthereumMainnet,
},
}, nil).AnyTimes()
chainClient := mock_client.NewMockClientInterface(ctrl)
chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes()
expectedEthBalances := map[common.Address]*big.Int{
accounts[0]: big.NewInt(100),
accounts[1]: big.NewInt(200),
}
expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{
accounts[0]: {
tokens[1]: big.NewInt(1000),
tokens[2]: big.NewInt(2000),
},
accounts[1]: {
tokens[1]: big.NewInt(3000),
tokens[2]: big.NewInt(4000),
},
}
expectedBalances := map[common.Address]map[common.Address]*hexutil.Big{
accounts[0]: {
tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[0]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[1]]),
tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[2]]),
},
accounts[1]: {
tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[1]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[1]]),
tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[2]]),
},
}
contractMaker := mock_contracts.NewMockContractMakerIface(ctrl)
contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{
etherBalances: expectedEthBalances,
tokenBalances: expectedTokenBalances,
}, uint(0), nil).AnyTimes()
bf := NewDefaultBalanceFetcher(contractMaker)
// Fetch native balances and token balances using scan contract
balances, err := bf.fetchBalancesForChain(ctx, chainClient, accounts, tokens, atBlock)
require.NoError(t, err)
require.Equal(t, expectedBalances, balances)
}
func TestBalanceFetcherFetchBalancesForChainTokensWithTokenContracts(t *testing.T) {
log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true))))
ctx := context.Background()
accounts := []common.Address{
common.HexToAddress("0x1234567890abcdef"),
common.HexToAddress("0xabcdef1234567890"),
}
tokens := []common.Address{
common.HexToAddress("0xabcdef1234567890"),
common.HexToAddress("0x0987654321fedcba"),
}
atBlock := big.NewInt(0) // will trigger using a token contract
ctrl := gomock.NewController(t)
networkManager := mock_network.NewMockManagerInterface(ctrl)
networkManager.EXPECT().GetAll().Return([]*params.Network{
{
ChainID: w_common.EthereumMainnet,
},
}, nil).AnyTimes()
chainClient := mock_client.NewMockClientInterface(ctrl)
chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes()
chainClient.EXPECT().CallContract(gomock.Any(), gomock.Any(), atBlock).Return([]byte{}, nil).AnyTimes()
expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{
tokens[0]: {
accounts[0]: big.NewInt(1000),
accounts[1]: big.NewInt(2000),
},
tokens[1]: {
accounts[0]: big.NewInt(3000),
accounts[1]: big.NewInt(4000),
},
}
expectedBalances := map[common.Address]map[common.Address]*hexutil.Big{
accounts[0]: {
tokens[0]: (*hexutil.Big)(expectedTokenBalances[tokens[0]][accounts[0]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[tokens[1]][accounts[0]]),
},
accounts[1]: {
tokens[0]: (*hexutil.Big)(expectedTokenBalances[tokens[0]][accounts[1]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[tokens[1]][accounts[1]]),
},
}
contractMaker := mock_contracts.NewMockContractMakerIface(ctrl)
contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{}, uint(0), nil).Times(1)
for _, token := range tokens {
contractMaker.EXPECT().NewERC20Caller(w_common.EthereumMainnet, token).Return(&FakeERC20Caller{
accountBalances: expectedTokenBalances[token],
}, nil).AnyTimes()
}
bf := NewDefaultBalanceFetcher(contractMaker)
// Fetch token balances using tokens contracts
balances, err := bf.fetchBalancesForChain(ctx, chainClient, accounts, tokens, atBlock)
require.NoError(t, err)
require.Equal(t, expectedBalances, balances)
}
func TestBalanceFetcherGetBalancesAtByChain(t *testing.T) {
log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stdout, log.TerminalFormat(true))))
ctx := context.Background()
accounts := []common.Address{
common.HexToAddress("0x1234567890abcdef"),
common.HexToAddress("0xabcdef1234567890"),
}
tokens := []common.Address{
NativeChainAddress,
common.HexToAddress("0xabcdef1234567890"),
common.HexToAddress("0x0987654321fedcba"),
}
var atBlock *big.Int // nil triggers using a scan contract
atBlocks := map[uint64]*big.Int{
w_common.EthereumMainnet: atBlock, // nil triggers using a scan contract
w_common.OptimismMainnet: atBlock, // nil triggers using a scan contract
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
networkManager := mock_network.NewMockManagerInterface(ctrl)
networkManager.EXPECT().GetAll().Return([]*params.Network{
{
ChainID: w_common.EthereumMainnet,
},
{
ChainID: w_common.OptimismMainnet,
},
{
ChainID: w_common.ArbitrumMainnet,
},
}, nil).AnyTimes()
chainClient := mock_client.NewMockClientInterface(ctrl)
chainClient.EXPECT().NetworkID().Return(w_common.EthereumMainnet).AnyTimes()
chainClientOpt := mock_client.NewMockClientInterface(ctrl)
chainClientOpt.EXPECT().NetworkID().Return(w_common.OptimismMainnet).AnyTimes()
chainClientArb := mock_client.NewMockClientInterface(ctrl)
chainClientArb.EXPECT().NetworkID().Return(w_common.ArbitrumMainnet).AnyTimes()
chainClients := map[uint64]chain.ClientInterface{
w_common.EthereumMainnet: chainClient,
w_common.OptimismMainnet: chainClientOpt,
w_common.ArbitrumMainnet: chainClientArb,
}
expectedEthBalances := map[common.Address]*big.Int{
accounts[0]: big.NewInt(100),
accounts[1]: big.NewInt(200),
}
expectedEthOptBalances := map[common.Address]*big.Int{
accounts[0]: big.NewInt(300),
accounts[1]: big.NewInt(400),
}
expectedTokenBalances := map[common.Address]map[common.Address]*big.Int{
accounts[0]: {
tokens[1]: big.NewInt(1000),
tokens[2]: big.NewInt(2000),
},
accounts[1]: {
tokens[1]: big.NewInt(3000),
tokens[2]: big.NewInt(4000),
},
}
expectedTokenOptBalances := map[common.Address]map[common.Address]*big.Int{
accounts[0]: {
tokens[1]: big.NewInt(5000),
tokens[2]: big.NewInt(6000),
},
}
expectedBalances := map[uint64]map[common.Address]map[common.Address]*hexutil.Big{
w_common.EthereumMainnet: {
accounts[0]: {
tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[0]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[1]]),
tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[0]][tokens[2]]),
},
accounts[1]: {
tokens[0]: (*hexutil.Big)(expectedEthBalances[accounts[1]]),
tokens[1]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[1]]),
tokens[2]: (*hexutil.Big)(expectedTokenBalances[accounts[1]][tokens[2]]),
},
},
w_common.OptimismMainnet: {
accounts[0]: {
tokens[0]: (*hexutil.Big)(expectedEthOptBalances[accounts[0]]),
tokens[1]: (*hexutil.Big)(expectedTokenOptBalances[accounts[0]][tokens[1]]),
tokens[2]: (*hexutil.Big)(expectedTokenOptBalances[accounts[0]][tokens[2]]),
},
accounts[1]: {
tokens[0]: (*hexutil.Big)(expectedEthOptBalances[accounts[1]]),
},
},
}
contractMaker := mock_contracts.NewMockContractMakerIface(ctrl)
contractMaker.EXPECT().NewEthScan(w_common.EthereumMainnet).Return(&FakeBalanceScanner{
etherBalances: expectedEthBalances,
tokenBalances: expectedTokenBalances,
}, uint(0), nil).AnyTimes()
contractMaker.EXPECT().NewEthScan(w_common.OptimismMainnet).Return(&FakeBalanceScanner{
etherBalances: expectedEthOptBalances,
tokenBalances: expectedTokenOptBalances,
}, uint(0), nil).AnyTimes()
contractMaker.EXPECT().NewEthScan(w_common.ArbitrumMainnet).Return(nil, uint(0), errors.New("no scan contract")).AnyTimes()
bf := NewDefaultBalanceFetcher(contractMaker)
// Fetch native balances and token balances using scan contract
balances, err := bf.GetBalancesAtByChain(ctx, chainClients, accounts, tokens, atBlocks)
require.NoError(t, err)
require.Equal(t, expectedBalances, balances)
}

View File

@ -9,82 +9,82 @@ import (
big "math/big"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
common "github.com/ethereum/go-ethereum/common"
hexutil "github.com/ethereum/go-ethereum/common/hexutil"
gomock "github.com/golang/mock/gomock"
chain "github.com/status-im/status-go/rpc/chain"
token "github.com/status-im/status-go/services/wallet/token"
)
// MockManagerInterface is a mock of ManagerInterface interface
// MockManagerInterface is a mock of ManagerInterface interface.
type MockManagerInterface struct {
ctrl *gomock.Controller
recorder *MockManagerInterfaceMockRecorder
}
// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface
// MockManagerInterfaceMockRecorder is the mock recorder for MockManagerInterface.
type MockManagerInterfaceMockRecorder struct {
mock *MockManagerInterface
}
// NewMockManagerInterface creates a new mock instance
// NewMockManagerInterface creates a new mock instance.
func NewMockManagerInterface(ctrl *gomock.Controller) *MockManagerInterface {
mock := &MockManagerInterface{ctrl: ctrl}
mock.recorder = &MockManagerInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManagerInterface) EXPECT() *MockManagerInterfaceMockRecorder {
return m.recorder
}
// LookupTokenIdentity mocks base method
func (m *MockManagerInterface) LookupTokenIdentity(chainID uint64, address common.Address, native bool) *token.Token {
// FetchBalancesForChain mocks base method.
func (m *MockManagerInterface) FetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LookupTokenIdentity", chainID, address, native)
ret0, _ := ret[0].(*token.Token)
return ret0
}
// LookupTokenIdentity indicates an expected call of LookupTokenIdentity
func (mr *MockManagerInterfaceMockRecorder) LookupTokenIdentity(chainID, address, native interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupTokenIdentity", reflect.TypeOf((*MockManagerInterface)(nil).LookupTokenIdentity), chainID, address, native)
}
// LookupToken mocks base method
func (m *MockManagerInterface) LookupToken(chainID *uint64, tokenSymbol string) (*token.Token, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LookupToken", chainID, tokenSymbol)
ret0, _ := ret[0].(*token.Token)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// LookupToken indicates an expected call of LookupToken
func (mr *MockManagerInterfaceMockRecorder) LookupToken(chainID, tokenSymbol interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupToken", reflect.TypeOf((*MockManagerInterface)(nil).LookupToken), chainID, tokenSymbol)
}
// GetTokensByChainIDs mocks base method
func (m *MockManagerInterface) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTokensByChainIDs", chainIDs)
ret0, _ := ret[0].([]*token.Token)
ret := m.ctrl.Call(m, "FetchBalancesForChain", parent, client, accounts, tokens, atBlock)
ret0, _ := ret[0].(map[common.Address]map[common.Address]*hexutil.Big)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTokensByChainIDs indicates an expected call of GetTokensByChainIDs
func (mr *MockManagerInterfaceMockRecorder) GetTokensByChainIDs(chainIDs interface{}) *gomock.Call {
// FetchBalancesForChain indicates an expected call of FetchBalancesForChain.
func (mr *MockManagerInterfaceMockRecorder) FetchBalancesForChain(parent, client, accounts, tokens, atBlock interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokensByChainIDs", reflect.TypeOf((*MockManagerInterface)(nil).GetTokensByChainIDs), chainIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchBalancesForChain", reflect.TypeOf((*MockManagerInterface)(nil).FetchBalancesForChain), parent, client, accounts, tokens, atBlock)
}
// GetBalancesByChain mocks base method
// GetBalance mocks base method.
func (m *MockManagerInterface) GetBalance(ctx context.Context, client chain.ClientInterface, account, token common.Address) (*big.Int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBalance", ctx, client, account, token)
ret0, _ := ret[0].(*big.Int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetBalance indicates an expected call of GetBalance.
func (mr *MockManagerInterfaceMockRecorder) GetBalance(ctx, client, account, token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetBalance), ctx, client, account, token)
}
// GetBalancesAtByChain mocks base method.
func (m *MockManagerInterface) GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBalancesAtByChain", parent, clients, accounts, tokens, atBlocks)
ret0, _ := ret[0].(map[uint64]map[common.Address]map[common.Address]*hexutil.Big)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetBalancesAtByChain indicates an expected call of GetBalancesAtByChain.
func (mr *MockManagerInterfaceMockRecorder) GetBalancesAtByChain(parent, clients, accounts, tokens, atBlocks interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalancesAtByChain", reflect.TypeOf((*MockManagerInterface)(nil).GetBalancesAtByChain), parent, clients, accounts, tokens, atBlocks)
}
// GetBalancesByChain mocks base method.
func (m *MockManagerInterface) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBalancesByChain", parent, clients, accounts, tokens)
@ -93,13 +93,43 @@ func (m *MockManagerInterface) GetBalancesByChain(parent context.Context, client
return ret0, ret1
}
// GetBalancesByChain indicates an expected call of GetBalancesByChain
// GetBalancesByChain indicates an expected call of GetBalancesByChain.
func (mr *MockManagerInterfaceMockRecorder) GetBalancesByChain(parent, clients, accounts, tokens interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalancesByChain", reflect.TypeOf((*MockManagerInterface)(nil).GetBalancesByChain), parent, clients, accounts, tokens)
}
// GetTokenHistoricalBalance mocks base method
// GetChainBalance mocks base method.
func (m *MockManagerInterface) GetChainBalance(ctx context.Context, client chain.ClientInterface, account common.Address) (*big.Int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChainBalance", ctx, client, account)
ret0, _ := ret[0].(*big.Int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChainBalance indicates an expected call of GetChainBalance.
func (mr *MockManagerInterfaceMockRecorder) GetChainBalance(ctx, client, account interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChainBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetChainBalance), ctx, client, account)
}
// GetTokenBalanceAt mocks base method.
func (m *MockManagerInterface) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account, token common.Address, blockNumber *big.Int) (*big.Int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTokenBalanceAt", ctx, client, account, token, blockNumber)
ret0, _ := ret[0].(*big.Int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTokenBalanceAt indicates an expected call of GetTokenBalanceAt.
func (mr *MockManagerInterfaceMockRecorder) GetTokenBalanceAt(ctx, client, account, token, blockNumber interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenBalanceAt", reflect.TypeOf((*MockManagerInterface)(nil).GetTokenBalanceAt), ctx, client, account, token, blockNumber)
}
// GetTokenHistoricalBalance mocks base method.
func (m *MockManagerInterface) GetTokenHistoricalBalance(account common.Address, chainID uint64, symbol string, timestamp int64) (*big.Int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTokenHistoricalBalance", account, chainID, symbol, timestamp)
@ -108,8 +138,52 @@ func (m *MockManagerInterface) GetTokenHistoricalBalance(account common.Address,
return ret0, ret1
}
// GetTokenHistoricalBalance indicates an expected call of GetTokenHistoricalBalance
// GetTokenHistoricalBalance indicates an expected call of GetTokenHistoricalBalance.
func (mr *MockManagerInterfaceMockRecorder) GetTokenHistoricalBalance(account, chainID, symbol, timestamp interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokenHistoricalBalance", reflect.TypeOf((*MockManagerInterface)(nil).GetTokenHistoricalBalance), account, chainID, symbol, timestamp)
}
// GetTokensByChainIDs mocks base method.
func (m *MockManagerInterface) GetTokensByChainIDs(chainIDs []uint64) ([]*token.Token, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTokensByChainIDs", chainIDs)
ret0, _ := ret[0].([]*token.Token)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTokensByChainIDs indicates an expected call of GetTokensByChainIDs.
func (mr *MockManagerInterfaceMockRecorder) GetTokensByChainIDs(chainIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTokensByChainIDs", reflect.TypeOf((*MockManagerInterface)(nil).GetTokensByChainIDs), chainIDs)
}
// LookupToken mocks base method.
func (m *MockManagerInterface) LookupToken(chainID *uint64, tokenSymbol string) (*token.Token, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LookupToken", chainID, tokenSymbol)
ret0, _ := ret[0].(*token.Token)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// LookupToken indicates an expected call of LookupToken.
func (mr *MockManagerInterfaceMockRecorder) LookupToken(chainID, tokenSymbol interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupToken", reflect.TypeOf((*MockManagerInterface)(nil).LookupToken), chainID, tokenSymbol)
}
// LookupTokenIdentity mocks base method.
func (m *MockManagerInterface) LookupTokenIdentity(chainID uint64, address common.Address, native bool) *token.Token {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LookupTokenIdentity", chainID, address, native)
ret0, _ := ret[0].(*token.Token)
return ret0
}
// LookupTokenIdentity indicates an expected call of LookupTokenIdentity.
func (mr *MockManagerInterfaceMockRecorder) LookupTokenIdentity(chainID, address, native interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupTokenIdentity", reflect.TypeOf((*MockManagerInterface)(nil).LookupTokenIdentity), chainID, address, native)
}

View File

@ -97,9 +97,9 @@ type ManagerInterface interface {
type Manager struct {
balancefetcher.BalanceFetcher
db *sql.DB
RPCClient *rpc.Client
RPCClient rpc.ClientInterface
ContractMaker *contracts.ContractMaker
networkManager *network.Manager
networkManager network.ManagerInterface
stores []store // Set on init, not changed afterwards
communityTokensDB *communitytokensdatabase.Database
communityManager *community.Manager
@ -130,7 +130,7 @@ func mergeTokens(sliceLists [][]*Token) []*Token {
return res
}
func prepareTokens(networkManager *network.Manager, stores []store) []*Token {
func prepareTokens(networkManager network.ManagerInterface, stores []store) []*Token {
tokens := make([]*Token, 0)
networks, err := networkManager.GetAll()
@ -158,9 +158,9 @@ func prepareTokens(networkManager *network.Manager, stores []store) []*Token {
func NewTokenManager(
db *sql.DB,
RPCClient *rpc.Client,
RPCClient rpc.ClientInterface,
communityManager *community.Manager,
networkManager *network.Manager,
networkManager network.ManagerInterface,
appDB *sql.DB,
mediaServer *server.MediaServer,
walletFeed *event.Feed,
@ -173,10 +173,10 @@ func NewTokenManager(
tokens := prepareTokens(networkManager, stores)
return &Manager{
BalanceFetcher: balancefetcher.NewDefaultBalanceFetcher(maker),
db: db,
RPCClient: RPCClient,
// ContractMaker: maker,
BalanceFetcher: balancefetcher.NewDefaultBalanceFetcher(maker),
db: db,
RPCClient: RPCClient,
ContractMaker: maker,
networkManager: networkManager,
communityManager: communityManager,
stores: stores,

View File

@ -24,6 +24,7 @@ import (
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/community"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/t/utils"
"github.com/status-im/status-go/transactions/fake"

View File

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock"
"golang.org/x/exp/slices" // since 1.21, this is in the standard library
@ -29,6 +30,8 @@ import (
"github.com/status-im/status-go/contracts/ierc20"
ethtypes "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/rpc/chain"
mock_client "github.com/status-im/status-go/rpc/chain/mock/client"
mock_rpcclient "github.com/status-im/status-go/rpc/mock/client"
"github.com/status-im/status-go/server"
"github.com/status-im/status-go/services/wallet/async"
"github.com/status-im/status-go/services/wallet/balance"
@ -1295,6 +1298,7 @@ func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem)
type MockChainClient struct {
mock.Mock
mock_client.MockClientInterface
clients map[walletcommon.ChainID]*MockETHClient
}
@ -1360,7 +1364,11 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) {
address := common.HexToAddress("0x1234")
chainClient := newMockChainClient()
tracker := transactions.NewPendingTxTracker(db, chainClient, nil, &event.Feed{}, transactions.PendingCheckInterval)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rpcClient := mock_rpcclient.NewMockClientInterface(ctrl)
rpcClient.EXPECT().AbstractEthClient(tc.NetworkID()).Return(chainClient, nil).AnyTimes()
tracker := transactions.NewPendingTxTracker(db, rpcClient, nil, &event.Feed{}, transactions.PendingCheckInterval)
accDB, err := accounts.NewDB(appdb)
require.NoError(t, err)

View File

@ -16,6 +16,8 @@ import (
"github.com/status-im/status-go/account"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/chain"
mock_rpcclient "github.com/status-im/status-go/rpc/mock/client"
wallet_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/router/pathprocessor"
"github.com/status-im/status-go/services/wallet/router/pathprocessor/mock_pathprocessor"
@ -174,15 +176,21 @@ func TestSendTransactionsETHFailOnTransactor(t *testing.T) {
func TestWatchTransaction(t *testing.T) {
tm, _, _ := setupTransactionManager(t)
chainID := uint64(1)
chainID := uint64(777) // GeneratePendingTransaction uses this chainID
pendingTxTimeout = 2 * time.Millisecond
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
chainClient := transactions.NewMockChainClient()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rpcClient := mock_rpcclient.NewMockClientInterface(ctrl)
rpcClient.EXPECT().AbstractEthClient(wallet_common.ChainID(chainID)).DoAndReturn(func(chainID wallet_common.ChainID) (chain.BatchCallClient, error) {
return chainClient.AbstractEthClient(chainID)
}).AnyTimes()
eventFeed := &event.Feed{}
// For now, pending tracker is not interface, so we have to use a real one
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout)
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, rpcClient, nil, eventFeed, pendingTxTimeout)
tm.eventFeed = eventFeed
// Create a context with timeout
@ -219,16 +227,22 @@ func TestWatchTransaction(t *testing.T) {
func TestWatchTransaction_Timeout(t *testing.T) {
tm, _, _ := setupTransactionManager(t)
chainID := uint64(1)
chainID := uint64(777) // GeneratePendingTransaction uses this chainID
transactionHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
pendingTxTimeout = 2 * time.Millisecond
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
chainClient := transactions.NewMockChainClient()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rpcClient := mock_rpcclient.NewMockClientInterface(gomock.NewController(t))
rpcClient.EXPECT().AbstractEthClient(wallet_common.ChainID(chainID)).DoAndReturn(func(chainID wallet_common.ChainID) (chain.BatchCallClient, error) {
return chainClient.AbstractEthClient(chainID)
}).AnyTimes()
eventFeed := &event.Feed{}
// For now, pending tracker is not interface, so we have to use a real one
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout)
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, rpcClient, nil, eventFeed, pendingTxTimeout)
tm.eventFeed = eventFeed
// Create a context with timeout

View File

@ -8,10 +8,9 @@ import (
big "math/big"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
common "github.com/ethereum/go-ethereum/common"
types "github.com/ethereum/go-ethereum/core/types"
gomock "github.com/golang/mock/gomock"
account "github.com/status-im/status-go/account"
types0 "github.com/status-im/status-go/eth-node/types"
params "github.com/status-im/status-go/params"
@ -89,7 +88,7 @@ func (mr *MockTransactorIfaceMockRecorder) EstimateGas(network, from, to, value,
}
// NextNonce mocks base method.
func (m *MockTransactorIface) NextNonce(rpcClient *rpc.Client, chainID uint64, from types0.Address) (uint64, error) {
func (m *MockTransactorIface) NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types0.Address) (uint64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextNonce", rpcClient, chainID, from)
ret0, _ := ret[0].(uint64)

View File

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -16,6 +17,8 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/rpc/chain"
mock_rpcclient "github.com/status-im/status-go/rpc/mock/client"
"github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/walletevent"
@ -29,12 +32,21 @@ func setupTestTransactionDB(t *testing.T, checkInterval *time.Duration) (*Pendin
require.NoError(t, err)
chainClient := NewMockChainClient()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
eventFeed := &event.Feed{}
pendingCheckInterval := PendingCheckInterval
if checkInterval != nil {
pendingCheckInterval = *checkInterval
}
return NewPendingTxTracker(db, chainClient, nil, eventFeed, pendingCheckInterval), func() {
rpcClient := mock_rpcclient.NewMockClientInterface(ctrl)
rpcClient.EXPECT().EthClient(common.EthereumMainnet).Return(chainClient, nil).AnyTimes()
// Delegate the call to the fake implementation
rpcClient.EXPECT().AbstractEthClient(gomock.Any()).DoAndReturn(func(chainID common.ChainID) (chain.BatchCallClient, error) {
return chainClient.AbstractEthClient(chainID)
}).AnyTimes()
return NewPendingTxTracker(db, rpcClient, nil, eventFeed, pendingCheckInterval), func() {
require.NoError(t, db.Close())
}, chainClient, eventFeed
}

View File

@ -16,11 +16,11 @@ import (
// rpcWrapper wraps provides convenient interface for ethereum RPC APIs we need for sending transactions
type rpcWrapper struct {
RPCClient *rpc.Client
RPCClient rpc.ClientInterface
chainID uint64
}
func newRPCWrapper(client *rpc.Client, chainID uint64) *rpcWrapper {
func newRPCWrapper(client rpc.ClientInterface, chainID uint64) *rpcWrapper {
return &rpcWrapper{RPCClient: client, chainID: chainID}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/rpc/chain"
mock_client "github.com/status-im/status-go/rpc/chain/mock/client"
"github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/common"
@ -21,6 +22,8 @@ type MockETHClient struct {
mock.Mock
}
var _ chain.BatchCallClient = (*MockETHClient)(nil)
func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
args := m.Called(ctx, b)
return args.Error(0)
@ -28,10 +31,13 @@ func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem)
type MockChainClient struct {
mock.Mock
mock_client.MockClientInterface
Clients map[common.ChainID]*MockETHClient
}
var _ chain.ClientInterface = (*MockChainClient)(nil)
func NewMockChainClient() *MockChainClient {
return &MockChainClient{
Clients: make(map[common.ChainID]*MockETHClient),

View File

@ -46,7 +46,7 @@ func (e *ErrBadNonce) Error() string {
// Transactor is an interface that defines the methods for validating and sending transactions.
type TransactorIface interface {
NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, error)
NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types.Address) (uint64, error)
EstimateGas(network *params.Network, from common.Address, to common.Address, value *big.Int, input []byte) (uint64, error)
SendTransaction(sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error)
SendTransactionWithChainID(chainID uint64, sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error)
@ -96,7 +96,7 @@ func (t *Transactor) SetRPC(rpcClient *rpc.Client, timeout time.Duration) {
t.rpcCallTimeout = timeout
}
func (t *Transactor) NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, error) {
func (t *Transactor) NextNonce(rpcClient rpc.ClientInterface, chainID uint64, from types.Address) (uint64, error) {
wrapper := newRPCWrapper(rpcClient, chainID)
ctx := context.Background()
nonce, err := wrapper.PendingNonceAt(ctx, common.Address(from))