325 lines
10 KiB
Go
325 lines
10 KiB
Go
package balancefetcher
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"math/big"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/ethereum/go-ethereum/accounts/abi/bind"
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/common/hexutil"
|
|
"github.com/status-im/status-go/contracts"
|
|
"github.com/status-im/status-go/contracts/ethscan"
|
|
"github.com/status-im/status-go/contracts/ierc20"
|
|
"github.com/status-im/status-go/logutils"
|
|
"github.com/status-im/status-go/rpc/chain"
|
|
"github.com/status-im/status-go/services/wallet/async"
|
|
)
|
|
|
|
var NativeChainAddress = common.HexToAddress("0x")
|
|
var requestTimeout = 20 * time.Second
|
|
|
|
const (
|
|
tokenChunkSize = 500
|
|
)
|
|
|
|
type BalanceFetcher interface {
|
|
GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error)
|
|
GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error)
|
|
GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error)
|
|
GetBalance(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address) (*big.Int, error)
|
|
GetChainBalance(ctx context.Context, client chain.ClientInterface, account common.Address) (*big.Int, error)
|
|
}
|
|
|
|
type DefaultBalanceFetcher struct {
|
|
contractMaker contracts.ContractMakerIface
|
|
}
|
|
|
|
func NewDefaultBalanceFetcher(contractMaker contracts.ContractMakerIface) *DefaultBalanceFetcher {
|
|
return &DefaultBalanceFetcher{
|
|
contractMaker: contractMaker,
|
|
}
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) fetchBalancesForChain(parent context.Context, client chain.ClientInterface, accounts, tokens []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
var (
|
|
group = async.NewAtomicGroup(parent)
|
|
mu sync.Mutex
|
|
)
|
|
|
|
balances := make(map[common.Address]map[common.Address]*hexutil.Big)
|
|
updateBalance := func(accTokenBalance map[common.Address]map[common.Address]*hexutil.Big) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
for account, tokenBalance := range accTokenBalance {
|
|
if _, ok := balances[account]; !ok {
|
|
balances[account] = make(map[common.Address]*hexutil.Big)
|
|
}
|
|
|
|
for token, balance := range tokenBalance {
|
|
balances[account][token] = balance
|
|
}
|
|
}
|
|
}
|
|
|
|
ethScanContract, availableAtBlock, err := bf.contractMaker.NewEthScan(client.NetworkID())
|
|
if err != nil {
|
|
logutils.ZapLogger().Error("error scanning contract", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
fetchChainBalance := false
|
|
|
|
for _, token := range tokens {
|
|
if token == NativeChainAddress {
|
|
fetchChainBalance = true
|
|
}
|
|
}
|
|
if fetchChainBalance {
|
|
group.Add(func(parent context.Context) error {
|
|
balances, err := bf.FetchChainBalances(parent, accounts, ethScanContract, atBlock)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
updateBalance(balances)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
tokenChunks := splitTokensToChunks(tokens, tokenChunkSize)
|
|
for accountIdx := range accounts {
|
|
// Keep the reference to the account. DO NOT USE A LOOP, the account will be overridden in the coroutine
|
|
account := accounts[accountIdx]
|
|
for idx := range tokenChunks {
|
|
// Keep the reference to the chunk. DO NOT USE A LOOP, the chunk will be overridden in the coroutine
|
|
chunk := tokenChunks[idx]
|
|
|
|
group.Add(func(parent context.Context) error {
|
|
ctx, cancel := context.WithTimeout(parent, requestTimeout)
|
|
defer cancel()
|
|
|
|
var accTokenBalance map[common.Address]map[common.Address]*hexutil.Big
|
|
var err error
|
|
if atBlock == nil || big.NewInt(int64(availableAtBlock)).Cmp(atBlock) < 0 {
|
|
accTokenBalance, err = bf.FetchTokenBalancesWithScanContract(ctx, ethScanContract, account, chunk, atBlock)
|
|
} else {
|
|
accTokenBalance, err = bf.fetchTokenBalancesWithTokenContracts(ctx, client, account, chunk, atBlock)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
updateBalance(accTokenBalance)
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-group.WaitAsync():
|
|
case <-parent.Done():
|
|
return nil, parent.Err()
|
|
}
|
|
return balances, group.Error()
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) FetchChainBalances(parent context.Context, accounts []common.Address, ethScanContract ethscan.BalanceScannerIface, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
|
|
|
|
ctx, cancel := context.WithTimeout(parent, requestTimeout)
|
|
defer cancel()
|
|
|
|
res, err := ethScanContract.EtherBalances(&bind.CallOpts{
|
|
Context: ctx,
|
|
BlockNumber: atBlock,
|
|
}, accounts)
|
|
if err != nil {
|
|
logutils.ZapLogger().Error("can't fetch chain balance 5", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
for idx, account := range accounts {
|
|
balance := new(big.Int)
|
|
balance.SetBytes(res[idx].Data)
|
|
|
|
if _, ok := accTokenBalance[account]; !ok {
|
|
accTokenBalance[account] = make(map[common.Address]*hexutil.Big)
|
|
}
|
|
|
|
accTokenBalance[account][NativeChainAddress] = (*hexutil.Big)(balance)
|
|
}
|
|
|
|
return accTokenBalance, nil
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) FetchTokenBalancesWithScanContract(ctx context.Context, ethScanContract ethscan.BalanceScannerIface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
|
|
res, err := ethScanContract.TokensBalance(&bind.CallOpts{
|
|
Context: ctx,
|
|
BlockNumber: atBlock,
|
|
}, account, chunk)
|
|
if err != nil {
|
|
logutils.ZapLogger().Error("can't fetch erc20 token balance 6", zap.Stringer("account", account), zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
if len(res) != len(chunk) {
|
|
logutils.ZapLogger().Error("can't fetch erc20 token balance 7",
|
|
zap.Stringer("account", account),
|
|
zap.Error(errors.New("response not complete")),
|
|
zap.Int("expected", len(chunk)),
|
|
zap.Int("got", len(res)),
|
|
)
|
|
return nil, errors.New("response not complete")
|
|
}
|
|
|
|
for idx, token := range chunk {
|
|
if !res[idx].Success {
|
|
continue
|
|
}
|
|
balance := new(big.Int)
|
|
balance.SetBytes(res[idx].Data)
|
|
|
|
if _, ok := accTokenBalance[account]; !ok {
|
|
accTokenBalance[account] = make(map[common.Address]*hexutil.Big)
|
|
}
|
|
|
|
accTokenBalance[account][token] = (*hexutil.Big)(balance)
|
|
}
|
|
return accTokenBalance, nil
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) fetchTokenBalancesWithTokenContracts(ctx context.Context, client chain.ClientInterface, account common.Address, chunk []common.Address, atBlock *big.Int) (map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
accTokenBalance := make(map[common.Address]map[common.Address]*hexutil.Big)
|
|
for _, token := range chunk {
|
|
balance, err := bf.GetTokenBalanceAt(ctx, client, account, token, atBlock)
|
|
if err != nil {
|
|
if err != bind.ErrNoCode {
|
|
logutils.ZapLogger().Error("can't fetch erc20 token balance 8",
|
|
zap.Stringer("account", account),
|
|
zap.Stringer("token", token),
|
|
zap.Error(errors.New("on fetching token balance")),
|
|
)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if _, ok := accTokenBalance[account]; !ok {
|
|
accTokenBalance[account] = make(map[common.Address]*hexutil.Big)
|
|
}
|
|
|
|
accTokenBalance[account][token] = (*hexutil.Big)(balance)
|
|
}
|
|
|
|
return accTokenBalance, nil
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) GetTokenBalanceAt(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error) {
|
|
caller, err := bf.contractMaker.NewERC20Caller(client.NetworkID(), token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
balance, err := caller.BalanceOf(&bind.CallOpts{
|
|
Context: ctx,
|
|
BlockNumber: blockNumber,
|
|
}, account)
|
|
|
|
if err != nil {
|
|
if err != bind.ErrNoCode {
|
|
return nil, err
|
|
}
|
|
balance = big.NewInt(0)
|
|
}
|
|
|
|
return balance, nil
|
|
}
|
|
|
|
func splitTokensToChunks(tokens []common.Address, chunkSize int) [][]common.Address {
|
|
tokenChunks := make([][]common.Address, 0)
|
|
for i := 0; i < len(tokens); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(tokens) {
|
|
end = len(tokens)
|
|
}
|
|
|
|
tokenChunks = append(tokenChunks, tokens[i:end])
|
|
}
|
|
|
|
return tokenChunks
|
|
}
|
|
|
|
func (tm *DefaultBalanceFetcher) GetTokenBalance(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address) (*big.Int, error) {
|
|
caller, err := ierc20.NewIERC20Caller(token, client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return caller.BalanceOf(&bind.CallOpts{
|
|
Context: ctx,
|
|
}, account)
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) GetChainBalance(ctx context.Context, client chain.ClientInterface, account common.Address) (*big.Int, error) {
|
|
return client.BalanceAt(ctx, account, nil)
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) GetBalance(ctx context.Context, client chain.ClientInterface, account common.Address, token common.Address) (*big.Int, error) {
|
|
if token == NativeChainAddress {
|
|
return bf.GetChainBalance(ctx, client, account)
|
|
}
|
|
|
|
return bf.GetTokenBalance(ctx, client, account, token)
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) GetBalancesByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
return bf.GetBalancesAtByChain(parent, clients, accounts, tokens, nil)
|
|
}
|
|
|
|
func (bf *DefaultBalanceFetcher) GetBalancesAtByChain(parent context.Context, clients map[uint64]chain.ClientInterface, accounts, tokens []common.Address, atBlocks map[uint64]*big.Int) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
|
|
var (
|
|
group = async.NewAtomicGroup(parent)
|
|
mu sync.Mutex
|
|
response = map[uint64]map[common.Address]map[common.Address]*hexutil.Big{}
|
|
)
|
|
|
|
updateBalance := func(chainID uint64, accTokenBalance map[common.Address]map[common.Address]*hexutil.Big) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if _, ok := response[chainID]; !ok {
|
|
response[chainID] = map[common.Address]map[common.Address]*hexutil.Big{}
|
|
}
|
|
|
|
for account, tokenBalance := range accTokenBalance {
|
|
response[chainID][account] = tokenBalance
|
|
}
|
|
}
|
|
|
|
for clientIdx := range clients {
|
|
// Keep the reference to the client. DO NOT USE A LOOP, the client will be overridden in the coroutine
|
|
client := clients[clientIdx]
|
|
|
|
group.Add(func(parent context.Context) error {
|
|
balances, err := bf.fetchBalancesForChain(parent, client, accounts, tokens, atBlocks[client.NetworkID()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
updateBalance(client.NetworkID(), balances)
|
|
return nil
|
|
})
|
|
}
|
|
select {
|
|
case <-group.WaitAsync():
|
|
case <-parent.Done():
|
|
return nil, parent.Err()
|
|
}
|
|
return response, nil
|
|
}
|