package wallet import ( "context" "math" "math/big" "sync" "time" "golang.org/x/exp/maps" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/market" "github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" "github.com/status-im/status-go/services/wallet/walletevent" ) // WalletTickReload emitted every 15mn to reload the wallet balance and history const EventWalletTickReload walletevent.EventType = "wallet-tick-reload" const EventWalletTickCheckConnected walletevent.EventType = "wallet-tick-check-connected" const ( walletTickReloadPeriod = 10 * time.Minute activityReloadDelay = 30 // Wait this many seconds after activity is detected before triggering a wallet reload activityReloadMarginSeconds = 30 // Trigger a wallet reload if activity is detected this many seconds before the last reload ) func getFixedCurrencies() []string { return []string{"USD"} } func belongsToMandatoryTokens(symbol string) bool { var mandatoryTokens = []string{"ETH", "DAI", "SNT", "STT"} for _, t := range mandatoryTokens { if t == symbol { return true } } return false } func NewReader(tokenManager token.ManagerInterface, marketManager *market.Manager, persistence token.TokenBalancesStorage, walletFeed *event.Feed) *Reader { return &Reader{ tokenManager: tokenManager, marketManager: marketManager, persistence: persistence, walletFeed: walletFeed, refreshBalanceCache: true, } } type Reader struct { tokenManager token.ManagerInterface marketManager *market.Manager persistence token.TokenBalancesStorage walletFeed *event.Feed cancel context.CancelFunc walletEventsWatcher *walletevent.Watcher lastWalletTokenUpdateTimestamp sync.Map reloadDelayTimer *time.Timer refreshBalanceCache bool rw sync.RWMutex } func splitVerifiedTokens(tokens []*token.Token) ([]*token.Token, []*token.Token) { verified := make([]*token.Token, 0) unverified := make([]*token.Token, 0) for _, t := range tokens { if t.Verified { verified = append(verified, t) } else { unverified = append(unverified, t) } } return verified, unverified } func getTokenBySymbols(tokens []*token.Token) map[string][]*token.Token { res := make(map[string][]*token.Token) for _, t := range tokens { if _, ok := res[t.Symbol]; !ok { res[t.Symbol] = make([]*token.Token, 0) } res[t.Symbol] = append(res[t.Symbol], t) } return res } func getTokenAddresses(tokens []*token.Token) []common.Address { set := make(map[common.Address]bool) for _, token := range tokens { set[token.Address] = true } res := make([]common.Address, 0) for address := range set { res = append(res, address) } return res } func (r *Reader) Start() error { ctx, cancel := context.WithCancel(context.Background()) r.cancel = cancel r.startWalletEventsWatcher() go func() { ticker := time.NewTicker(walletTickReloadPeriod) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: r.triggerWalletReload() } } }() return nil } func (r *Reader) Stop() { if r.cancel != nil { r.cancel() } r.stopWalletEventsWatcher() r.cancelDelayedWalletReload() r.lastWalletTokenUpdateTimestamp = sync.Map{} } func (r *Reader) triggerWalletReload() { r.cancelDelayedWalletReload() r.walletFeed.Send(walletevent.Event{ Type: EventWalletTickReload, }) } func (r *Reader) triggerDelayedWalletReload() { r.cancelDelayedWalletReload() r.reloadDelayTimer = time.AfterFunc(time.Duration(activityReloadDelay)*time.Second, r.triggerWalletReload) } func (r *Reader) cancelDelayedWalletReload() { if r.reloadDelayTimer != nil { r.reloadDelayTimer.Stop() r.reloadDelayTimer = nil } } func (r *Reader) startWalletEventsWatcher() { if r.walletEventsWatcher != nil { return } // Respond to ETH/Token transfers walletEventCb := func(event walletevent.Event) { if event.Type != transfer.EventInternalETHTransferDetected && event.Type != transfer.EventInternalERC20TransferDetected { return } for _, address := range event.Accounts { timestamp, ok := r.lastWalletTokenUpdateTimestamp.Load(address) timecheck := int64(0) if ok { timecheck = timestamp.(int64) - activityReloadMarginSeconds } if !ok || event.At > timecheck { r.triggerDelayedWalletReload() r.invalidateBalanceCache() break } } } r.walletEventsWatcher = walletevent.NewWatcher(r.walletFeed, walletEventCb) r.walletEventsWatcher.Start() } func (r *Reader) stopWalletEventsWatcher() { if r.walletEventsWatcher != nil { r.walletEventsWatcher.Stop() r.walletEventsWatcher = nil } } func (r *Reader) tokensCachedForAddresses(addresses []common.Address) bool { cachedTokens, err := r.getCachedWalletTokensWithoutMarketData() if err != nil { return false } for _, address := range addresses { _, ok := cachedTokens[address] if !ok { return false } } return true } func (r *Reader) isCacheTimestampValidForAddress(address common.Address) bool { _, ok := r.lastWalletTokenUpdateTimestamp.Load(address) return ok } func (r *Reader) areCacheTimestampsValid(addresses []common.Address) bool { for _, address := range addresses { if !r.isCacheTimestampValidForAddress(address) { return false } } return true } func (r *Reader) isBalanceCacheValid(addresses []common.Address) bool { r.rw.RLock() defer r.rw.RUnlock() return !r.refreshBalanceCache && r.tokensCachedForAddresses(addresses) && r.areCacheTimestampsValid(addresses) } func (r *Reader) balanceRefreshed() { r.rw.Lock() defer r.rw.Unlock() r.refreshBalanceCache = false } func (r *Reader) invalidateBalanceCache() { r.rw.Lock() defer r.rw.Unlock() r.refreshBalanceCache = true } func (r *Reader) FetchOrGetCachedWalletBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address) (map[common.Address][]token.StorageToken, error) { needFetch := !r.isBalanceCacheValid(addresses) || r.isBalanceUpdateNeededAnyway(clients, addresses) if needFetch { return r.FetchBalances(ctx, clients, addresses) } return r.GetCachedBalances(clients, addresses) } func (r *Reader) isBalanceUpdateNeededAnyway(clients map[uint64]chain.ClientInterface, addresses []common.Address) bool { cachedTokens, err := r.getCachedWalletTokensWithoutMarketData() if err != nil { return true } chainIDs := maps.Keys(clients) updateAnyway := false for _, address := range addresses { if res, ok := cachedTokens[address]; !ok || len(res) == 0 { updateAnyway = true break } networkFound := map[uint64]bool{} for _, token := range cachedTokens[address] { for _, chain := range chainIDs { if _, ok := token.BalancesPerChain[chain]; ok { networkFound[chain] = true } } } for _, chain := range chainIDs { if !networkFound[chain] { updateAnyway = true return updateAnyway } } } return updateAnyway } func tokensToBalancesPerChain(cachedTokens map[common.Address][]token.StorageToken) map[uint64]map[common.Address]map[common.Address]*hexutil.Big { cachedBalancesPerChain := map[uint64]map[common.Address]map[common.Address]*hexutil.Big{} for address, tokens := range cachedTokens { for _, token := range tokens { for _, balance := range token.BalancesPerChain { if _, ok := cachedBalancesPerChain[balance.ChainID]; !ok { cachedBalancesPerChain[balance.ChainID] = map[common.Address]map[common.Address]*hexutil.Big{} } if _, ok := cachedBalancesPerChain[balance.ChainID][address]; !ok { cachedBalancesPerChain[balance.ChainID][address] = map[common.Address]*hexutil.Big{} } bigBalance, _ := new(big.Int).SetString(balance.RawBalance, 10) cachedBalancesPerChain[balance.ChainID][address][balance.Address] = (*hexutil.Big)(bigBalance) } } } return cachedBalancesPerChain } func (r *Reader) fetchBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address, tokenAddresses []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) { latestBalances, err := r.tokenManager.GetBalancesByChain(ctx, clients, addresses, tokenAddresses) if err != nil { log.Error("tokenManager.GetBalancesByChain error", "err", err) return nil, err } return latestBalances, nil } func toChainBalance( balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big, tok *token.Token, address common.Address, decimals uint, cachedTokens map[common.Address][]token.StorageToken, hasError bool, isMandatoryToken bool, ) *token.ChainBalance { hexBalance := &big.Int{} if balances != nil { hexBalance = balances[tok.ChainID][address][tok.Address].ToInt() } balance := big.NewFloat(0.0) if hexBalance != nil { balance = new(big.Float).Quo( new(big.Float).SetInt(hexBalance), big.NewFloat(math.Pow(10, float64(decimals))), ) } isVisible := balance.Cmp(big.NewFloat(0.0)) > 0 || isCachedToken(cachedTokens, address, tok.Symbol, tok.ChainID) if !isVisible && !isMandatoryToken { return nil } return &token.ChainBalance{ RawBalance: hexBalance.String(), Balance: balance, Balance1DayAgo: "0", Address: tok.Address, ChainID: tok.ChainID, HasError: hasError, } } func (r *Reader) getBalance1DayAgo(balance *token.ChainBalance, dayAgoTimestamp int64, symbol string, address common.Address) (*big.Int, error) { balance1DayAgo, err := r.tokenManager.GetTokenHistoricalBalance(address, balance.ChainID, symbol, dayAgoTimestamp) if err != nil { log.Error("tokenManager.GetTokenHistoricalBalance error", "err", err) return nil, err } return balance1DayAgo, nil } func (r *Reader) balancesToTokensByAddress(connectedPerChain map[uint64]bool, addresses []common.Address, allTokens []*token.Token, balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big, cachedTokens map[common.Address][]token.StorageToken) map[common.Address][]token.StorageToken { verifiedTokens, unverifiedTokens := splitVerifiedTokens(allTokens) result := make(map[common.Address][]token.StorageToken) dayAgoTimestamp := time.Now().Add(-24 * time.Hour).Unix() for _, address := range addresses { for _, tokenList := range [][]*token.Token{verifiedTokens, unverifiedTokens} { for symbol, tokens := range getTokenBySymbols(tokenList) { balancesPerChain := r.createBalancePerChainPerSymbol(address, balances, tokens, cachedTokens, connectedPerChain, dayAgoTimestamp) if balancesPerChain == nil { continue } walletToken := token.StorageToken{ Token: token.Token{ Name: tokens[0].Name, Symbol: symbol, Decimals: tokens[0].Decimals, PegSymbol: token.GetTokenPegSymbol(symbol), Verified: tokens[0].Verified, CommunityData: tokens[0].CommunityData, Image: tokens[0].Image, }, BalancesPerChain: balancesPerChain, } result[address] = append(result[address], walletToken) } } } return result } // For tokens with single symbol, create a chain balance for each chain func (r *Reader) createBalancePerChainPerSymbol( address common.Address, balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big, tokens []*token.Token, cachedTokens map[common.Address][]token.StorageToken, clientConnectionPerChain map[uint64]bool, dayAgoTimestamp int64, ) map[uint64]token.ChainBalance { var balancesPerChain map[uint64]token.ChainBalance decimals := tokens[0].Decimals isMandatoryToken := belongsToMandatoryTokens(tokens[0].Symbol) // we expect all tokens in the list to have the same symbol for _, tok := range tokens { hasError := false if connected, ok := clientConnectionPerChain[tok.ChainID]; ok { hasError = !connected } // TODO: Avoid passing the entire balances map to toChainBalance. Iterate over the balances map once and pass the balance per address per token to toChainBalance balance := toChainBalance(balances, tok, address, decimals, cachedTokens, hasError, isMandatoryToken) if balance != nil { balance1DayAgo, _ := r.getBalance1DayAgo(balance, dayAgoTimestamp, tok.Symbol, address) // Ignore error if balance1DayAgo != nil { balance.Balance1DayAgo = balance1DayAgo.String() } if balancesPerChain == nil { balancesPerChain = make(map[uint64]token.ChainBalance) } balancesPerChain[tok.ChainID] = *balance } } return balancesPerChain } func (r *Reader) GetWalletToken(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address, currency string) (map[common.Address][]token.StorageToken, error) { cachedTokens, err := r.getCachedWalletTokensWithoutMarketData() if err != nil { return nil, err } chainIDs := maps.Keys(clients) currencies := make([]string, 0) currencies = append(currencies, currency) currencies = append(currencies, getFixedCurrencies()...) allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs) if err != nil { return nil, err } tokenAddresses := getTokenAddresses(allTokens) balances, err := r.tokenManager.GetBalancesByChain(ctx, clients, addresses, tokenAddresses) if err != nil { log.Info("tokenManager.GetBalancesByChain error", "err", err) return nil, err } verifiedTokens, unverifiedTokens := splitVerifiedTokens(allTokens) tokenSymbols := make([]string, 0) result := make(map[common.Address][]token.StorageToken) for _, address := range addresses { for _, tokenList := range [][]*token.Token{verifiedTokens, unverifiedTokens} { for symbol, tokens := range getTokenBySymbols(tokenList) { balancesPerChain := make(map[uint64]token.ChainBalance) decimals := tokens[0].Decimals isVisible := false for _, tok := range tokens { hexBalance := balances[tok.ChainID][address][tok.Address] balance := big.NewFloat(0.0) if hexBalance != nil { balance = new(big.Float).Quo( new(big.Float).SetInt(hexBalance.ToInt()), big.NewFloat(math.Pow(10, float64(decimals))), ) } hasError := false if client, ok := clients[tok.ChainID]; ok { hasError = err != nil || !client.IsConnected() } if !isVisible { isVisible = balance.Cmp(big.NewFloat(0.0)) > 0 || isCachedToken(cachedTokens, address, tok.Symbol, tok.ChainID) } balancesPerChain[tok.ChainID] = token.ChainBalance{ RawBalance: hexBalance.ToInt().String(), Balance: balance, Address: tok.Address, ChainID: tok.ChainID, HasError: hasError, } } if !isVisible && !belongsToMandatoryTokens(symbol) { continue } walletToken := token.StorageToken{ Token: token.Token{ Name: tokens[0].Name, Symbol: symbol, Decimals: decimals, PegSymbol: token.GetTokenPegSymbol(symbol), Verified: tokens[0].Verified, CommunityData: tokens[0].CommunityData, Image: tokens[0].Image, }, BalancesPerChain: balancesPerChain, } tokenSymbols = append(tokenSymbols, symbol) result[address] = append(result[address], walletToken) } } } var ( group = async.NewAtomicGroup(ctx) prices = map[string]map[string]float64{} tokenDetails = map[string]thirdparty.TokenDetails{} tokenMarketValues = map[string]thirdparty.TokenMarketValues{} ) group.Add(func(parent context.Context) error { prices, err = r.marketManager.FetchPrices(tokenSymbols, currencies) if err != nil { log.Info("marketManager.FetchPrices err", err) } return nil }) group.Add(func(parent context.Context) error { tokenDetails, err = r.marketManager.FetchTokenDetails(tokenSymbols) if err != nil { log.Info("marketManager.FetchTokenDetails err", err) } return nil }) group.Add(func(parent context.Context) error { tokenMarketValues, err = r.marketManager.FetchTokenMarketValues(tokenSymbols, currency) if err != nil { log.Info("marketManager.FetchTokenMarketValues err", err) } return nil }) select { case <-group.WaitAsync(): case <-ctx.Done(): return nil, ctx.Err() } err = group.Error() if err != nil { return nil, err } for address, tokens := range result { for index, tok := range tokens { marketValuesPerCurrency := make(map[string]token.TokenMarketValues) for _, currency := range currencies { if _, ok := tokenMarketValues[tok.Symbol]; !ok { continue } marketValuesPerCurrency[currency] = token.TokenMarketValues{ MarketCap: tokenMarketValues[tok.Symbol].MKTCAP, HighDay: tokenMarketValues[tok.Symbol].HIGHDAY, LowDay: tokenMarketValues[tok.Symbol].LOWDAY, ChangePctHour: tokenMarketValues[tok.Symbol].CHANGEPCTHOUR, ChangePctDay: tokenMarketValues[tok.Symbol].CHANGEPCTDAY, ChangePct24hour: tokenMarketValues[tok.Symbol].CHANGEPCT24HOUR, Change24hour: tokenMarketValues[tok.Symbol].CHANGE24HOUR, Price: prices[tok.Symbol][currency], HasError: !r.marketManager.IsConnected, } } if _, ok := tokenDetails[tok.Symbol]; !ok { continue } result[address][index].Description = tokenDetails[tok.Symbol].Description result[address][index].AssetWebsiteURL = tokenDetails[tok.Symbol].AssetWebsiteURL result[address][index].BuiltOn = tokenDetails[tok.Symbol].BuiltOn result[address][index].MarketValuesPerCurrency = marketValuesPerCurrency } } r.updateTokenUpdateTimestamp(addresses) return result, r.persistence.SaveTokens(result) } func isCachedToken(cachedTokens map[common.Address][]token.StorageToken, address common.Address, symbol string, chainID uint64) bool { if tokens, ok := cachedTokens[address]; ok { for _, t := range tokens { if t.Symbol != symbol { continue } _, ok := t.BalancesPerChain[chainID] if ok { return true } } } return false } // getCachedWalletTokensWithoutMarketData returns the latest fetched balances, minus // price information func (r *Reader) getCachedWalletTokensWithoutMarketData() (map[common.Address][]token.StorageToken, error) { return r.persistence.GetTokens() } func (r *Reader) updateTokenUpdateTimestamp(addresses []common.Address) { for _, address := range addresses { r.lastWalletTokenUpdateTimestamp.Store(address, time.Now().Unix()) } } func (r *Reader) FetchBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address) (map[common.Address][]token.StorageToken, error) { cachedTokens, err := r.getCachedWalletTokensWithoutMarketData() if err != nil { return nil, err } chainIDs := maps.Keys(clients) allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs) if err != nil { return nil, err } connectedPerChain := map[uint64]bool{} for chainID, client := range clients { connectedPerChain[chainID] = client.IsConnected() } tokenAddresses := getTokenAddresses(allTokens) balances, err := r.fetchBalances(ctx, clients, addresses, tokenAddresses) if err != nil { log.Error("failed to update balances", "err", err) return nil, err } tokens := r.balancesToTokensByAddress(connectedPerChain, addresses, allTokens, balances, cachedTokens) err = r.persistence.SaveTokens(tokens) if err != nil { log.Error("failed to save tokens", "err", err) // Do not return error, as it is not critical } r.updateTokenUpdateTimestamp(addresses) r.balanceRefreshed() return tokens, err } func (r *Reader) GetCachedBalances(clients map[uint64]chain.ClientInterface, addresses []common.Address) (map[common.Address][]token.StorageToken, error) { cachedTokens, err := r.getCachedWalletTokensWithoutMarketData() if err != nil { return nil, err } chainIDs := maps.Keys(clients) allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs) if err != nil { return nil, err } connectedPerChain := map[uint64]bool{} for chainID, client := range clients { connectedPerChain[chainID] = client.IsConnected() } balances := tokensToBalancesPerChain(cachedTokens) return r.balancesToTokensByAddress(connectedPerChain, addresses, allTokens, balances, cachedTokens), nil }