diff --git a/services/wallet/reader.go b/services/wallet/reader.go index 267090908..0c0486722 100644 --- a/services/wallet/reader.go +++ b/services/wallet/reader.go @@ -5,7 +5,6 @@ import ( "math" "math/big" "sync" - "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -50,15 +49,14 @@ func belongsToMandatoryTokens(symbol string) bool { func NewReader(rpcClient *rpc.Client, tokenManager *token.Manager, marketManager *market.Manager, communityManager *community.Manager, accountsDB *accounts.Database, persistence *Persistence, walletFeed *event.Feed) *Reader { return &Reader{ - rpcClient: rpcClient, - tokenManager: tokenManager, - marketManager: marketManager, - communityManager: communityManager, - accountsDB: accountsDB, - persistence: persistence, - walletFeed: walletFeed, - lastWalletTokenUpdateTimestamp: atomic.Int64{}, - refreshBalanceCache: true, + rpcClient: rpcClient, + tokenManager: tokenManager, + marketManager: marketManager, + communityManager: communityManager, + accountsDB: accountsDB, + persistence: persistence, + walletFeed: walletFeed, + refreshBalanceCache: true, } } @@ -72,7 +70,7 @@ type Reader struct { walletFeed *event.Feed cancel context.CancelFunc walletEventsWatcher *walletevent.Watcher - lastWalletTokenUpdateTimestamp atomic.Int64 + lastWalletTokenUpdateTimestamp sync.Map reloadDelayTimer *time.Timer refreshBalanceCache bool rw sync.RWMutex @@ -185,7 +183,7 @@ func (r *Reader) Stop() { r.cancelDelayedWalletReload() - r.lastWalletTokenUpdateTimestamp.Store(0) + r.lastWalletTokenUpdateTimestamp = sync.Map{} } func (r *Reader) triggerWalletReload() { @@ -222,13 +220,15 @@ func (r *Reader) startWalletEventsWatcher() { return } - timecheck := r.lastWalletTokenUpdateTimestamp.Load() - activityReloadMarginSeconds - if event.At > timecheck { - r.triggerDelayedWalletReload() - } + for _, address := range event.Accounts { + timestamp, ok := r.lastWalletTokenUpdateTimestamp.Load(address) + timecheck := timestamp.(int64) - activityReloadMarginSeconds - if transfer.IsTransferDetectionEvent(event.Type) { - r.invalidateBalanceCache() + if !ok || event.At > timecheck { + r.triggerDelayedWalletReload() + r.invalidateBalanceCache() + break + } } } @@ -244,11 +244,42 @@ func (r *Reader) stopWalletEventsWatcher() { } } -func (r *Reader) isBalanceCacheValid() bool { +func (r *Reader) tokensCachedForAddresses(addresses []common.Address) bool { + for _, address := range addresses { + cachedTokens, err := r.GetCachedWalletTokensWithoutMarketData() + if err != nil { + return false + } + + _, ok := cachedTokens[address] + if !ok { + return false + } + } + + return true +} + +func (r *Reader) isCacheTimestampValidForAddress(address common.Address) bool { + _, ok := r.lastWalletTokenUpdateTimestamp.Load(address) + return ok +} + +func (r *Reader) areCacheTimestampsValid(addresses []common.Address) bool { + for _, address := range addresses { + if !r.isCacheTimestampValidForAddress(address) { + return false + } + } + + return true +} + +func (r *Reader) isBalanceCacheValid(addresses []common.Address) bool { r.rw.RLock() defer r.rw.RUnlock() - return !r.refreshBalanceCache + return !r.refreshBalanceCache && r.tokensCachedForAddresses(addresses) && r.areCacheTimestampsValid(addresses) } func (r *Reader) balanceRefreshed() { @@ -266,7 +297,7 @@ func (r *Reader) invalidateBalanceCache() { } func (r *Reader) FetchOrGetCachedWalletBalances(ctx context.Context, addresses []common.Address) (map[common.Address][]Token, error) { - if !r.isBalanceCacheValid() { + if !r.isBalanceCacheValid(addresses) { balances, err := r.GetWalletTokenBalances(ctx, addresses) if err != nil { return nil, err @@ -458,7 +489,7 @@ func (r *Reader) getWalletTokenBalances(ctx context.Context, addresses []common. } } - r.lastWalletTokenUpdateTimestamp.Store(time.Now().Unix()) + r.updateTokenUpdateTimestamp(addresses) return result, r.persistence.SaveTokens(result) } @@ -655,7 +686,7 @@ func (r *Reader) GetWalletToken(ctx context.Context, addresses []common.Address) } } - r.lastWalletTokenUpdateTimestamp.Store(time.Now().Unix()) + r.updateTokenUpdateTimestamp(addresses) return result, r.persistence.SaveTokens(result) } @@ -680,3 +711,9 @@ func (r *Reader) isCachedToken(cachedTokens map[common.Address][]Token, address func (r *Reader) GetCachedWalletTokensWithoutMarketData() (map[common.Address][]Token, error) { return r.persistence.GetTokens() } + +func (r *Reader) updateTokenUpdateTimestamp(addresses []common.Address) { + for _, address := range addresses { + r.lastWalletTokenUpdateTimestamp.Store(address, time.Now().Unix()) + } +}