fix(wallet): cleanup token_balances table on account removed

Updates #4937
This commit is contained in:
Ivan Belyakov 2024-03-22 10:45:43 +01:00 committed by IvanBelyakoff
parent cc839ad7bc
commit c21e6430a2
5 changed files with 194 additions and 18 deletions

View File

@ -464,7 +464,7 @@ func NewMessenger(
if c.tokenManager != nil { if c.tokenManager != nil {
managerOptions = append(managerOptions, communities.WithTokenManager(c.tokenManager)) managerOptions = append(managerOptions, communities.WithTokenManager(c.tokenManager))
} else if c.rpcClient != nil { } else if c.rpcClient != nil {
tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil) tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil, nil, nil)
managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager))) managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager)))
} }

View File

@ -102,7 +102,8 @@ func NewService(
communityManager := community.NewManager(db, mediaServer, feed) communityManager := community.NewManager(db, mediaServer, feed)
balanceCacher := balance.NewCacherWithTTL(5 * time.Minute) balanceCacher := balance.NewCacherWithTTL(5 * time.Minute)
tokenManager := token.NewTokenManager(db, rpcClient, communityManager, rpcClient.NetworkManager, appDB, mediaServer, feed) tokenManager := token.NewTokenManager(db, rpcClient, communityManager, rpcClient.NetworkManager, appDB, mediaServer, feed, accountFeed, accountsDB)
tokenManager.Start()
savedAddressesManager := &SavedAddressesManager{db: db} savedAddressesManager := &SavedAddressesManager{db: db}
transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, feed) transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, feed)
blockChainState := blockchainstate.NewBlockChainState() blockChainState := blockchainstate.NewBlockChainState()
@ -262,6 +263,7 @@ func (s *Service) Stop() error {
s.history.Stop() s.history.Stop()
s.activity.Stop() s.activity.Stop()
s.collectibles.Stop() s.collectibles.Stop()
s.tokenManager.Stop()
s.started = false s.started = false
log.Info("wallet stopped") log.Info("wallet stopped")
return nil return nil

View File

@ -21,12 +21,14 @@ import (
"github.com/status-im/status-go/contracts/ethscan" "github.com/status-im/status-go/contracts/ethscan"
"github.com/status-im/status-go/contracts/ierc20" "github.com/status-im/status-go/contracts/ierc20"
eth_node_types "github.com/status-im/status-go/eth-node/types" eth_node_types "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/protocol/communities/token" "github.com/status-im/status-go/protocol/communities/token"
"github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/chain"
"github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/server" "github.com/status-im/status-go/server"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/services/communitytokens" "github.com/status-im/status-go/services/communitytokens"
"github.com/status-im/status-go/services/utils" "github.com/status-im/status-go/services/utils"
"github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/async"
@ -104,6 +106,9 @@ type Manager struct {
communityManager *community.Manager communityManager *community.Manager
mediaServer *server.MediaServer mediaServer *server.MediaServer
walletFeed *event.Feed walletFeed *event.Feed
accountFeed *event.Feed
accountWatcher *accountsevent.Watcher
accountsDB *accounts.Database
tokens []*Token tokens []*Token
@ -125,17 +130,7 @@ func mergeTokens(sliceLists [][]*Token) []*Token {
return res return res
} }
func NewTokenManager( func prepareTokens(networkManager *network.Manager, stores []store) []*Token {
db *sql.DB,
RPCClient *rpc.Client,
communityManager *community.Manager,
networkManager *network.Manager,
appDB *sql.DB,
mediaServer *server.MediaServer,
walletFeed *event.Feed,
) *Manager {
maker, _ := contracts.NewContractMaker(RPCClient)
stores := []store{newUniswapStore(), newDefaultStore()}
tokens := make([]*Token, 0) tokens := make([]*Token, 0)
networks, err := networkManager.GetAll() networks, err := networkManager.GetAll()
@ -158,6 +153,23 @@ func NewTokenManager(
tokens = mergeTokens([][]*Token{tokens, validTokens}) tokens = mergeTokens([][]*Token{tokens, validTokens})
} }
return tokens
}
func NewTokenManager(
db *sql.DB,
RPCClient *rpc.Client,
communityManager *community.Manager,
networkManager *network.Manager,
appDB *sql.DB,
mediaServer *server.MediaServer,
walletFeed *event.Feed,
accountFeed *event.Feed,
accountsDB *accounts.Database,
) *Manager {
maker, _ := contracts.NewContractMaker(RPCClient)
stores := []store{newUniswapStore(), newDefaultStore()}
tokens := prepareTokens(networkManager, stores)
return &Manager{ return &Manager{
db: db, db: db,
@ -170,6 +182,32 @@ func NewTokenManager(
tokens: tokens, tokens: tokens,
mediaServer: mediaServer, mediaServer: mediaServer,
walletFeed: walletFeed, walletFeed: walletFeed,
accountFeed: accountFeed,
accountsDB: accountsDB,
}
}
func (tm *Manager) Start() {
tm.startAccountsWatcher()
}
func (tm *Manager) startAccountsWatcher() {
if tm.accountWatcher != nil {
return
}
tm.accountWatcher = accountsevent.NewWatcher(tm.accountsDB, tm.accountFeed, tm.onAccountsChange)
tm.accountWatcher.Start()
}
func (tm *Manager) Stop() {
tm.stopAccountsWatcher()
}
func (tm *Manager) stopAccountsWatcher() {
if tm.accountWatcher != nil {
tm.accountWatcher.Stop()
tm.accountWatcher = nil
} }
} }
@ -314,6 +352,7 @@ func (tm *Manager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint6
} }
func (tm *Manager) MarkAsPreviouslyOwnedToken(token *Token, owner common.Address) (bool, error) { func (tm *Manager) MarkAsPreviouslyOwnedToken(token *Token, owner common.Address) (bool, error) {
log.Info("Marking token as previously owned", "token", token, "owner", owner)
if token == nil { if token == nil {
return false, errors.New("token is nil") return false, errors.New("token is nil")
} }
@ -876,3 +915,52 @@ func (tm *Manager) GetTokenHistoricalBalance(account common.Address, chainID uin
} }
return &balance, nil return &balance, nil
} }
func (tm *Manager) GetPreviouslyOwnedTokens() (map[common.Address][]*Token, error) {
tokenMap := make(map[common.Address][]*Token)
rows, err := tm.db.Query("SELECT user_address, token_name, token_symbol, token_address, token_decimals, chain_id FROM token_balances")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
token := &Token{}
var addressStr, tokenAddressStr string
err := rows.Scan(&addressStr, &token.Name, &token.Symbol, &tokenAddressStr, &token.Decimals, &token.ChainID)
if err != nil {
return nil, err
}
address := common.HexToAddress(addressStr)
if (address == common.Address{}) {
continue
}
token.Address = common.HexToAddress(tokenAddressStr)
if (token.Address == common.Address{}) {
continue
}
if _, ok := tokenMap[address]; !ok {
tokenMap[address] = make([]*Token, 0)
}
tokenMap[address] = append(tokenMap[address], token)
}
return tokenMap, nil
}
func (tm *Manager) removeTokenBalances(account common.Address) error {
_, err := tm.db.Exec("DELETE FROM token_balances WHERE user_address = ?", account.String())
return err
}
func (tm *Manager) onAccountsChange(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) {
if eventType == accountsevent.EventTypeRemoved {
for _, account := range changedAddresses {
err := tm.removeTokenBalances(account)
if err != nil {
log.Error("token.Manager: can't remove token balances", "error", err)
}
}
}
}

View File

@ -1,17 +1,31 @@
package token package token
import ( import (
"errors"
"math/big" "math/big"
"sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/event"
gethrpc "github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/network"
mediaserver "github.com/status-im/status-go/server"
"github.com/status-im/status-go/services/accounts/accountsevent"
"github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/community" "github.com/status-im/status-go/services/wallet/community"
"github.com/status-im/status-go/t/helpers" "github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/t/utils"
"github.com/status-im/status-go/transactions/fake"
"github.com/status-im/status-go/walletdatabase" "github.com/status-im/status-go/walletdatabase"
) )
@ -297,3 +311,75 @@ func TestGetTokenHistoricalBalance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedBalance, balance) require.Equal(t, expectedBalance, balance)
} }
func Test_removeTokenBalanceOnEventAccountRemoved(t *testing.T) {
appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
require.NoError(t, err)
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
accountsDB, err := accounts.NewDB(appDB)
require.NoError(t, err)
address := common.HexToAddress("0x1234")
accountFeed := event.Feed{}
chainID := uint64(1)
txServiceMockCtrl := gomock.NewController(t)
server, _ := fake.NewTestServer(txServiceMockCtrl)
client := gethrpc.DialInProc(server)
rpcClient, _ := rpc.NewClient(client, chainID, params.UpstreamRPCConfig{}, nil, nil)
rpcClient.UpstreamChainID = chainID
nm := network.NewManager(appDB)
mediaServer, err := mediaserver.NewMediaServer(appDB, nil, nil, walletDB)
require.NoError(t, err)
manager := NewTokenManager(walletDB, rpcClient, nil, nm, appDB, mediaServer, nil, &accountFeed, accountsDB)
// Insert balances for address
marked, err := manager.MarkAsPreviouslyOwnedToken(&Token{
Address: common.HexToAddress("0x1234"),
Symbol: "Dummy",
Decimals: 18,
ChainID: 1,
}, address)
require.NoError(t, err)
require.True(t, marked)
tokenByAddress, err := manager.GetPreviouslyOwnedTokens()
require.NoError(t, err)
require.Len(t, tokenByAddress, 1)
// Start service
manager.startAccountsWatcher()
// Watching accounts must start before sending event.
// To avoid running goroutine immediately and let the controller subscribe first,
// use any delay.
group := sync.WaitGroup{}
group.Add(1)
go func() {
defer group.Done()
time.Sleep(1 * time.Millisecond)
accountFeed.Send(accountsevent.Event{
Type: accountsevent.EventTypeRemoved,
Accounts: []common.Address{address},
})
require.NoError(t, utils.Eventually(func() error {
tokenByAddress, err := manager.GetPreviouslyOwnedTokens()
if err == nil && len(tokenByAddress) == 0 {
return nil
}
return errors.New("Token not removed")
}, 100*time.Millisecond, 10*time.Millisecond))
}()
group.Wait()
// Stop service
txServiceMockCtrl.Finish()
server.Stop()
manager.stopAccountsWatcher()
}

View File

@ -1048,7 +1048,7 @@ func TestFindBlocksCommand(t *testing.T) {
} }
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{ tokenManager.SetTokens([]*token.Token{
{ {
Address: tokenTXXAddress, Address: tokenTXXAddress,
@ -1182,7 +1182,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) {
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{ tokenManager.SetTokens([]*token.Token{
{ {
@ -1303,7 +1303,7 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) {
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{ tokenManager.SetTokens([]*token.Token{
{ {
@ -1385,7 +1385,7 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) {
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
wdb := NewDB(db) wdb := NewDB(db)
blockChannel := make(chan []*DBHeader, 10) blockChannel := make(chan []*DBHeader, 10)
@ -1502,7 +1502,7 @@ func TestFetchNewBlocksCommand(t *testing.T) {
client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db)
client.SetClient(tc.NetworkID(), tc) client.SetClient(tc.NetworkID(), tc)
tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil)
tokenManager.SetTokens([]*token.Token{ tokenManager.SetTokens([]*token.Token{
{ {