diff --git a/services/wallet/history/balance_db.go b/services/wallet/history/balance_db.go index b4a161865..504ae83c0 100644 --- a/services/wallet/history/balance_db.go +++ b/services/wallet/history/balance_db.go @@ -150,3 +150,8 @@ func (b *BalanceDB) getEntryPreviousTo(item *entry) (res *entry, err error) { return res, nil } + +func (b *BalanceDB) removeBalanceHistory(address common.Address) error { + _, err := b.db.Exec("DELETE FROM balance_history WHERE address = ?", address) + return err +} diff --git a/services/wallet/history/service.go b/services/wallet/history/service.go index 102fe8e69..a09778836 100644 --- a/services/wallet/history/service.go +++ b/services/wallet/history/service.go @@ -22,6 +22,7 @@ import ( "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/network" + "github.com/status-im/status-go/services/accounts/accountsevent" "github.com/status-im/status-go/services/wallet/balance" "github.com/status-im/status-go/services/wallet/market" "github.com/status-im/status-go/services/wallet/token" @@ -47,6 +48,7 @@ type Service struct { balance *Balance db *sql.DB accountsDB *accounts.Database + accountFeed *event.Feed eventFeed *event.Feed rpcClient *statusrpc.Client networkManager *network.Manager @@ -54,15 +56,17 @@ type Service struct { serviceContext context.Context cancelFn context.CancelFunc transferWatcher *Watcher + accWatcher *accountsevent.Watcher exchange *Exchange balanceCache balance.CacheIface } -func NewService(db *sql.DB, accountsDB *accounts.Database, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager, marketManager *market.Manager, balanceCache balance.CacheIface) *Service { +func NewService(db *sql.DB, accountsDB *accounts.Database, accountFeed *event.Feed, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager, marketManager *market.Manager, balanceCache balance.CacheIface) *Service { return &Service{ balance: NewBalance(NewBalanceDB(db)), db: db, accountsDB: accountsDB, + accountFeed: accountFeed, eventFeed: eventFeed, rpcClient: rpcClient, networkManager: rpcClient.NetworkManager, @@ -78,6 +82,7 @@ func (s *Service) Stop() { } s.stopTransfersWatcher() + s.stopAccountWatcher() } func (s *Service) triggerEvent(eventType walletevent.EventType, account statustypes.Address, message string) { @@ -94,6 +99,7 @@ func (s *Service) Start() { log.Debug("Starting balance history service") s.startTransfersWatcher() + s.startAccountWatcher() go func() { s.serviceContext, s.cancelFn = context.WithCancel(context.Background()) @@ -563,3 +569,30 @@ func (s *Service) stopTransfersWatcher() { s.transferWatcher = nil } } + +func (s *Service) startAccountWatcher() { + if s.accWatcher == nil { + s.accWatcher = accountsevent.NewWatcher(s.accountsDB, s.accountFeed, func(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) { + s.onAccountsChanged(changedAddresses, eventType, currentAddresses) + }) + } + s.accWatcher.Start() +} + +func (s *Service) stopAccountWatcher() { + if s.accWatcher != nil { + s.accWatcher.Stop() + s.accWatcher = nil + } +} + +func (s *Service) onAccountsChanged(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) { + if eventType == accountsevent.EventTypeRemoved { + for _, address := range changedAddresses { + err := s.balance.db.removeBalanceHistory(address) + if err != nil { + log.Error("Error removing balance history", "address", address, "err", err) + } + } + } +} diff --git a/services/wallet/history/service_test.go b/services/wallet/history/service_test.go index 6906f6aa1..4e07a7209 100644 --- a/services/wallet/history/service_test.go +++ b/services/wallet/history/service_test.go @@ -1,12 +1,29 @@ package history import ( + "errors" "math/big" "reflect" + "sync" "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/event" + gethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/multiaccounts/accounts" + "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/services/accounts/accountsevent" + "github.com/status-im/status-go/t/helpers" + "github.com/status-im/status-go/t/utils" + "github.com/status-im/status-go/transactions/fake" + "github.com/status-im/status-go/walletdatabase" ) func Test_entriesToDataPoints(t *testing.T) { @@ -336,3 +353,84 @@ func Test_entriesToDataPoints(t *testing.T) { }) } } + +func Test_removeBalanceHistoryOnEventAccountRemoved(t *testing.T) { + appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) + require.NoError(t, err) + + walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + + accountsDB, err := accounts.NewDB(appDB) + require.NoError(t, err) + + address := common.HexToAddress("0x1234") + accountFeed := event.Feed{} + walletFeed := event.Feed{} + chainID := uint64(1) + txServiceMockCtrl := gomock.NewController(t) + server, _ := fake.NewTestServer(txServiceMockCtrl) + client := gethrpc.DialInProc(server) + rpcClient, _ := rpc.NewClient(client, chainID, params.UpstreamRPCConfig{}, nil, nil) + rpcClient.UpstreamChainID = chainID + + service := NewService(walletDB, accountsDB, &accountFeed, &walletFeed, rpcClient, nil, nil, nil) + + // Insert balances for address + database := service.balance.db + err = database.add(&entry{ + chainID: chainID, + address: address, + block: big.NewInt(1), + balance: big.NewInt(1), + timestamp: 1, + tokenSymbol: "ETH", + }) + require.NoError(t, err) + err = database.add(&entry{ + chainID: chainID, + address: address, + block: big.NewInt(2), + balance: big.NewInt(2), + tokenSymbol: "ETH", + timestamp: 2, + }) + require.NoError(t, err) + + entries, err := database.getNewerThan(&assetIdentity{chainID, []common.Address{address}, "ETH"}, 0) + require.NoError(t, err) + require.Len(t, entries, 2) + + // Start service + service.startAccountWatcher() + + // Watching accounts must start before sending event. + // To avoid running goroutine immediately and let the controller subscribe first, + // use any delay. + group := sync.WaitGroup{} + group.Add(1) + go func() { + defer group.Done() + time.Sleep(1 * time.Millisecond) + + accountFeed.Send(accountsevent.Event{ + Type: accountsevent.EventTypeRemoved, + Accounts: []common.Address{address}, + }) + + require.NoError(t, utils.Eventually(func() error { + entries, err := database.getNewerThan(&assetIdentity{1, []common.Address{address}, "ETH"}, 0) + if err == nil && len(entries) == 0 { + return nil + } + return errors.New("data is not removed") + }, 100*time.Millisecond, 10*time.Millisecond)) + }() + + group.Wait() + + // Stop service + txServiceMockCtrl.Finish() + server.Stop() + service.stopAccountWatcher() +} diff --git a/services/wallet/service.go b/services/wallet/service.go index c57c4ad41..b4c0ce3b5 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -113,7 +113,7 @@ func NewService( coingecko := coingecko.NewClient() marketManager := market.NewManager(cryptoCompare, coingecko, feed) reader := NewReader(rpcClient, tokenManager, marketManager, communityManager, accountsDB, NewPersistence(db), feed) - history := history.NewService(db, accountsDB, feed, rpcClient, tokenManager, marketManager, balanceCacher.Cache()) + history := history.NewService(db, accountsDB, accountFeed, feed, rpcClient, tokenManager, marketManager, balanceCacher.Cache()) currency := currency.NewService(db, feed, tokenManager, marketManager) openseaHTTPClient := opensea.NewHTTPClient()