fix(wallet): cleanup balance_history table on account removed

Updates #4937
This commit is contained in:
Ivan Belyakov 2024-03-21 14:00:34 +01:00 committed by IvanBelyakoff
parent 5b7910ae5a
commit cc839ad7bc
4 changed files with 138 additions and 2 deletions

View File

@ -150,3 +150,8 @@ func (b *BalanceDB) getEntryPreviousTo(item *entry) (res *entry, err error) {
return res, nil return res, nil
} }
func (b *BalanceDB) removeBalanceHistory(address common.Address) error {
_, err := b.db.Exec("DELETE FROM balance_history WHERE address = ?", address)
return err
}

View File

@ -22,6 +22,7 @@ import (
"github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/chain"
"github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/services/wallet/balance" "github.com/status-im/status-go/services/wallet/balance"
"github.com/status-im/status-go/services/wallet/market" "github.com/status-im/status-go/services/wallet/market"
"github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/token"
@ -47,6 +48,7 @@ type Service struct {
balance *Balance balance *Balance
db *sql.DB db *sql.DB
accountsDB *accounts.Database accountsDB *accounts.Database
accountFeed *event.Feed
eventFeed *event.Feed eventFeed *event.Feed
rpcClient *statusrpc.Client rpcClient *statusrpc.Client
networkManager *network.Manager networkManager *network.Manager
@ -54,15 +56,17 @@ type Service struct {
serviceContext context.Context serviceContext context.Context
cancelFn context.CancelFunc cancelFn context.CancelFunc
transferWatcher *Watcher transferWatcher *Watcher
accWatcher *accountsevent.Watcher
exchange *Exchange exchange *Exchange
balanceCache balance.CacheIface balanceCache balance.CacheIface
} }
func NewService(db *sql.DB, accountsDB *accounts.Database, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager, marketManager *market.Manager, balanceCache balance.CacheIface) *Service { func NewService(db *sql.DB, accountsDB *accounts.Database, accountFeed *event.Feed, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager, marketManager *market.Manager, balanceCache balance.CacheIface) *Service {
return &Service{ return &Service{
balance: NewBalance(NewBalanceDB(db)), balance: NewBalance(NewBalanceDB(db)),
db: db, db: db,
accountsDB: accountsDB, accountsDB: accountsDB,
accountFeed: accountFeed,
eventFeed: eventFeed, eventFeed: eventFeed,
rpcClient: rpcClient, rpcClient: rpcClient,
networkManager: rpcClient.NetworkManager, networkManager: rpcClient.NetworkManager,
@ -78,6 +82,7 @@ func (s *Service) Stop() {
} }
s.stopTransfersWatcher() s.stopTransfersWatcher()
s.stopAccountWatcher()
} }
func (s *Service) triggerEvent(eventType walletevent.EventType, account statustypes.Address, message string) { func (s *Service) triggerEvent(eventType walletevent.EventType, account statustypes.Address, message string) {
@ -94,6 +99,7 @@ func (s *Service) Start() {
log.Debug("Starting balance history service") log.Debug("Starting balance history service")
s.startTransfersWatcher() s.startTransfersWatcher()
s.startAccountWatcher()
go func() { go func() {
s.serviceContext, s.cancelFn = context.WithCancel(context.Background()) s.serviceContext, s.cancelFn = context.WithCancel(context.Background())
@ -563,3 +569,30 @@ func (s *Service) stopTransfersWatcher() {
s.transferWatcher = nil s.transferWatcher = nil
} }
} }
func (s *Service) startAccountWatcher() {
if s.accWatcher == nil {
s.accWatcher = accountsevent.NewWatcher(s.accountsDB, s.accountFeed, func(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) {
s.onAccountsChanged(changedAddresses, eventType, currentAddresses)
})
}
s.accWatcher.Start()
}
func (s *Service) stopAccountWatcher() {
if s.accWatcher != nil {
s.accWatcher.Stop()
s.accWatcher = nil
}
}
func (s *Service) onAccountsChanged(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) {
if eventType == accountsevent.EventTypeRemoved {
for _, address := range changedAddresses {
err := s.balance.db.removeBalanceHistory(address)
if err != nil {
log.Error("Error removing balance history", "address", address, "err", err)
}
}
}
}

View File

@ -1,12 +1,29 @@
package history package history
import ( import (
"errors"
"math/big" "math/big"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/event"
gethrpc "github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/t/utils"
"github.com/status-im/status-go/transactions/fake"
"github.com/status-im/status-go/walletdatabase"
) )
func Test_entriesToDataPoints(t *testing.T) { func Test_entriesToDataPoints(t *testing.T) {
@ -336,3 +353,84 @@ func Test_entriesToDataPoints(t *testing.T) {
}) })
} }
} }
func Test_removeBalanceHistoryOnEventAccountRemoved(t *testing.T) {
appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
accountsDB, err := accounts.NewDB(appDB)
require.NoError(t, err)
address := common.HexToAddress("0x1234")
accountFeed := event.Feed{}
walletFeed := event.Feed{}
chainID := uint64(1)
txServiceMockCtrl := gomock.NewController(t)
server, _ := fake.NewTestServer(txServiceMockCtrl)
client := gethrpc.DialInProc(server)
rpcClient, _ := rpc.NewClient(client, chainID, params.UpstreamRPCConfig{}, nil, nil)
rpcClient.UpstreamChainID = chainID
service := NewService(walletDB, accountsDB, &accountFeed, &walletFeed, rpcClient, nil, nil, nil)
// Insert balances for address
database := service.balance.db
err = database.add(&entry{
chainID: chainID,
address: address,
block: big.NewInt(1),
balance: big.NewInt(1),
timestamp: 1,
tokenSymbol: "ETH",
})
require.NoError(t, err)
err = database.add(&entry{
chainID: chainID,
address: address,
block: big.NewInt(2),
balance: big.NewInt(2),
tokenSymbol: "ETH",
timestamp: 2,
})
require.NoError(t, err)
entries, err := database.getNewerThan(&assetIdentity{chainID, []common.Address{address}, "ETH"}, 0)
require.NoError(t, err)
require.Len(t, entries, 2)
// Start service
service.startAccountWatcher()
// Watching accounts must start before sending event.
// To avoid running goroutine immediately and let the controller subscribe first,
// use any delay.
group := sync.WaitGroup{}
group.Add(1)
go func() {
defer group.Done()
time.Sleep(1 * time.Millisecond)
accountFeed.Send(accountsevent.Event{
Type: accountsevent.EventTypeRemoved,
Accounts: []common.Address{address},
})
require.NoError(t, utils.Eventually(func() error {
entries, err := database.getNewerThan(&assetIdentity{1, []common.Address{address}, "ETH"}, 0)
if err == nil && len(entries) == 0 {
return nil
}
return errors.New("data is not removed")
}, 100*time.Millisecond, 10*time.Millisecond))
}()
group.Wait()
// Stop service
txServiceMockCtrl.Finish()
server.Stop()
service.stopAccountWatcher()
}

View File

@ -113,7 +113,7 @@ func NewService(
coingecko := coingecko.NewClient() coingecko := coingecko.NewClient()
marketManager := market.NewManager(cryptoCompare, coingecko, feed) marketManager := market.NewManager(cryptoCompare, coingecko, feed)
reader := NewReader(rpcClient, tokenManager, marketManager, communityManager, accountsDB, NewPersistence(db), feed) reader := NewReader(rpcClient, tokenManager, marketManager, communityManager, accountsDB, NewPersistence(db), feed)
history := history.NewService(db, accountsDB, feed, rpcClient, tokenManager, marketManager, balanceCacher.Cache()) history := history.NewService(db, accountsDB, accountFeed, feed, rpcClient, tokenManager, marketManager, balanceCacher.Cache())
currency := currency.NewService(db, feed, tokenManager, marketManager) currency := currency.NewService(db, feed, tokenManager, marketManager)
openseaHTTPClient := opensea.NewHTTPClient() openseaHTTPClient := opensea.NewHTTPClient()