From 12deb2336028639ff11b6a3e08043e2961bed5c4 Mon Sep 17 00:00:00 2001 From: Ivan Belyakov Date: Fri, 29 Mar 2024 13:44:50 +0100 Subject: [PATCH] chore(wallet)_: removed all addresses parameter as redundant. Instead we check directly if passed addresses are all wallet addresses that we have in accounts DB. --- services/wallet/activity/service.go | 56 +++++++++++++++++++- services/wallet/activity/service_test.go | 23 +++++--- services/wallet/activity/session.go | 67 +++++++++++------------- services/wallet/api.go | 12 ++--- services/wallet/service.go | 2 +- 5 files changed, 107 insertions(+), 53 deletions(-) diff --git a/services/wallet/activity/service.go b/services/wallet/activity/service.go index 9b86dd4b1..e0ad4d353 100644 --- a/services/wallet/activity/service.go +++ b/services/wallet/activity/service.go @@ -14,6 +14,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" + "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/collectibles" w_common "github.com/status-im/status-go/services/wallet/common" @@ -57,6 +58,7 @@ var ( // Service provides an async interface, ensuring only one filter request, of each type, is running at a time. It also provides lazy load of NFT info and token mapping type Service struct { db *sql.DB + accountsDB *accounts.Database tokenManager token.ManagerInterface collectibles collectibles.ManagerInterface eventFeed *event.Feed @@ -78,9 +80,10 @@ func (s *Service) nextSessionID() SessionID { return SessionID(s.lastSessionID.Add(1)) } -func NewService(db *sql.DB, tokenManager token.ManagerInterface, collectibles collectibles.ManagerInterface, eventFeed *event.Feed, pendingTracker *transactions.PendingTxTracker) *Service { +func NewService(db *sql.DB, accountsDB *accounts.Database, tokenManager token.ManagerInterface, collectibles collectibles.ManagerInterface, eventFeed *event.Feed, pendingTracker *transactions.PendingTxTracker) *Service { return &Service{ db: db, + accountsDB: accountsDB, tokenManager: tokenManager, collectibles: collectibles, eventFeed: eventFeed, @@ -117,8 +120,9 @@ type FilterResponse struct { // // All calls will trigger an EventActivityFilteringDone event with the result of the filtering // TODO #12120: replace with session based APIs -func (s *Service) FilterActivityAsync(requestID int32, addresses []common.Address, allAddresses bool, chainIDs []w_common.ChainID, filter Filter, offset int, limit int) { +func (s *Service) FilterActivityAsync(requestID int32, addresses []common.Address, chainIDs []w_common.ChainID, filter Filter, offset int, limit int) { s.scheduler.Enqueue(requestID, filterTask, func(ctx context.Context) (interface{}, error) { + allAddresses := s.areAllAddresses(addresses) activities, err := getActivityEntries(ctx, s.getDeps(), addresses, allAddresses, chainIDs, filter, offset, limit) return activities, err }, func(result interface{}, taskType async.TaskType, err error) { @@ -396,3 +400,51 @@ func sendResponseEvent(eventFeed *event.Feed, requestID *int32, eventType wallet eventFeed.Send(event) } + +func (s *Service) getWalletAddreses() ([]common.Address, error) { + ethAddresses, err := s.accountsDB.GetWalletAddresses() + if err != nil { + return nil, err + } + + addresses := make([]common.Address, 0, len(ethAddresses)) + for _, ethAddress := range ethAddresses { + addresses = append(addresses, common.Address(ethAddress)) + } + + return addresses, nil +} + +func (s *Service) areAllAddresses(addresses []common.Address) bool { + // Compare with addresses in accountsDB + walletAddresses, err := s.getWalletAddreses() + if err != nil { + log.Error("Error getting wallet addresses", "error", err) + return false + } + + // Check if passed addresses are the same as in the accountsDB ignoring the order + return areSlicesEqual(walletAddresses, addresses) +} + +// Comparison function to check if slices are the same ignoring the order +func areSlicesEqual(a, b []common.Address) bool { + if len(a) != len(b) { + return false + } + + // Create a map of addresses + aMap := make(map[common.Address]struct{}, len(a)) + for _, address := range a { + aMap[address] = struct{}{} + } + + // Check if all passed addresses are in the map + for _, address := range b { + if _, ok := aMap[address]; !ok { + return false + } + } + + return true +} diff --git a/services/wallet/activity/service_test.go b/services/wallet/activity/service_test.go index 02380809d..26c05f378 100644 --- a/services/wallet/activity/service_test.go +++ b/services/wallet/activity/service_test.go @@ -10,6 +10,8 @@ import ( eth "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/event" + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/thirdparty" @@ -73,6 +75,11 @@ func setupTestService(tb testing.TB) (state testState) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(tb, err) + appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) + require.NoError(tb, err) + accountsDB, err := accounts.NewDB(appDB) + require.NoError(tb, err) + state.eventFeed = new(event.Feed) state.tokenMock = &mockTokenManager{} state.collectiblesMock = &mockCollectiblesManager{} @@ -83,7 +90,7 @@ func setupTestService(tb testing.TB) (state testState) { pendingCheckInterval := time.Second state.pendingTracker = transactions.NewPendingTxTracker(db, state.chainClient, nil, state.eventFeed, pendingCheckInterval) - state.service = NewService(db, state.tokenMock, state.collectiblesMock, state.eventFeed, state.pendingTracker) + state.service = NewService(db, accountsDB, state.tokenMock, state.collectiblesMock, state.eventFeed, state.pendingTracker) state.service.debounceDuration = 0 state.close = func() { require.NoError(tb, state.pendingTracker.Stop()) @@ -171,7 +178,7 @@ func TestService_UpdateCollectibleInfo(t *testing.T) { }, }, nil).Once() - state.service.FilterActivityAsync(0, append(fromAddresses, toAddresses...), true, allNetworksFilter(), Filter{}, 0, 3) + state.service.FilterActivityAsync(0, append(fromAddresses, toAddresses...), allNetworksFilter(), Filter{}, 0, 3) filterResponseCount := 0 var updates []EntryData @@ -221,7 +228,7 @@ func TestService_UpdateCollectibleInfo_Error(t *testing.T) { state.collectiblesMock.On("FetchAssetsByCollectibleUniqueID", mock.Anything).Return(nil, thirdparty.ErrChainIDNotSupported).Once() - state.service.FilterActivityAsync(0, append(fromAddresses, toAddresses...), true, allNetworksFilter(), Filter{}, 0, 5) + state.service.FilterActivityAsync(0, append(fromAddresses, toAddresses...), allNetworksFilter(), Filter{}, 0, 5) filterResponseCount := 0 updatesCount := 0 @@ -378,7 +385,7 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) { allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}}) defer cleanup() - sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 5) + sessionID := state.service.StartFilterSession(allAddresses, allNetworksFilter(), Filter{}, 5) require.Greater(t, sessionID, SessionID(0)) defer state.service.StopFilterSession(sessionID) @@ -453,7 +460,7 @@ func TestService_IncrementalUpdateMixed(t *testing.T) { ) defer cleanup() - sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 5) + sessionID := state.service.StartFilterSession(allAddresses, allNetworksFilter(), Filter{}, 5) require.Greater(t, sessionID, SessionID(0)) defer state.service.StopFilterSession(sessionID) @@ -500,7 +507,7 @@ func TestService_IncrementalUpdateFetchWindow(t *testing.T) { allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}}) defer cleanup() - sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 2) + sessionID := state.service.StartFilterSession(allAddresses, allNetworksFilter(), Filter{}, 2) require.Greater(t, sessionID, SessionID(0)) defer state.service.StopFilterSession(sessionID) @@ -549,7 +556,7 @@ func TestService_IncrementalUpdateFetchWindowNoReset(t *testing.T) { allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}}) defer cleanup() - sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 2) + sessionID := state.service.StartFilterSession(allAddresses, allNetworksFilter(), Filter{}, 2) require.Greater(t, sessionID, SessionID(0)) defer state.service.StopFilterSession(sessionID) @@ -596,7 +603,7 @@ func TestService_FilteredIncrementalUpdateResetAndClear(t *testing.T) { allAddresses = append(append(allAddresses, newFromTrs...), newToTrs...) // 1. User visualizes transactions for the first time - sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 4) + sessionID := state.service.StartFilterSession(allAddresses, allNetworksFilter(), Filter{}, 4) require.Greater(t, sessionID, SessionID(0)) defer state.service.StopFilterSession(sessionID) diff --git a/services/wallet/activity/session.go b/services/wallet/activity/session.go index 69aaaf2cd..b6c162935 100644 --- a/services/wallet/activity/session.go +++ b/services/wallet/activity/session.go @@ -55,10 +55,9 @@ type Session struct { // Filter info // - addresses []eth.Address - allAddresses bool - chainIDs []common.ChainID - filter Filter + addresses []eth.Address + chainIDs []common.ChainID + filter Filter // model is a mirror of the data model presentation has (sent by EventActivityFilteringDone) model []EntryIdentity @@ -81,16 +80,16 @@ type SessionUpdate struct { } type fullFilterParams struct { - sessionID SessionID - addresses []eth.Address - allAddresses bool - chainIDs []common.ChainID - filter Filter + sessionID SessionID + addresses []eth.Address + chainIDs []common.ChainID + filter Filter } func (s *Service) internalFilter(f fullFilterParams, offset int, count int, processResults func(entries []Entry) (offsetOverride int)) { s.scheduler.Enqueue(int32(f.sessionID), filterTask, func(ctx context.Context) (interface{}, error) { - activities, err := getActivityEntries(ctx, s.getDeps(), f.addresses, f.allAddresses, f.chainIDs, f.filter, offset, count) + allAddresses := s.areAllAddresses(f.addresses) + activities, err := getActivityEntries(ctx, s.getDeps(), f.addresses, allAddresses, f.chainIDs, f.filter, offset, count) return activities, err }, func(result interface{}, taskType async.TaskType, err error) { res := FilterResponse{ @@ -131,11 +130,10 @@ func mirrorIdentities(entries []Entry) []EntryIdentity { func (s *Service) internalFilterForSession(session *Session, firstPageCount int) { s.internalFilter( fullFilterParams{ - sessionID: session.id, - addresses: session.addresses, - allAddresses: session.allAddresses, - chainIDs: session.chainIDs, - filter: session.filter, + sessionID: session.id, + addresses: session.addresses, + chainIDs: session.chainIDs, + filter: session.filter, }, 0, firstPageCount, @@ -150,16 +148,15 @@ func (s *Service) internalFilterForSession(session *Session, firstPageCount int) ) } -func (s *Service) StartFilterSession(addresses []eth.Address, allAddresses bool, chainIDs []common.ChainID, filter Filter, firstPageCount int) SessionID { +func (s *Service) StartFilterSession(addresses []eth.Address, chainIDs []common.ChainID, filter Filter, firstPageCount int) SessionID { sessionID := s.nextSessionID() session := &Session{ id: sessionID, - addresses: addresses, - allAddresses: allAddresses, - chainIDs: chainIDs, - filter: filter, + addresses: addresses, + chainIDs: chainIDs, + filter: filter, model: make([]EntryIdentity, 0, firstPageCount), } @@ -214,11 +211,10 @@ func (s *Service) UpdateFilterForSession(id SessionID, filter Filter, firstPageC // In this case we need to flag all the new entries that are not in the noFilterModel s.internalFilter( fullFilterParams{ - sessionID: session.id, - addresses: session.addresses, - allAddresses: session.allAddresses, - chainIDs: session.chainIDs, - filter: session.filter, + sessionID: session.id, + addresses: session.addresses, + chainIDs: session.chainIDs, + filter: session.filter, }, 0, firstPageCount, @@ -257,11 +253,10 @@ func (s *Service) ResetFilterSession(id SessionID, firstPageCount int) error { s.internalFilter( fullFilterParams{ - sessionID: id, - addresses: session.addresses, - allAddresses: session.allAddresses, - chainIDs: session.chainIDs, - filter: session.filter, + sessionID: id, + addresses: session.addresses, + chainIDs: session.chainIDs, + filter: session.filter, }, 0, firstPageCount, @@ -302,11 +297,10 @@ func (s *Service) GetMoreForFilterSession(id SessionID, pageCount int) error { prevModelLen := len(session.model) s.internalFilter( fullFilterParams{ - sessionID: id, - addresses: session.addresses, - allAddresses: session.allAddresses, - chainIDs: session.chainIDs, - filter: session.filter, + sessionID: id, + addresses: session.addresses, + chainIDs: session.chainIDs, + filter: session.filter, }, prevModelLen+len(session.new), pageCount, @@ -362,7 +356,8 @@ func (s *Service) detectNew(changeCount int) { session := s.sessions[sessionID] fetchLen := len(session.model) + changeCount - activities, err := getActivityEntries(context.Background(), s.getDeps(), session.addresses, session.allAddresses, session.chainIDs, session.filter, 0, fetchLen) + allAddresses := s.areAllAddresses(session.addresses) + activities, err := getActivityEntries(context.Background(), s.getDeps(), session.addresses, allAddresses, session.chainIDs, session.filter, 0, fetchLen) if err != nil { log.Error("Error getting activity entries", "error", err) continue diff --git a/services/wallet/api.go b/services/wallet/api.go index be919d8da..d8b9520ad 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -598,10 +598,10 @@ func (api *API) FetchAllCurrencyFormats() (currency.FormatPerSymbol, error) { } // @deprecated replaced by session APIs; see #12120 -func (api *API) FilterActivityAsync(requestID int32, addresses []common.Address, allAddresses bool, chainIDs []wcommon.ChainID, filter activity.Filter, offset int, limit int) error { - log.Debug("wallet.api.FilterActivityAsync", "requestID", requestID, "addr.count", len(addresses), "allAddresses", allAddresses, "chainIDs.count", len(chainIDs), "offset", offset, "limit", limit) +func (api *API) FilterActivityAsync(requestID int32, addresses []common.Address, chainIDs []wcommon.ChainID, filter activity.Filter, offset int, limit int) error { + log.Debug("wallet.api.FilterActivityAsync", "requestID", requestID, "addr.count", len(addresses), "chainIDs.count", len(chainIDs), "offset", offset, "limit", limit) - api.s.activity.FilterActivityAsync(requestID, addresses, allAddresses, chainIDs, filter, offset, limit) + api.s.activity.FilterActivityAsync(requestID, addresses, chainIDs, filter, offset, limit) return nil } @@ -613,10 +613,10 @@ func (api *API) CancelActivityFilterTask(requestID int32) error { return nil } -func (api *API) StartActivityFilterSession(addresses []common.Address, allAddresses bool, chainIDs []wcommon.ChainID, filter activity.Filter, firstPageCount int) (activity.SessionID, error) { - log.Debug("wallet.api.StartActivityFilterSession", "addr.count", len(addresses), "allAddresses", allAddresses, "chainIDs.count", len(chainIDs), "firstPageCount", firstPageCount) +func (api *API) StartActivityFilterSession(addresses []common.Address, chainIDs []wcommon.ChainID, filter activity.Filter, firstPageCount int) (activity.SessionID, error) { + log.Debug("wallet.api.StartActivityFilterSession", "addr.count", len(addresses), "chainIDs.count", len(chainIDs), "firstPageCount", firstPageCount) - return api.s.activity.StartFilterSession(addresses, allAddresses, chainIDs, filter, firstPageCount), nil + return api.s.activity.StartFilterSession(addresses, chainIDs, filter, firstPageCount), nil } func (api *API) UpdateActivityFilterForSession(sessionID activity.SessionID, filter activity.Filter, firstPageCount int) error { diff --git a/services/wallet/service.go b/services/wallet/service.go index 3e3c9bc6b..decec5175 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -168,7 +168,7 @@ func NewService( ) collectibles := collectibles.NewService(db, feed, accountsDB, accountFeed, settingsFeed, communityManager, rpcClient.NetworkManager, collectiblesManager) - activity := activity.NewService(db, tokenManager, collectiblesManager, feed, pendingTxManager) + activity := activity.NewService(db, accountsDB, tokenManager, collectiblesManager, feed, pendingTxManager) walletconnect := walletconnect.NewService(db, rpcClient.NetworkManager, accountsDB, transactionManager, gethManager, feed, config)