feat(wallet): remove transfers data on account removal

Closes:  #4394
This commit is contained in:
Ivan Belyakov 2023-11-28 15:23:03 +01:00 committed by IvanBelyakoff
parent 5e6768a42a
commit c0f2f76e9a
8 changed files with 381 additions and 78 deletions

View File

@ -368,6 +368,16 @@ func deleteRange(chainID uint64, creator statementCreator, account common.Addres
return err
}
func deleteAllRanges(creator statementCreator, account common.Address) error {
delete, err := creator.Prepare(`DELETE FROM blocks_ranges WHERE address = ?`)
if err != nil {
return err
}
_, err = delete.Exec(account)
return err
}
func insertRange(chainID uint64, creator statementCreator, account common.Address, from *big.Int, to *big.Int) error {
log.Info("insert blocks range", "account", account, "network", chainID, "from", from, "to", to)
insert, err := creator.Prepare("INSERT INTO blocks_ranges (network_id, address, blk_from, blk_to) VALUES (?, ?, ?, ?)")

View File

@ -47,20 +47,15 @@ func (b *BlockRangeSequentialDAO) getBlockRange(chainID uint64, address common.A
return nil, nil
}
// TODO call it when account is removed
//
//lint:ignore U1000 Ignore unused function temporarily
func (b *BlockRangeSequentialDAO) deleteRange(chainID uint64, account common.Address) error {
log.Debug("delete blocks range", "account", account, "network", chainID)
delete, err := b.db.Prepare(`DELETE FROM blocks_ranges_sequential
WHERE address = ?
AND network_id = ?`)
func (b *BlockRangeSequentialDAO) deleteRange(account common.Address) error {
log.Debug("delete blocks range", "account", account)
delete, err := b.db.Prepare(`DELETE FROM blocks_ranges_sequential WHERE address = ?`)
if err != nil {
log.Error("Failed to prepare deletion of sequential block range", "error", err)
return err
}
_, err = delete.Exec(account, chainID)
_, err = delete.Exec(account)
return err
}

View File

@ -514,14 +514,14 @@ func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *Bl
}
func newLoadBlocksAndTransfersCommand(account common.Address, db *Database,
blockDAO *BlockDAO, chainClient chain.ClientInterface, feed *event.Feed,
blockDAO *BlockDAO, blockRangesSeqDAO *BlockRangeSequentialDAO, chainClient chain.ClientInterface, feed *event.Feed,
transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker,
tokenManager *token.Manager, balanceCacher balance.Cacher, omitHistory bool) *loadBlocksAndTransfersCommand {
return &loadBlocksAndTransfersCommand{
account: account,
db: db,
blockRangeDAO: &BlockRangeSequentialDAO{db.client},
blockRangeDAO: blockRangesSeqDAO,
blockDAO: blockDAO,
chainClient: chainClient,
feed: feed,

View File

@ -6,6 +6,8 @@ import (
"fmt"
"math/big"
"golang.org/x/exp/slices" // since 1.21, this is in the standard library
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/event"
@ -13,9 +15,7 @@ import (
statusaccounts "github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/multiaccounts/settings"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/chain"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/services/wallet/async"
"github.com/status-im/status-go/services/wallet/balance"
"github.com/status-im/status-go/services/wallet/token"
"github.com/status-im/status-go/transactions"
@ -26,10 +26,11 @@ type Controller struct {
accountsDB *statusaccounts.Database
rpcClient *rpc.Client
blockDAO *BlockDAO
blockRangesSeqDAO *BlockRangeSequentialDAO
reactor *Reactor
accountFeed *event.Feed
TransferFeed *event.Feed
group *async.Group
accWatcher *accountsevent.Watcher
transactionManager *TransactionManager
pendingTxManager *transactions.PendingTxTracker
tokenManager *token.Manager
@ -45,6 +46,7 @@ func NewTransferController(db *sql.DB, accountsDB *statusaccounts.Database, rpcC
db: NewDB(db),
accountsDB: accountsDB,
blockDAO: blockDAO,
blockRangesSeqDAO: &BlockRangeSequentialDAO{db},
rpcClient: rpcClient,
accountFeed: accountFeed,
TransferFeed: transferFeed,
@ -56,7 +58,7 @@ func NewTransferController(db *sql.DB, accountsDB *statusaccounts.Database, rpcC
}
func (c *Controller) Start() {
c.group = async.NewGroup(context.Background())
go func() { _ = c.cleanupAccountsLeftovers() }()
}
func (c *Controller) Stop() {
@ -64,10 +66,9 @@ func (c *Controller) Stop() {
c.reactor.stop()
}
if c.group != nil {
c.group.Stop()
c.group.Wait()
c.group = nil
if c.accWatcher != nil {
c.accWatcher.Stop()
c.accWatcher = nil
}
}
@ -109,7 +110,7 @@ func (c *Controller) CheckRecentHistory(chainIDs []uint64, accounts []common.Add
}
}
c.reactor = NewReactor(c.db, c.blockDAO, c.TransferFeed, c.transactionManager,
c.reactor = NewReactor(c.db, c.blockDAO, c.blockRangesSeqDAO, c.TransferFeed, c.transactionManager,
c.pendingTxManager, c.tokenManager, c.balanceCacher, omitHistory)
err = c.reactor.start(chainClients, accounts)
@ -117,68 +118,45 @@ func (c *Controller) CheckRecentHistory(chainIDs []uint64, accounts []common.Add
return err
}
c.group.Add(func(ctx context.Context) error {
return watchAccountsChanges(ctx, c.accountFeed, c.reactor, chainClients, accounts)
})
c.startAccountWatcher(chainIDs)
}
return nil
}
// watchAccountsChanges subscribes to a feed and watches for changes in accounts list. If there are new or removed accounts
// reactor will be restarted.
func watchAccountsChanges(ctx context.Context, accountFeed *event.Feed, reactor *Reactor,
chainClients map[uint64]chain.ClientInterface, initial []common.Address) error {
ch := make(chan accountsevent.Event, 1) // it may block if the rate of updates will be significantly higher
sub := accountFeed.Subscribe(ch)
defer sub.Unsubscribe()
listen := make(map[common.Address]struct{}, len(initial))
for _, address := range initial {
listen[address] = struct{}{}
}
for {
select {
case <-ctx.Done():
return nil
case err := <-sub.Err():
if err != nil {
log.Error("accounts watcher subscription failed", "error", err)
}
case ev := <-ch:
restart := false
for _, address := range ev.Accounts {
_, exist := listen[address]
if ev.Type == accountsevent.EventTypeAdded && !exist {
listen[address] = struct{}{}
restart = true
} else if ev.Type == accountsevent.EventTypeRemoved && exist {
delete(listen, address)
restart = true
}
}
if !restart {
continue
}
listenList := mapToList(listen)
log.Debug("list of accounts was changed from a previous version. reactor will be restarted", "new", listenList)
err := reactor.restart(chainClients, listenList)
if err != nil {
log.Error("failed to restart reactor with new accounts", "error", err)
}
}
func (c *Controller) startAccountWatcher(chainIDs []uint64) {
if c.accWatcher == nil {
c.accWatcher = accountsevent.NewWatcher(c.accountsDB, c.accountFeed, func(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) {
c.onAccountsChanged(changedAddresses, eventType, currentAddresses, chainIDs)
})
c.accWatcher.Start()
}
}
func mapToList(m map[common.Address]struct{}) []common.Address {
rst := make([]common.Address, 0, len(m))
for address := range m {
rst = append(rst, address)
func (c *Controller) onAccountsChanged(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address, chainIDs []uint64) {
if eventType == accountsevent.EventTypeRemoved {
for _, address := range changedAddresses {
c.cleanUpRemovedAccount(address)
}
}
if c.reactor == nil {
log.Warn("reactor is not initialized")
return
}
if eventType == accountsevent.EventTypeAdded || eventType == accountsevent.EventTypeRemoved {
log.Debug("list of accounts was changed from a previous version. reactor will be restarted", "new", currentAddresses)
chainClients, err := c.rpcClient.EthClients(chainIDs)
if err != nil {
return
}
err = c.reactor.restart(chainClients, currentAddresses)
if err != nil {
log.Error("failed to restart reactor with new accounts", "error", err)
}
}
return rst
}
// Only used by status-mobile
@ -250,3 +228,58 @@ func (c *Controller) GetCachedBalances(ctx context.Context, chainID uint64, addr
return blocksToViews(result), nil
}
func (c *Controller) cleanUpRemovedAccount(address common.Address) {
// Transfers will be deleted by foreign key constraint by cascade
err := deleteBlocks(c.db.client, address)
if err != nil {
log.Error("Failed to delete blocks", "error", err)
}
err = deleteAllRanges(c.db.client, address)
if err != nil {
log.Error("Failed to delete old blocks ranges", "error", err)
}
err = c.blockRangesSeqDAO.deleteRange(address)
if err != nil {
log.Error("Failed to delete blocks ranges sequential", "error", err)
}
}
func (c *Controller) cleanupAccountsLeftovers() error {
// We clean up accounts that were deleted and soft removed
accounts, err := c.accountsDB.GetWalletAddresses()
if err != nil {
log.Error("Failed to get accounts", "error", err)
return err
}
existingAddresses := make([]common.Address, len(accounts))
for i, account := range accounts {
existingAddresses[i] = (common.Address)(account)
}
addressesInWalletDB, err := getAddresses(c.db.client)
if err != nil {
log.Error("Failed to get addresses from wallet db", "error", err)
return err
}
missing := findMissingItems(addressesInWalletDB, existingAddresses)
for _, address := range missing {
c.cleanUpRemovedAccount(address)
}
return nil
}
// find items from one slice that are not in another
func findMissingItems(slice1 []common.Address, slice2 []common.Address) []common.Address {
var missing []common.Address
for _, item := range slice1 {
if !slices.Contains(slice2, item) {
missing = append(missing, item)
}
}
return missing
}

View File

@ -0,0 +1,224 @@
package transfer
import (
"math/big"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/event"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/walletdatabase"
)
func TestController_watchAccountsChanges(t *testing.T) {
appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
accountsDB, err := accounts.NewDB(appDB)
require.NoError(t, err)
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
accountFeed := &event.Feed{}
c := NewTransferController(
walletDB,
accountsDB,
nil, // rpcClient
accountFeed,
nil, // transferFeed
nil, // transactionManager
nil, // pendingTxManager
nil, // tokenManager
nil, // balanceCacher
)
address := common.HexToAddress("0x1234")
chainID := uint64(777)
// Insert blocks
database := NewDB(walletDB)
err = database.SaveBlocks(chainID, address, []*DBHeader{
{
Number: big.NewInt(1),
Hash: common.Hash{1},
Network: chainID,
Address: address,
Loaded: false,
},
})
require.NoError(t, err)
// Insert transfers
err = saveTransfersMarkBlocksLoaded(walletDB, chainID, address, []Transfer{
{
ID: common.Hash{1},
BlockHash: common.Hash{1},
BlockNumber: big.NewInt(1),
Address: address,
NetworkID: chainID,
},
}, []*big.Int{big.NewInt(1)})
require.NoError(t, err)
// Insert block ranges
blockRangesDAO := &BlockRangeSequentialDAO{walletDB}
err = blockRangesDAO.upsertRange(chainID, address, NewBlockRange())
require.NoError(t, err)
ranges, err := blockRangesDAO.getBlockRange(chainID, address)
require.NoError(t, err)
require.NotNil(t, ranges)
ch := make(chan accountsevent.Event)
// Subscribe for account changes
accountFeed.Subscribe(ch)
// Watching accounts must start before sending event.
// To avoid running goroutine immediately, use any delay.
go func() {
time.Sleep(1 * time.Millisecond)
accountFeed.Send(accountsevent.Event{
Type: accountsevent.EventTypeRemoved,
Accounts: []common.Address{address},
})
}()
c.startAccountWatcher([]uint64{chainID})
// Wait for event
<-ch
// Wait for DB to be cleaned up
c.accWatcher.Stop()
// Check that transfers, blocks and block ranges were deleted
transfers, err := database.GetTransfersByAddress(chainID, address, big.NewInt(2), 1)
require.NoError(t, err)
require.Len(t, transfers, 0)
blocksDAO := &BlockDAO{walletDB}
block, err := blocksDAO.GetLastBlockByAddress(chainID, address, 1)
require.NoError(t, err)
require.Nil(t, block)
ranges, err = blockRangesDAO.getBlockRange(chainID, address)
require.NoError(t, err)
require.Nil(t, ranges)
}
func TestController_cleanupAccountLeftovers(t *testing.T) {
appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
accountsDB, err := accounts.NewDB(appDB)
require.NoError(t, err)
removedAddr := common.HexToAddress("0x5678")
existingAddr := types.HexToAddress("0x1234")
accounts := []*accounts.Account{
{Address: existingAddr, Chat: false, Wallet: true},
}
err = accountsDB.SaveOrUpdateAccounts(accounts, false)
require.NoError(t, err)
storedAccs, err := accountsDB.GetWalletAddresses()
require.NoError(t, err)
require.Len(t, storedAccs, 1)
c := NewTransferController(
walletDB,
accountsDB,
nil, // rpcClient
nil, // accountFeed
nil, // transferFeed
nil, // transactionManager
nil, // pendingTxManager
nil, // tokenManager
nil, // balanceCacher
)
chainID := uint64(777)
// Insert blocks
database := NewDB(walletDB)
err = database.SaveBlocks(chainID, removedAddr, []*DBHeader{
{
Number: big.NewInt(1),
Hash: common.Hash{1},
Network: chainID,
Address: removedAddr,
Loaded: false,
},
})
require.NoError(t, err)
err = database.SaveBlocks(chainID, common.Address(existingAddr), []*DBHeader{
{
Number: big.NewInt(2),
Hash: common.Hash{2},
Network: chainID,
Address: common.Address(existingAddr),
Loaded: false,
},
})
require.NoError(t, err)
blocksDAO := &BlockDAO{walletDB}
block, err := blocksDAO.GetLastBlockByAddress(chainID, removedAddr, 1)
require.NoError(t, err)
require.NotNil(t, block)
block, err = blocksDAO.GetLastBlockByAddress(chainID, common.Address(existingAddr), 1)
require.NoError(t, err)
require.NotNil(t, block)
// Insert transfers
err = saveTransfersMarkBlocksLoaded(walletDB, chainID, removedAddr, []Transfer{
{
ID: common.Hash{1},
BlockHash: common.Hash{1},
BlockNumber: big.NewInt(1),
Address: removedAddr,
NetworkID: chainID,
},
}, []*big.Int{big.NewInt(1)})
require.NoError(t, err)
err = saveTransfersMarkBlocksLoaded(walletDB, chainID, common.Address(existingAddr), []Transfer{
{
ID: common.Hash{2},
BlockHash: common.Hash{2},
BlockNumber: big.NewInt(2),
Address: common.Address(existingAddr),
NetworkID: chainID,
},
}, []*big.Int{big.NewInt(2)})
require.NoError(t, err)
err = c.cleanupAccountsLeftovers()
require.NoError(t, err)
// Check that transfers and blocks of removed account were deleted
transfers, err := database.GetTransfers(chainID, big.NewInt(1), big.NewInt(2))
require.NoError(t, err)
require.Len(t, transfers, 1)
require.Equal(t, transfers[0].Address, common.Address(existingAddr))
block, err = blocksDAO.GetLastBlockByAddress(chainID, removedAddr, 1)
require.NoError(t, err)
require.Nil(t, block)
// Make sure that transfers and blocks of existing account were not deleted
existingBlock, err := blocksDAO.GetLastBlockByAddress(chainID, common.Address(existingAddr), 1)
require.NoError(t, err)
require.NotNil(t, existingBlock)
}

View File

@ -170,7 +170,7 @@ func (db *Database) ProcessTransfers(chainID uint64, transfers []Transfer, remov
return
}
func saveTransfersMarkBlocksLoaded(tx *sql.Tx, chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) {
func saveTransfersMarkBlocksLoaded(tx statementCreator, chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) {
err = updateOrInsertTransfers(chainID, tx, transfers)
if err != nil {
return
@ -592,3 +592,39 @@ func GetOwnedMultiTransactionID(tx *sql.Tx, chainID w_common.ChainID, id common.
}
return mTID, nil
}
// Delete blocks for address and chainID
// Transfers will be deleted by cascade
func deleteBlocks(creator statementCreator, address common.Address) error {
delete, err := creator.Prepare("DELETE FROM blocks WHERE address = ?")
if err != nil {
return err
}
_, err = delete.Exec(address)
return err
}
func getAddresses(creator statementCreator) (rst []common.Address, err error) {
stmt, err := creator.Prepare(`SELECT address FROM transfers UNION SELECT address FROM blocks UNION
SELECT address FROM blocks_ranges_sequential UNION SELECT address FROM blocks_ranges`)
if err != nil {
return
}
rows, err := stmt.Query()
if err != nil {
return nil, err
}
defer rows.Close()
address := common.Address{}
for rows.Next() {
err = rows.Scan(&address)
if err != nil {
return nil, err
}
rst = append(rst, address)
}
return rst, nil
}

View File

@ -48,6 +48,7 @@ type HistoryFetcher interface {
type Reactor struct {
db *Database
blockDAO *BlockDAO
blockRangesSeqDAO *BlockRangeSequentialDAO
feed *event.Feed
transactionManager *TransactionManager
pendingTxManager *transactions.PendingTxTracker
@ -57,12 +58,13 @@ type Reactor struct {
omitHistory bool
}
func NewReactor(db *Database, blockDAO *BlockDAO, feed *event.Feed, tm *TransactionManager,
func NewReactor(db *Database, blockDAO *BlockDAO, blockRangesSeqDAO *BlockRangeSequentialDAO, feed *event.Feed, tm *TransactionManager,
pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager,
balanceCacher balance.Cacher, omitHistory bool) *Reactor {
return &Reactor{
db: db,
blockDAO: blockDAO,
blockRangesSeqDAO: blockRangesSeqDAO,
feed: feed,
transactionManager: tm,
pendingTxManager: pendingTxManager,
@ -98,6 +100,7 @@ func (r *Reactor) createFetchStrategy(chainClients map[uint64]chain.ClientInterf
return NewSequentialFetchStrategy(
r.db,
r.blockDAO,
r.blockRangesSeqDAO,
r.feed,
r.transactionManager,
r.pendingTxManager,

View File

@ -16,7 +16,7 @@ import (
"github.com/status-im/status-go/transactions"
)
func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Feed,
func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, blockRangesSeqDAO *BlockRangeSequentialDAO, feed *event.Feed,
transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker,
tokenManager *token.Manager,
chainClients map[uint64]chain.ClientInterface,
@ -28,6 +28,7 @@ func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Fe
return &SequentialFetchStrategy{
db: db,
blockDAO: blockDAO,
blockRangesSeqDAO: blockRangesSeqDAO,
feed: feed,
transactionManager: transactionManager,
pendingTxManager: pendingTxManager,
@ -42,6 +43,7 @@ func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Fe
type SequentialFetchStrategy struct {
db *Database
blockDAO *BlockDAO
blockRangesSeqDAO *BlockRangeSequentialDAO
feed *event.Feed
mu sync.Mutex
group *async.Group
@ -57,7 +59,7 @@ type SequentialFetchStrategy struct {
func (s *SequentialFetchStrategy) newCommand(chainClient chain.ClientInterface,
account common.Address) async.Commander {
return newLoadBlocksAndTransfersCommand(account, s.db, s.blockDAO, chainClient, s.feed,
return newLoadBlocksAndTransfersCommand(account, s.db, s.blockDAO, s.blockRangesSeqDAO, chainClient, s.feed,
s.transactionManager, s.pendingTxManager, s.tokenManager, s.balanceCacher, s.omitHistory)
}