diff --git a/services/wallet/api.go b/services/wallet/api.go index 356013bc8..f19dc11e3 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -183,7 +183,7 @@ func (api *API) GetBalanceHistoryRange(ctx context.Context, chainIDs []uint64, a func (api *API) GetTokens(ctx context.Context, chainID uint64) ([]*token.Token, error) { log.Debug("call to get tokens") - rst, err := api.s.tokenManager.GetTokens(chainID) + rst, err := api.s.tokenManager.GetTokens(chainID, true) log.Debug("result from token store", "len", len(rst)) return rst, err } diff --git a/services/wallet/reader.go b/services/wallet/reader.go index 3dab614b0..241be3aae 100644 --- a/services/wallet/reader.go +++ b/services/wallet/reader.go @@ -192,7 +192,7 @@ func (r *Reader) GetWalletToken(ctx context.Context, addresses []common.Address) } currencies = append(currencies, currency) currencies = append(currencies, getFixedCurrencies()...) - allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs) + allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs, true) if err != nil { return nil, err diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index 124c9c92a..7e7d7f872 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -197,12 +197,12 @@ func (tm *Manager) fetchTokens() { } func (tm *Manager) getFullTokenList(chainID uint64) []*Token { - tokens, err := tm.GetTokens(chainID) + tokens, err := tm.GetTokens(chainID, false) if err != nil { return nil } - customTokens, err := tm.GetCustomsByChainID(chainID) + customTokens, err := tm.GetCustomsByChainID(chainID, false) if err != nil { return nil } @@ -347,7 +347,7 @@ func (tm *Manager) discoverTokenCommunityID(ctx context.Context, token *Token, a } func (tm *Manager) FindSNT(chainID uint64) *Token { - tokens, err := tm.GetTokens(chainID) + tokens, err := tm.GetTokens(chainID, false) if err != nil { return nil } @@ -396,10 +396,10 @@ func (tm *Manager) GetAllTokens() ([]*Token, error) { return tokens, nil } -func (tm *Manager) GetTokensByChainIDs(chainIDs []uint64) ([]*Token, error) { +func (tm *Manager) GetTokensByChainIDs(chainIDs []uint64, onlyCommunityCustoms bool) ([]*Token, error) { tokens := make([]*Token, 0) for _, chainID := range chainIDs { - t, err := tm.GetTokens(chainID) + t, err := tm.GetTokens(chainID, onlyCommunityCustoms) if err != nil { return nil, err } @@ -408,7 +408,7 @@ func (tm *Manager) GetTokensByChainIDs(chainIDs []uint64) ([]*Token, error) { return tokens, nil } -func (tm *Manager) GetTokens(chainID uint64) ([]*Token, error) { +func (tm *Manager) GetDefaultTokens(chainID uint64) ([]*Token, error) { if !tm.areTokensFetched { tm.fetchTokens() } @@ -423,8 +423,16 @@ func (tm *Manager) GetTokens(chainID uint64) ([]*Token, error) { for _, token := range tokensMap { res = append(res, token) } + return res, nil +} - tokens, err := tm.GetCustomsByChainID(chainID) +func (tm *Manager) GetTokens(chainID uint64, onlyCommunityCustoms bool) ([]*Token, error) { + res, err := tm.GetDefaultTokens(chainID) + if err != nil { + return nil, err + } + + tokens, err := tm.GetCustomsByChainID(chainID, onlyCommunityCustoms) if err != nil { return nil, err } @@ -498,8 +506,11 @@ func (tm *Manager) GetCustoms() ([]*Token, error) { return tm.getTokens("SELECT address, name, symbol, decimals, color, network_id, community_id FROM tokens") } -func (tm *Manager) GetCustomsByChainID(chainID uint64) ([]*Token, error) { - return tm.getTokens("SELECT address, name, symbol, decimals, color, network_id, community_id FROM tokens where network_id=?", chainID) +func (tm *Manager) GetCustomsByChainID(chainID uint64, onlyCommunityCustoms bool) ([]*Token, error) { + if onlyCommunityCustoms { + return tm.getTokens("SELECT address, name, symbol, decimals, color, network_id, community_id FROM tokens WHERE network_id=? AND community_id IS NOT NULL AND community_id != ''", chainID) + } + return tm.getTokens("SELECT address, name, symbol, decimals, color, network_id, community_id FROM tokens WHERE network_id=?", chainID) } func (tm *Manager) IsTokenVisible(chainID uint64, address common.Address) (bool, error) { @@ -575,7 +586,7 @@ func (tm *Manager) GetVisible(chainIDs []uint64) (map[uint64][]*Token, error) { } found := false - tokens, err := tm.GetTokens(chainID) + tokens, err := tm.GetTokens(chainID, false) if err != nil { continue } diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index 3d584d536..d14ee7049 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -164,7 +164,7 @@ func (c *findBlocksCommand) ERC20ScanByBalance(parent context.Context, fromBlock func (c *findBlocksCommand) checkERC20Tail(parent context.Context) ([]*DBHeader, error) { log.Debug("checkERC20Tail", "account", c.account, "to block", c.startBlockNumber, "from", c.resFromBlock.Number) - tokens, err := c.tokenManager.GetTokens(c.chainClient.NetworkID()) + tokens, err := c.tokenManager.GetTokens(c.chainClient.NetworkID(), false) if err != nil { return nil, err }