diff --git a/protocol/messenger.go b/protocol/messenger.go index 37fcaa295..cbfb76a4d 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -464,7 +464,7 @@ func NewMessenger( if c.tokenManager != nil { managerOptions = append(managerOptions, communities.WithTokenManager(c.tokenManager)) } else if c.rpcClient != nil { - tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil) + tokenManager := token.NewTokenManager(c.walletDb, c.rpcClient, community.NewManager(database, c.httpServer, nil), c.rpcClient.NetworkManager, database, c.httpServer, nil, nil, nil) managerOptions = append(managerOptions, communities.WithTokenManager(communities.NewDefaultTokenManager(tokenManager))) } diff --git a/services/wallet/service.go b/services/wallet/service.go index b4c0ce3b5..3e3c9bc6b 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -102,7 +102,8 @@ func NewService( communityManager := community.NewManager(db, mediaServer, feed) balanceCacher := balance.NewCacherWithTTL(5 * time.Minute) - tokenManager := token.NewTokenManager(db, rpcClient, communityManager, rpcClient.NetworkManager, appDB, mediaServer, feed) + tokenManager := token.NewTokenManager(db, rpcClient, communityManager, rpcClient.NetworkManager, appDB, mediaServer, feed, accountFeed, accountsDB) + tokenManager.Start() savedAddressesManager := &SavedAddressesManager{db: db} transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, feed) blockChainState := blockchainstate.NewBlockChainState() @@ -262,6 +263,7 @@ func (s *Service) Stop() error { s.history.Stop() s.activity.Stop() s.collectibles.Stop() + s.tokenManager.Stop() s.started = false log.Info("wallet stopped") return nil diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index 8dbdb14e1..4ae319f4f 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -21,12 +21,14 @@ import ( "github.com/status-im/status-go/contracts/ethscan" "github.com/status-im/status-go/contracts/ierc20" eth_node_types "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/params" "github.com/status-im/status-go/protocol/communities/token" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/server" + "github.com/status-im/status-go/services/accounts/accountsevent" "github.com/status-im/status-go/services/communitytokens" "github.com/status-im/status-go/services/utils" "github.com/status-im/status-go/services/wallet/async" @@ -104,6 +106,9 @@ type Manager struct { communityManager *community.Manager mediaServer *server.MediaServer walletFeed *event.Feed + accountFeed *event.Feed + accountWatcher *accountsevent.Watcher + accountsDB *accounts.Database tokens []*Token @@ -125,17 +130,7 @@ func mergeTokens(sliceLists [][]*Token) []*Token { return res } -func NewTokenManager( - db *sql.DB, - RPCClient *rpc.Client, - communityManager *community.Manager, - networkManager *network.Manager, - appDB *sql.DB, - mediaServer *server.MediaServer, - walletFeed *event.Feed, -) *Manager { - maker, _ := contracts.NewContractMaker(RPCClient) - stores := []store{newUniswapStore(), newDefaultStore()} +func prepareTokens(networkManager *network.Manager, stores []store) []*Token { tokens := make([]*Token, 0) networks, err := networkManager.GetAll() @@ -158,6 +153,23 @@ func NewTokenManager( tokens = mergeTokens([][]*Token{tokens, validTokens}) } + return tokens +} + +func NewTokenManager( + db *sql.DB, + RPCClient *rpc.Client, + communityManager *community.Manager, + networkManager *network.Manager, + appDB *sql.DB, + mediaServer *server.MediaServer, + walletFeed *event.Feed, + accountFeed *event.Feed, + accountsDB *accounts.Database, +) *Manager { + maker, _ := contracts.NewContractMaker(RPCClient) + stores := []store{newUniswapStore(), newDefaultStore()} + tokens := prepareTokens(networkManager, stores) return &Manager{ db: db, @@ -170,6 +182,32 @@ func NewTokenManager( tokens: tokens, mediaServer: mediaServer, walletFeed: walletFeed, + accountFeed: accountFeed, + accountsDB: accountsDB, + } +} + +func (tm *Manager) Start() { + tm.startAccountsWatcher() +} + +func (tm *Manager) startAccountsWatcher() { + if tm.accountWatcher != nil { + return + } + + tm.accountWatcher = accountsevent.NewWatcher(tm.accountsDB, tm.accountFeed, tm.onAccountsChange) + tm.accountWatcher.Start() +} + +func (tm *Manager) Stop() { + tm.stopAccountsWatcher() +} + +func (tm *Manager) stopAccountsWatcher() { + if tm.accountWatcher != nil { + tm.accountWatcher.Stop() + tm.accountWatcher = nil } } @@ -314,6 +352,7 @@ func (tm *Manager) FindOrCreateTokenByAddress(ctx context.Context, chainID uint6 } func (tm *Manager) MarkAsPreviouslyOwnedToken(token *Token, owner common.Address) (bool, error) { + log.Info("Marking token as previously owned", "token", token, "owner", owner) if token == nil { return false, errors.New("token is nil") } @@ -876,3 +915,52 @@ func (tm *Manager) GetTokenHistoricalBalance(account common.Address, chainID uin } return &balance, nil } + +func (tm *Manager) GetPreviouslyOwnedTokens() (map[common.Address][]*Token, error) { + tokenMap := make(map[common.Address][]*Token) + rows, err := tm.db.Query("SELECT user_address, token_name, token_symbol, token_address, token_decimals, chain_id FROM token_balances") + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + token := &Token{} + var addressStr, tokenAddressStr string + err := rows.Scan(&addressStr, &token.Name, &token.Symbol, &tokenAddressStr, &token.Decimals, &token.ChainID) + if err != nil { + return nil, err + } + address := common.HexToAddress(addressStr) + if (address == common.Address{}) { + continue + } + token.Address = common.HexToAddress(tokenAddressStr) + if (token.Address == common.Address{}) { + continue + } + + if _, ok := tokenMap[address]; !ok { + tokenMap[address] = make([]*Token, 0) + } + tokenMap[address] = append(tokenMap[address], token) + } + + return tokenMap, nil +} + +func (tm *Manager) removeTokenBalances(account common.Address) error { + _, err := tm.db.Exec("DELETE FROM token_balances WHERE user_address = ?", account.String()) + return err +} + +func (tm *Manager) onAccountsChange(changedAddresses []common.Address, eventType accountsevent.EventType, currentAddresses []common.Address) { + if eventType == accountsevent.EventTypeRemoved { + for _, account := range changedAddresses { + err := tm.removeTokenBalances(account) + if err != nil { + log.Error("token.Manager: can't remove token balances", "error", err) + } + } + } +} diff --git a/services/wallet/token/token_test.go b/services/wallet/token/token_test.go index 70da89848..db64f91cb 100644 --- a/services/wallet/token/token_test.go +++ b/services/wallet/token/token_test.go @@ -1,17 +1,31 @@ package token import ( + "errors" "math/big" + "sync" "testing" + "time" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + gethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/rpc/network" + mediaserver "github.com/status-im/status-go/server" + "github.com/status-im/status-go/services/accounts/accountsevent" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/community" "github.com/status-im/status-go/t/helpers" + "github.com/status-im/status-go/t/utils" + "github.com/status-im/status-go/transactions/fake" "github.com/status-im/status-go/walletdatabase" ) @@ -297,3 +311,75 @@ func TestGetTokenHistoricalBalance(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedBalance, balance) } + +func Test_removeTokenBalanceOnEventAccountRemoved(t *testing.T) { + appDB, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) + require.NoError(t, err) + + walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + + accountsDB, err := accounts.NewDB(appDB) + require.NoError(t, err) + + address := common.HexToAddress("0x1234") + accountFeed := event.Feed{} + chainID := uint64(1) + txServiceMockCtrl := gomock.NewController(t) + server, _ := fake.NewTestServer(txServiceMockCtrl) + client := gethrpc.DialInProc(server) + rpcClient, _ := rpc.NewClient(client, chainID, params.UpstreamRPCConfig{}, nil, nil) + rpcClient.UpstreamChainID = chainID + nm := network.NewManager(appDB) + mediaServer, err := mediaserver.NewMediaServer(appDB, nil, nil, walletDB) + require.NoError(t, err) + + manager := NewTokenManager(walletDB, rpcClient, nil, nm, appDB, mediaServer, nil, &accountFeed, accountsDB) + + // Insert balances for address + marked, err := manager.MarkAsPreviouslyOwnedToken(&Token{ + Address: common.HexToAddress("0x1234"), + Symbol: "Dummy", + Decimals: 18, + ChainID: 1, + }, address) + require.NoError(t, err) + require.True(t, marked) + + tokenByAddress, err := manager.GetPreviouslyOwnedTokens() + require.NoError(t, err) + require.Len(t, tokenByAddress, 1) + + // Start service + manager.startAccountsWatcher() + + // Watching accounts must start before sending event. + // To avoid running goroutine immediately and let the controller subscribe first, + // use any delay. + group := sync.WaitGroup{} + group.Add(1) + go func() { + defer group.Done() + time.Sleep(1 * time.Millisecond) + + accountFeed.Send(accountsevent.Event{ + Type: accountsevent.EventTypeRemoved, + Accounts: []common.Address{address}, + }) + + require.NoError(t, utils.Eventually(func() error { + tokenByAddress, err := manager.GetPreviouslyOwnedTokens() + if err == nil && len(tokenByAddress) == 0 { + return nil + } + return errors.New("Token not removed") + }, 100*time.Millisecond, 10*time.Millisecond)) + }() + + group.Wait() + + // Stop service + txServiceMockCtrl.Finish() + server.Stop() + manager.stopAccountsWatcher() +} diff --git a/services/wallet/transfer/commands_sequential_test.go b/services/wallet/transfer/commands_sequential_test.go index f5be844c7..fe530a182 100644 --- a/services/wallet/transfer/commands_sequential_test.go +++ b/services/wallet/transfer/commands_sequential_test.go @@ -1048,7 +1048,7 @@ func TestFindBlocksCommand(t *testing.T) { } client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) tokenManager.SetTokens([]*token.Token{ { Address: tokenTXXAddress, @@ -1182,7 +1182,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) { client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) tokenManager.SetTokens([]*token.Token{ { @@ -1303,7 +1303,7 @@ func TestFetchNewBlocksCommand_findBlocksWithEthTransfers(t *testing.T) { client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) tokenManager.SetTokens([]*token.Token{ { @@ -1385,7 +1385,7 @@ func TestFetchNewBlocksCommand_nonceDetection(t *testing.T) { client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) wdb := NewDB(db) blockChannel := make(chan []*DBHeader, 10) @@ -1502,7 +1502,7 @@ func TestFetchNewBlocksCommand(t *testing.T) { client, _ := statusRpc.NewClient(nil, 1, params.UpstreamRPCConfig{Enabled: false, URL: ""}, []params.Network{}, db) client.SetClient(tc.NetworkID(), tc) - tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil) + tokenManager := token.NewTokenManager(db, client, community.NewManager(appdb, nil, nil), network.NewManager(appdb), appdb, mediaServer, nil, nil, nil) tokenManager.SetTokens([]*token.Token{ {