fix: protect token list with a mutex

This commit is contained in:
Dario Gabriel Lipicar 2023-09-11 11:44:43 -03:00 committed by dlipicar
parent 7a9845c6e2
commit 8d1992d2e1
2 changed files with 61 additions and 24 deletions

View File

@ -58,10 +58,14 @@ type Manager struct {
RPCClient *rpc.Client RPCClient *rpc.Client
contractMaker *contracts.ContractMaker contractMaker *contracts.ContractMaker
networkManager *network.Manager networkManager *network.Manager
stores []store stores []store // Set on init, not changed afterwards
// member variables below are protected by mutex
tokenList []*Token tokenList []*Token
tokenMap storeMap tokenMap storeMap
areTokensFetched bool areTokensFetched bool
tokenLock sync.RWMutex
} }
func NewTokenManager( func NewTokenManager(
@ -72,14 +76,14 @@ func NewTokenManager(
maker, _ := contracts.NewContractMaker(RPCClient) maker, _ := contracts.NewContractMaker(RPCClient)
// Order of stores is important when merging token lists. The former prevale // Order of stores is important when merging token lists. The former prevale
return &Manager{ return &Manager{
db, db: db,
RPCClient, RPCClient: RPCClient,
maker, contractMaker: maker,
networkManager, networkManager: networkManager,
[]store{newUniswapStore(), newDefaultStore()}, stores: []store{newUniswapStore(), newDefaultStore()},
nil, tokenList: nil,
nil, tokenMap: nil,
false, areTokensFetched: false,
} }
} }
@ -125,7 +129,7 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool {
tm.fetchTokens() tm.fetchTokens()
} }
tokensMap, ok := tm.tokenMap[chainID] tokensMap, ok := tm.getAddressTokenMap(chainID)
if !ok { if !ok {
return false return false
} }
@ -134,9 +138,32 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool {
return ok return ok
} }
func (tm *Manager) getTokenList() []*Token {
tm.tokenLock.RLock()
defer tm.tokenLock.RUnlock()
return tm.tokenList
}
func (tm *Manager) getAddressTokenMap(chainID uint64) (addressTokenMap, bool) {
tm.tokenLock.RLock()
defer tm.tokenLock.RUnlock()
tokenMap, chainPresent := tm.tokenMap[chainID]
return tokenMap, chainPresent
}
func (tm *Manager) setTokens(tokens []*Token) {
tm.tokenLock.Lock()
defer tm.tokenLock.Unlock()
tm.tokenList = tokens
tm.tokenMap = toTokenMap(tokens)
tm.areTokensFetched = true
}
func (tm *Manager) fetchTokens() { func (tm *Manager) fetchTokens() {
tm.tokenList = nil tokenList := make([]*Token, 0)
tm.tokenMap = nil
networks, err := tm.networkManager.Get(false) networks, err := tm.networkManager.Get(false)
if err != nil { if err != nil {
@ -157,10 +184,10 @@ func (tm *Manager) fetchTokens() {
} }
} }
tm.tokenList = mergeTokenLists([][]*Token{tm.tokenList, validTokens}) tokenList = mergeTokenLists([][]*Token{tokenList, validTokens})
} }
tm.areTokensFetched = true
tm.tokenMap = toTokenMap(tm.tokenList) tm.setTokens(tokenList)
} }
func (tm *Manager) getFullTokenList(chainID uint64) []*Token { func (tm *Manager) getFullTokenList(chainID uint64) []*Token {
@ -306,7 +333,7 @@ func (tm *Manager) GetAllTokens() ([]*Token, error) {
log.Error("can't fetch custom tokens", "error", err) log.Error("can't fetch custom tokens", "error", err)
} }
tokens = append(tm.tokenList, tokens...) tokens = append(tm.getTokenList(), tokens...)
overrideTokensInPlace(tm.networkManager.GetConfiguredNetworks(), tokens) overrideTokensInPlace(tm.networkManager.GetConfiguredNetworks(), tokens)
@ -330,7 +357,7 @@ func (tm *Manager) GetTokens(chainID uint64) ([]*Token, error) {
tm.fetchTokens() tm.fetchTokens()
} }
tokensMap, ok := tm.tokenMap[chainID] tokensMap, ok := tm.getAddressTokenMap(chainID)
if !ok { if !ok {
return nil, errors.New("no tokens for this network") return nil, errors.New("no tokens for this network")
} }

View File

@ -16,7 +16,17 @@ import (
func setupTestTokenDB(t *testing.T) (*Manager, func()) { func setupTestTokenDB(t *testing.T) (*Manager, func()) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err) require.NoError(t, err)
return &Manager{db, nil, nil, nil, nil, nil, nil, false}, func() {
return &Manager{
db: db,
RPCClient: nil,
contractMaker: nil,
networkManager: nil,
stores: nil,
tokenList: nil,
tokenMap: nil,
areTokensFetched: false,
}, func() {
require.NoError(t, db.Close()) require.NoError(t, db.Close())
} }
} }