fix: protect token list with a mutex

This commit is contained in:
Dario Gabriel Lipicar 2023-09-11 11:44:43 -03:00 committed by dlipicar
parent 7a9845c6e2
commit 8d1992d2e1
2 changed files with 61 additions and 24 deletions

View File

@ -54,14 +54,18 @@ type ManagerInterface interface {
// Manager is used for accessing token store. It changes the token store based on overridden tokens
type Manager struct {
db *sql.DB
RPCClient *rpc.Client
contractMaker *contracts.ContractMaker
networkManager *network.Manager
stores []store
db *sql.DB
RPCClient *rpc.Client
contractMaker *contracts.ContractMaker
networkManager *network.Manager
stores []store // Set on init, not changed afterwards
// member variables below are protected by mutex
tokenList []*Token
tokenMap storeMap
areTokensFetched bool
tokenLock sync.RWMutex
}
func NewTokenManager(
@ -72,14 +76,14 @@ func NewTokenManager(
maker, _ := contracts.NewContractMaker(RPCClient)
// Order of stores is important when merging token lists. The former prevale
return &Manager{
db,
RPCClient,
maker,
networkManager,
[]store{newUniswapStore(), newDefaultStore()},
nil,
nil,
false,
db: db,
RPCClient: RPCClient,
contractMaker: maker,
networkManager: networkManager,
stores: []store{newUniswapStore(), newDefaultStore()},
tokenList: nil,
tokenMap: nil,
areTokensFetched: false,
}
}
@ -125,7 +129,7 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool {
tm.fetchTokens()
}
tokensMap, ok := tm.tokenMap[chainID]
tokensMap, ok := tm.getAddressTokenMap(chainID)
if !ok {
return false
}
@ -134,9 +138,32 @@ func (tm *Manager) inStore(address common.Address, chainID uint64) bool {
return ok
}
func (tm *Manager) getTokenList() []*Token {
tm.tokenLock.RLock()
defer tm.tokenLock.RUnlock()
return tm.tokenList
}
func (tm *Manager) getAddressTokenMap(chainID uint64) (addressTokenMap, bool) {
tm.tokenLock.RLock()
defer tm.tokenLock.RUnlock()
tokenMap, chainPresent := tm.tokenMap[chainID]
return tokenMap, chainPresent
}
func (tm *Manager) setTokens(tokens []*Token) {
tm.tokenLock.Lock()
defer tm.tokenLock.Unlock()
tm.tokenList = tokens
tm.tokenMap = toTokenMap(tokens)
tm.areTokensFetched = true
}
func (tm *Manager) fetchTokens() {
tm.tokenList = nil
tm.tokenMap = nil
tokenList := make([]*Token, 0)
networks, err := tm.networkManager.Get(false)
if err != nil {
@ -157,10 +184,10 @@ func (tm *Manager) fetchTokens() {
}
}
tm.tokenList = mergeTokenLists([][]*Token{tm.tokenList, validTokens})
tokenList = mergeTokenLists([][]*Token{tokenList, validTokens})
}
tm.areTokensFetched = true
tm.tokenMap = toTokenMap(tm.tokenList)
tm.setTokens(tokenList)
}
func (tm *Manager) getFullTokenList(chainID uint64) []*Token {
@ -306,7 +333,7 @@ func (tm *Manager) GetAllTokens() ([]*Token, error) {
log.Error("can't fetch custom tokens", "error", err)
}
tokens = append(tm.tokenList, tokens...)
tokens = append(tm.getTokenList(), tokens...)
overrideTokensInPlace(tm.networkManager.GetConfiguredNetworks(), tokens)
@ -330,7 +357,7 @@ func (tm *Manager) GetTokens(chainID uint64) ([]*Token, error) {
tm.fetchTokens()
}
tokensMap, ok := tm.tokenMap[chainID]
tokensMap, ok := tm.getAddressTokenMap(chainID)
if !ok {
return nil, errors.New("no tokens for this network")
}

View File

@ -16,9 +16,19 @@ import (
func setupTestTokenDB(t *testing.T) (*Manager, func()) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
return &Manager{db, nil, nil, nil, nil, nil, nil, false}, func() {
require.NoError(t, db.Close())
}
return &Manager{
db: db,
RPCClient: nil,
contractMaker: nil,
networkManager: nil,
stores: nil,
tokenList: nil,
tokenMap: nil,
areTokensFetched: false,
}, func() {
require.NoError(t, db.Close())
}
}
func TestCustoms(t *testing.T) {