From 8d1992d2e11943096fa0ae9c7ffe99a507b75715 Mon Sep 17 00:00:00 2001 From: Dario Gabriel Lipicar Date: Mon, 11 Sep 2023 11:44:43 -0300 Subject: [PATCH] fix: protect token list with a mutex --- services/wallet/token/token.go | 69 ++++++++++++++++++++--------- services/wallet/token/token_test.go | 16 +++++-- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index a1fc706ad..9507d3c57 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -54,14 +54,18 @@ type ManagerInterface interface { // Manager is used for accessing token store. It changes the token store based on overridden tokens type Manager struct { - db *sql.DB - RPCClient *rpc.Client - contractMaker *contracts.ContractMaker - networkManager *network.Manager - stores []store + db *sql.DB + RPCClient *rpc.Client + contractMaker *contracts.ContractMaker + networkManager *network.Manager + stores []store // Set on init, not changed afterwards + + // member variables below are protected by mutex tokenList []*Token tokenMap storeMap areTokensFetched bool + + tokenLock sync.RWMutex } func NewTokenManager( @@ -72,14 +76,14 @@ func NewTokenManager( maker, _ := contracts.NewContractMaker(RPCClient) // Order of stores is important when merging token lists. The former prevale return &Manager{ - db, - RPCClient, - maker, - networkManager, - []store{newUniswapStore(), newDefaultStore()}, - nil, - nil, - false, + db: db, + RPCClient: RPCClient, + contractMaker: maker, + networkManager: networkManager, + stores: []store{newUniswapStore(), newDefaultStore()}, + tokenList: nil, + tokenMap: nil, + areTokensFetched: false, } } @@ -125,7 +129,7 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool { tm.fetchTokens() } - tokensMap, ok := tm.tokenMap[chainID] + tokensMap, ok := tm.getAddressTokenMap(chainID) if !ok { return false } @@ -134,9 +138,32 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool { return ok } +func (tm *Manager) getTokenList() []*Token { + tm.tokenLock.RLock() + defer tm.tokenLock.RUnlock() + + return tm.tokenList +} + +func (tm *Manager) getAddressTokenMap(chainID uint64) (addressTokenMap, bool) { + tm.tokenLock.RLock() + defer tm.tokenLock.RUnlock() + + tokenMap, chainPresent := tm.tokenMap[chainID] + return tokenMap, chainPresent +} + +func (tm *Manager) setTokens(tokens []*Token) { + tm.tokenLock.Lock() + defer tm.tokenLock.Unlock() + + tm.tokenList = tokens + tm.tokenMap = toTokenMap(tokens) + tm.areTokensFetched = true +} + func (tm *Manager) fetchTokens() { - tm.tokenList = nil - tm.tokenMap = nil + tokenList := make([]*Token, 0) networks, err := tm.networkManager.Get(false) if err != nil { @@ -157,10 +184,10 @@ func (tm *Manager) fetchTokens() { } } - tm.tokenList = mergeTokenLists([][]*Token{tm.tokenList, validTokens}) + tokenList = mergeTokenLists([][]*Token{tokenList, validTokens}) } - tm.areTokensFetched = true - tm.tokenMap = toTokenMap(tm.tokenList) + + tm.setTokens(tokenList) } func (tm *Manager) getFullTokenList(chainID uint64) []*Token { @@ -306,7 +333,7 @@ func (tm *Manager) GetAllTokens() ([]*Token, error) { log.Error("can't fetch custom tokens", "error", err) } - tokens = append(tm.tokenList, tokens...) + tokens = append(tm.getTokenList(), tokens...) overrideTokensInPlace(tm.networkManager.GetConfiguredNetworks(), tokens) @@ -330,7 +357,7 @@ func (tm *Manager) GetTokens(chainID uint64) ([]*Token, error) { tm.fetchTokens() } - tokensMap, ok := tm.tokenMap[chainID] + tokensMap, ok := tm.getAddressTokenMap(chainID) if !ok { return nil, errors.New("no tokens for this network") } diff --git a/services/wallet/token/token_test.go b/services/wallet/token/token_test.go index 88e89165a..792e1c58f 100644 --- a/services/wallet/token/token_test.go +++ b/services/wallet/token/token_test.go @@ -16,9 +16,19 @@ import ( func setupTestTokenDB(t *testing.T) (*Manager, func()) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - return &Manager{db, nil, nil, nil, nil, nil, nil, false}, func() { - require.NoError(t, db.Close()) - } + + return &Manager{ + db: db, + RPCClient: nil, + contractMaker: nil, + networkManager: nil, + stores: nil, + tokenList: nil, + tokenMap: nil, + areTokensFetched: false, + }, func() { + require.NoError(t, db.Close()) + } } func TestCustoms(t *testing.T) {