status-go/services/wallet/history/service.go

448 lines
13 KiB
Go

package history
import (
"context"
"database/sql"
"errors"
"math/big"
"sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
statustypes "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/multiaccounts/accounts"
statusrpc "github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/rpc/network"
"github.com/status-im/status-go/services/wallet/chain"
"github.com/status-im/status-go/services/wallet/token"
"github.com/status-im/status-go/services/wallet/walletevent"
)
// EventBalanceHistoryUpdateStarted and EventBalanceHistoryUpdateDone are used to notify the UI that balance history is being updated
const (
EventBalanceHistoryUpdateStarted walletevent.EventType = "wallet-balance-history-update-started"
EventBalanceHistoryUpdateFinished walletevent.EventType = "wallet-balance-history-update-finished"
EventBalanceHistoryUpdateFinishedWithError walletevent.EventType = "wallet-balance-history-update-finished-with-error"
balanceHistoryUpdateInterval = 12 * time.Hour
)
type Service struct {
balance *Balance
db *sql.DB
eventFeed *event.Feed
rpcClient *statusrpc.Client
networkManager *network.Manager
tokenManager *token.Manager
serviceContext context.Context
cancelFn context.CancelFunc
timer *time.Timer
visibleTokenSymbols []string
visibleTokenSymbolsMutex sync.Mutex // Protects access to visibleSymbols
}
type chainIdentity uint64
func NewService(db *sql.DB, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager) *Service {
return &Service{
balance: NewBalance(NewBalanceDB(db)),
db: db,
eventFeed: eventFeed,
rpcClient: rpcClient,
networkManager: rpcClient.NetworkManager,
tokenManager: tokenManager,
}
}
func (s *Service) Stop() {
if s.cancelFn != nil {
s.cancelFn()
}
}
func (s *Service) triggerEvent(eventType walletevent.EventType, account statustypes.Address, message string) {
s.eventFeed.Send(walletevent.Event{
Type: eventType,
Accounts: []common.Address{
common.Address(account),
},
Message: message,
})
}
func (s *Service) StartBalanceHistory() {
go func() {
s.serviceContext, s.cancelFn = context.WithCancel(context.Background())
s.timer = time.NewTimer(balanceHistoryUpdateInterval)
update := func() (exit bool) {
err := s.updateBalanceHistoryForAllEnabledNetworks(s.serviceContext)
if s.serviceContext.Err() != nil {
s.triggerEvent(EventBalanceHistoryUpdateFinished, statustypes.Address{}, "Service canceled")
s.timer.Stop()
return true
}
if err != nil {
s.triggerEvent(EventBalanceHistoryUpdateFinishedWithError, statustypes.Address{}, err.Error())
}
return false
}
if update() {
return
}
for range s.timer.C {
s.resetTimer(balanceHistoryUpdateInterval)
if update() {
return
}
}
}()
}
func (s *Service) resetTimer(interval time.Duration) {
if s.timer != nil {
s.timer.Stop()
s.timer.Reset(interval)
}
}
func (s *Service) UpdateVisibleTokens(symbols []string) {
s.visibleTokenSymbolsMutex.Lock()
defer s.visibleTokenSymbolsMutex.Unlock()
startUpdate := len(s.visibleTokenSymbols) == 0 && len(symbols) > 0
s.visibleTokenSymbols = symbols
if startUpdate {
s.resetTimer(0)
}
}
func (s *Service) isTokenVisible(tokenSymbol string) bool {
s.visibleTokenSymbolsMutex.Lock()
defer s.visibleTokenSymbolsMutex.Unlock()
for _, visibleSymbol := range s.visibleTokenSymbols {
if visibleSymbol == tokenSymbol {
return true
}
}
return false
}
// Native token implementation of DataSource interface
type chainClientSource struct {
chainClient *chain.Client
currency string
}
func (src *chainClientSource) HeaderByNumber(ctx context.Context, blockNo *big.Int) (*types.Header, error) {
return src.chainClient.HeaderByNumber(ctx, blockNo)
}
func (src *chainClientSource) BalanceAt(ctx context.Context, account common.Address, blockNo *big.Int) (*big.Int, error) {
return src.chainClient.BalanceAt(ctx, account, blockNo)
}
func (src *chainClientSource) ChainID() uint64 {
return src.chainClient.ChainID
}
func (src *chainClientSource) Currency() string {
return src.currency
}
func (src *chainClientSource) TimeNow() int64 {
return time.Now().UTC().Unix()
}
type tokenChainClientSource struct {
chainClientSource
TokenManager *token.Manager
NetworkManager *network.Manager
firstUnavailableBlockNo *big.Int
}
func (src *tokenChainClientSource) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
network := src.NetworkManager.Find(src.chainClient.ChainID)
if network == nil {
return nil, errors.New("network not found")
}
token := src.TokenManager.FindToken(network, src.currency)
if token == nil {
return nil, errors.New("token not found")
}
if src.firstUnavailableBlockNo != nil && blockNumber.Cmp(src.firstUnavailableBlockNo) < 0 {
return big.NewInt(0), nil
}
balance, err := src.TokenManager.GetTokenBalanceAt(ctx, src.chainClient, account, token.Address, blockNumber)
if err != nil {
if err.Error() == "no contract code at given address" {
// Ignore requests before contract deployment and mark this state for future requests
src.firstUnavailableBlockNo = new(big.Int).Set(blockNumber)
return big.NewInt(0), nil
}
return nil, err
}
return balance, err
}
// GetBalanceHistory returns token count balance
// TODO: fetch token to FIAT exchange rates and return FIAT balance
func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, currency string, endTimestamp int64, timeInterval TimeInterval) ([]*DataPoint, error) {
allData := make(map[chainIdentity][]*DataPoint)
for _, chainID := range chainIDs {
data, err := s.balance.get(ctx, chainID, currency, address, endTimestamp, timeInterval)
if err != nil {
return nil, err
}
if len(data) > 0 {
allData[chainIdentity(chainID)] = data
}
}
return mergeDataPoints(allData, strideDuration(timeInterval))
}
// mergeDataPoints merges close in time block numbers. Drops the ones that are not in a stride duration
// this should improve merging balance data from different chains which are incompatible due to different timelines
// and block length
func mergeDataPoints(data map[chainIdentity][]*DataPoint, stride time.Duration) ([]*DataPoint, error) {
if len(data) == 0 {
return make([]*DataPoint, 0), nil
}
pos := make(map[chainIdentity]int)
for k := range data {
pos[k] = 0
}
res := make([]*DataPoint, 0)
strideStart := findFirstStrideWindow(data, stride)
for {
strideEnd := strideStart + int64(stride.Seconds())
// - Gather all points in the stride window starting with current pos
var strideIdentities map[chainIdentity][]timeIdentity
strideIdentities, pos = dataInStrideWindowAndNextPos(data, pos, strideEnd)
// Check if all chains have data
strideComplete := true
for k := range data {
_, strideComplete = strideIdentities[k]
if !strideComplete {
break
}
}
if strideComplete {
chainMaxBalance := make(map[chainIdentity]*DataPoint)
for chainID, identities := range strideIdentities {
for _, identity := range identities {
_, exists := chainMaxBalance[chainID]
if exists && (*big.Int)(identity.dataPoint(data).Value).Cmp((*big.Int)(chainMaxBalance[chainID].Value)) <= 0 {
continue
}
chainMaxBalance[chainID] = identity.dataPoint(data)
}
}
balance := big.NewInt(0)
for _, chainBalance := range chainMaxBalance {
balance.Add(balance, (*big.Int)(chainBalance.Value))
}
res = append(res, &DataPoint{
Timestamp: uint64(strideEnd),
Value: (*hexutil.Big)(balance),
BlockNumber: (*hexutil.Big)(getBlockID(chainMaxBalance)),
})
}
if allPastEnd(data, pos) {
return res, nil
}
strideStart = strideEnd
}
}
func getBlockID(chainBalance map[chainIdentity]*DataPoint) *big.Int {
var res *big.Int
for _, balance := range chainBalance {
if res == nil {
res = new(big.Int).Set(balance.BlockNumber.ToInt())
} else if res.Cmp(balance.BlockNumber.ToInt()) != 0 {
return nil
}
}
return res
}
type timeIdentity struct {
chain chainIdentity
index int
}
func (i timeIdentity) dataPoint(data map[chainIdentity][]*DataPoint) *DataPoint {
return data[i.chain][i.index]
}
func (i timeIdentity) atEnd(data map[chainIdentity][]*DataPoint) bool {
return (i.index + 1) == len(data[i.chain])
}
func (i timeIdentity) pastEnd(data map[chainIdentity][]*DataPoint) bool {
return i.index >= len(data[i.chain])
}
func allPastEnd(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) bool {
for chainID := range pos {
if !(timeIdentity{chainID, pos[chainID]}).pastEnd(data) {
return false
}
}
return true
}
// findFirstStrideWindow returns the start of the first stride window
// Tried to implement finding an optimal stride window but it was becoming too complicated and not worth it given that it will
// potentially save the first and last stride but it is not guaranteed. Current implementation should give good results
// as long as the the DataPoints are regular enough
func findFirstStrideWindow(data map[chainIdentity][]*DataPoint, stride time.Duration) int64 {
pos := make(map[chainIdentity]int)
for k := range data {
pos[k] = 0
}
// Identify the current oldest and newest block
cur := sortTimeAsc(data, pos)
return int64(cur[0].dataPoint(data).Timestamp)
}
func copyMap[K comparable, V any](original map[K]V) map[K]V {
copy := make(map[K]V, len(original))
for key, value := range original {
copy[key] = value
}
return copy
}
// startPos might have indexes past the end of the data for a chain
func dataInStrideWindowAndNextPos(data map[chainIdentity][]*DataPoint, startPos map[chainIdentity]int, endT int64) (identities map[chainIdentity][]timeIdentity, nextPos map[chainIdentity]int) {
pos := copyMap(startPos)
identities = make(map[chainIdentity][]timeIdentity)
// Identify the current oldest and newest block
lastLen := int(-1)
for lastLen < len(identities) {
lastLen = len(identities)
sorted := sortTimeAsc(data, pos)
for _, identity := range sorted {
if identity.dataPoint(data).Timestamp < uint64(endT) {
identities[identity.chain] = append(identities[identity.chain], identity)
pos[identity.chain]++
}
}
}
return identities, pos
}
// sortTimeAsc expect indexes in pos past the end of the data for a chain
func sortTimeAsc(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) []timeIdentity {
res := make([]timeIdentity, 0, len(data))
for k := range data {
identity := timeIdentity{
chain: k,
index: pos[k],
}
if !identity.pastEnd(data) {
res = append(res, identity)
}
}
sort.Slice(res, func(i, j int) bool {
return res[i].dataPoint(data).Timestamp < res[j].dataPoint(data).Timestamp
})
return res
}
// updateBalanceHistoryForAllEnabledNetworks iterates over all enabled and supported networks for the s.visibleTokenSymbol
// and updates the balance history for the given address
//
// expects ctx to have cancellation support and processing to be cancelled by the caller
func (s *Service) updateBalanceHistoryForAllEnabledNetworks(ctx context.Context) error {
accountsDB, err := accounts.NewDB(s.db)
if err != nil {
return err
}
addresses, err := accountsDB.GetWalletAddresses()
if err != nil {
return err
}
networks, err := s.networkManager.Get(true)
if err != nil {
return err
}
for _, address := range addresses {
s.triggerEvent(EventBalanceHistoryUpdateStarted, address, "")
for _, network := range networks {
tokensForChain, err := s.tokenManager.GetTokens(network.ChainID)
if err != nil {
tokensForChain = make([]*token.Token, 0)
}
tokensForChain = append(tokensForChain, s.tokenManager.ToToken(network))
for _, token := range tokensForChain {
if !s.isTokenVisible(token.Symbol) {
continue
}
var dataSource DataSource
chainClient, err := chain.NewClient(s.rpcClient, network.ChainID)
if err != nil {
return err
}
if token.IsNative() {
dataSource = &chainClientSource{chainClient, token.Symbol}
} else {
dataSource = &tokenChainClientSource{
chainClientSource: chainClientSource{
chainClient: chainClient,
currency: token.Symbol,
},
TokenManager: s.tokenManager,
NetworkManager: s.networkManager,
}
}
for currentInterval := int(BalanceHistoryAllTime); currentInterval >= int(BalanceHistory7Days); currentInterval-- {
select {
case <-ctx.Done():
return errors.New("context cancelled")
default:
}
err = s.balance.update(ctx, dataSource, common.Address(address), TimeInterval(currentInterval))
if err != nil {
log.Warn("Error updating balance history", "chainID", dataSource.ChainID(), "currency", dataSource.Currency(), "address", address.String(), "interval", currentInterval, "err", err)
}
}
}
}
s.triggerEvent(EventBalanceHistoryUpdateFinished, address, "")
}
return nil
}