diff --git a/services/wallet/api.go b/services/wallet/api.go index 7036f34a5..21c767ba1 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" "strings" + "time" "github.com/rmg/iso4217" @@ -17,6 +18,7 @@ import ( "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/bridge" "github.com/status-im/status-go/services/wallet/chain" + "github.com/status-im/status-go/services/wallet/history" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" ) @@ -41,6 +43,10 @@ func (api *API) CheckConnected(ctx context.Context) *ConnectedResult { return api.s.CheckConnected(ctx) } +func (api *API) StopWallet(ctx context.Context) error { + return api.s.Stop() +} + func (api *API) GetWalletToken(ctx context.Context, addresses []common.Address) (map[common.Address][]Token, error) { return api.reader.GetWalletToken(ctx, addresses) } @@ -123,9 +129,27 @@ func (api *API) GetTokensBalancesForChainIDs(ctx context.Context, chainIDs []uin return api.s.tokenManager.GetBalances(ctx, clients, accounts, addresses) } -// GetBalanceHistory retrieves native token. Will be extended later to support token balance history -func (api *API) GetBalanceHistory(ctx context.Context, chainID uint64, address common.Address, timeInterval transfer.BalanceHistoryTimeInterval) ([]transfer.BalanceState, error) { - return api.s.transferController.GetBalanceHistory(ctx, chainID, address, timeInterval) +func (api *API) StartBalanceHistory(ctx context.Context) error { + api.s.history.StartBalanceHistory() + return nil +} + +func (api *API) StopBalanceHistory(ctx context.Context) error { + api.s.history.Stop() + return nil +} + +func (api *API) UpdateVisibleTokens(ctx context.Context, symbols []string) error { + api.s.history.UpdateVisibleTokens(symbols) + return nil +} + +// GetBalanceHistory retrieves token balance history for token identity on multiple chains +// TODO: pass parameters by GetBalanceHistoryParameters struct +// TODO: expose endTimestamp parameter +func (api *API) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, currency string, timeInterval history.TimeInterval) ([]*history.DataPoint, error) { + endTimestamp := time.Now().UTC().Unix() + return api.s.history.GetBalanceHistory(ctx, chainIDs, address, currency, endTimestamp, timeInterval) } func (api *API) GetTokens(ctx context.Context, chainID uint64) ([]*token.Token, error) { diff --git a/services/wallet/history/balance.go b/services/wallet/history/balance.go new file mode 100644 index 000000000..a563d8b04 --- /dev/null +++ b/services/wallet/history/balance.go @@ -0,0 +1,350 @@ +package history + +import ( + "context" + "errors" + "math/big" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" +) + +type Balance struct { + db *BalanceDB +} + +type blocksStride = int + +const ( + blockTime = time.Duration(12) * time.Second + aDay = time.Duration(24) * time.Hour +) + +// Must have a common divisor to share common blocks and increase the cache hit +const ( + twiceADayStride blocksStride = blocksStride((time.Duration(12) * time.Hour) / blockTime) + weekStride = 14 * twiceADayStride + fourMonthsStride = 4 /*months*/ * 4 * weekStride +) + +// bitsetFilters used to fetch relevant data points in one batch and to increase cache hit +const ( + filterAllTime bitsetFilter = 1 + filterWeekly bitsetFilter = 1 << 3 + filterTwiceADay bitsetFilter = 1 << 5 +) + +type TimeInterval int + +// Specific time intervals for which balance history can be fetched +const ( + BalanceHistory7Days TimeInterval = iota + 1 + BalanceHistory1Month + BalanceHistory6Months + BalanceHistory1Year + BalanceHistoryAllTime +) + +var timeIntervalDuration = map[TimeInterval]time.Duration{ + BalanceHistory7Days: time.Duration(7) * aDay, + BalanceHistory1Month: time.Duration(30) * aDay, + BalanceHistory6Months: time.Duration(6*30) * aDay, + BalanceHistory1Year: time.Duration(365) * aDay, +} + +var timeIntervalToBitsetFilter = map[TimeInterval]bitsetFilter{ + BalanceHistory7Days: filterTwiceADay, + BalanceHistory1Month: filterTwiceADay, + BalanceHistory6Months: filterWeekly, + BalanceHistory1Year: filterWeekly, + BalanceHistoryAllTime: filterAllTime, +} + +var timeIntervalToStride = map[TimeInterval]blocksStride{ + BalanceHistory7Days: twiceADayStride, + BalanceHistory1Month: twiceADayStride, + BalanceHistory6Months: weekStride, + BalanceHistory1Year: weekStride, + BalanceHistoryAllTime: fourMonthsStride, +} + +func NewBalance(db *BalanceDB) *Balance { + return &Balance{ + db: db, + } +} + +// DataSource used as an abstraction to fetch required data from a specific blockchain +type DataSource interface { + HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) + BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) + ChainID() uint64 + Currency() string + TimeNow() int64 +} + +type DataPoint struct { + Value *hexutil.Big `json:"value"` + Timestamp uint64 `json:"time"` + BlockNumber *hexutil.Big `json:"blockNumber"` +} + +func strideDuration(timeInterval TimeInterval) time.Duration { + return time.Duration(timeIntervalToStride[timeInterval]) * blockTime +} + +// fetchAndCache will process the last available block if blocNo is nil +// reuses previous fetched blocks timestamp to avoid fetching block headers again +func (b *Balance) fetchAndCache(ctx context.Context, source DataSource, address common.Address, blockNo *big.Int, bitset bitsetFilter) (*DataPoint, *big.Int, error) { + var outEntry *entry + var err error + if blockNo != nil { + cached, bitsetList, err := b.db.get(&assetIdentity{source.ChainID(), address, source.Currency()}, blockNo, 1, asc) + if err != nil { + return nil, nil, err + } + if len(cached) > 0 && cached[0].block.Cmp(blockNo) == 0 { + // found a match update bitset + err := b.db.updateBitset(&assetIdentity{source.ChainID(), address, source.Currency()}, blockNo, bitset|bitsetList[0]) + if err != nil { + return nil, nil, err + } + return &DataPoint{ + Value: (*hexutil.Big)(cached[0].balance), + Timestamp: uint64(cached[0].timestamp), + BlockNumber: (*hexutil.Big)(cached[0].block), + }, blockNo, nil + } + + // otherwise try fetch any to get the timestamp info + outEntry, _, err = b.db.getFirst(source.ChainID(), blockNo) + if err != nil { + return nil, nil, err + } + } + var timestamp int64 + if outEntry != nil { + timestamp = outEntry.timestamp + } else { + header, err := source.HeaderByNumber(ctx, blockNo) + if err != nil { + return nil, nil, err + } + blockNo = new(big.Int).Set(header.Number) + timestamp = int64(header.Time) + } + + currentBalance, err := source.BalanceAt(ctx, address, blockNo) + if err != nil { + return nil, nil, err + } + + entry := entry{ + chainID: source.ChainID(), + address: address, + tokenSymbol: source.Currency(), + block: new(big.Int).Set(blockNo), + balance: currentBalance, + timestamp: timestamp, + } + err = b.db.add(&entry, bitset) + if err != nil { + return nil, nil, err + } + + var dataPoint DataPoint + dataPoint.Value = (*hexutil.Big)(currentBalance) + dataPoint.Timestamp = uint64(timestamp) + return &dataPoint, blockNo, nil +} + +// update fetches the balance history for a given asset from DB first and missing information from the blockchain to minimize the RPC calls +// if context is cancelled it will return with error +func (b *Balance) update(ctx context.Context, source DataSource, address common.Address, timeInterval TimeInterval) error { + startTimestamp := int64(0) + fetchTimestamp := int64(0) + endTime := source.TimeNow() + if timeInterval != BalanceHistoryAllTime { + // Ensure we always get the complete range by fetching the next block also + startTimestamp = endTime - int64(timeIntervalDuration[timeInterval].Seconds()) + fetchTimestamp = startTimestamp - int64(strideDuration(timeInterval).Seconds()) + } + identity := &assetIdentity{source.ChainID(), address, source.Currency()} + firstCached, err := b.firstCachedStartingAt(identity, fetchTimestamp, timeInterval) + if err != nil { + return err + } + + var oldestCached *big.Int + var oldestTimestamp int64 + var newestCached *big.Int + if firstCached != nil { + oldestCached = new(big.Int).Set(firstCached.block) + oldestTimestamp = firstCached.timestamp + lastCached, err := b.lastCached(identity, timeInterval) + if err != nil { + return err + } + newestCached = new(big.Int).Set(lastCached.block) + } else { + var fetchBlock *big.Int + lastEntry, _, err := b.db.getLastEntryForChain(source.ChainID()) + if err != nil { + return err + } + if lastEntry != nil { + fetchBlock = new(big.Int).Set(lastEntry.block) + } + mostRecentDataPoint, mostRecentBlock, err := b.fetchAndCache(ctx, source, address, fetchBlock, timeIntervalToBitsetFilter[timeInterval]) + if err != nil { + return err + } + + oldestCached = new(big.Int).Set(mostRecentBlock) + oldestTimestamp = int64(mostRecentDataPoint.Timestamp) + newestCached = new(big.Int).Set(mostRecentBlock) + } + + if oldestTimestamp > startTimestamp { + err := b.fetchBackwardAndCache(ctx, source, address, oldestCached, startTimestamp, timeInterval) + if err != nil { + return err + } + } + + // Fetch forward if didn't update in a stride duration + err = b.fetchForwardAndCache(ctx, source, address, newestCached, timeInterval) + if err != nil { + return err + } + + return nil +} + +// get returns the balance history for the given address and time interval until endTimestamp +func (b *Balance) get(ctx context.Context, chainID uint64, currency string, address common.Address, endTimestamp int64, timeInterval TimeInterval) ([]*DataPoint, error) { + startTimestamp := int64(0) + fetchTimestamp := int64(0) + if timeInterval != BalanceHistoryAllTime { + // Ensure we always get the complete range by fetching the next block also + startTimestamp = endTimestamp - int64(timeIntervalDuration[timeInterval].Seconds()) + fetchTimestamp = startTimestamp - int64(strideDuration(timeInterval).Seconds()) + } + cached, _, err := b.db.filter(&assetIdentity{chainID, address, currency}, nil, &balanceFilter{fetchTimestamp, endTimestamp, expandFlag(timeIntervalToBitsetFilter[timeInterval])}, 200, asc) + if err != nil { + return nil, err + } + + points := make([]*DataPoint, 0, len(cached)+1) + for _, entry := range cached { + dataPoint := DataPoint{ + Value: (*hexutil.Big)(entry.balance), + Timestamp: uint64(entry.timestamp), + BlockNumber: (*hexutil.Big)(entry.block), + } + points = append(points, &dataPoint) + } + + lastCached, _, err := b.db.get(&assetIdentity{chainID, address, currency}, nil, 1, desc) + if err != nil { + return nil, err + } + if len(lastCached) > 0 && len(cached) > 0 && lastCached[0].block.Cmp(cached[len(cached)-1].block) > 0 { + points = append(points, &DataPoint{ + Value: (*hexutil.Big)(lastCached[0].balance), + Timestamp: uint64(lastCached[0].timestamp), + BlockNumber: (*hexutil.Big)(lastCached[0].block), + }) + } + + return points, nil +} + +// fetchBackwardAndCache fetches and adds to DB balance entries starting one stride before the endBlock and stops +// when reaching a block timestamp older than startTimestamp or genesis block +// relies on the approximation of a block length to be blockTime for sampling the data +func (b *Balance) fetchBackwardAndCache(ctx context.Context, source DataSource, address common.Address, endBlock *big.Int, startTimestamp int64, timeInterval TimeInterval) error { + stride := timeIntervalToStride[timeInterval] + nextBlock := new(big.Int).Set(endBlock) + for nextBlock.Cmp(big.NewInt(1)) > 0 { + if shouldCancel(ctx) { + return errors.New("context cancelled") + } + + nextBlock.Sub(nextBlock, big.NewInt(int64(stride))) + if nextBlock.Cmp(big.NewInt(0)) <= 0 { + // we reached the genesis block which doesn't have a usable timestamp, fetch next + nextBlock.SetUint64(1) + } + + dataPoint, _, err := b.fetchAndCache(ctx, source, address, nextBlock, timeIntervalToBitsetFilter[timeInterval]) + if err != nil { + return err + } + + // Allow to go back one stride to match the requested interval + if int64(dataPoint.Timestamp) < startTimestamp { + return nil + } + } + + return nil +} + +// fetchForwardAndCache fetches and adds to DB balance entries starting one stride before the startBlock and stops +// when block not found +// relies on the approximation of a block length to be blockTime +func (b *Balance) fetchForwardAndCache(ctx context.Context, source DataSource, address common.Address, startBlock *big.Int, timeInterval TimeInterval) error { + stride := timeIntervalToStride[timeInterval] + nextBlock := new(big.Int).Set(startBlock) + for { + if shouldCancel(ctx) { + return errors.New("context cancelled") + } + + nextBlock.Add(nextBlock, big.NewInt(int64(stride))) + _, _, err := b.fetchAndCache(ctx, source, address, nextBlock, timeIntervalToBitsetFilter[timeInterval]) + if err != nil { + if err == ethereum.NotFound { + // We overshoot, stop and return what we have + return nil + } + return err + } + } +} + +// firstCachedStartingAt returns first cached entry for the given identity and time interval starting at fetchTimestamp or nil if none found +func (b *Balance) firstCachedStartingAt(identity *assetIdentity, startTimestamp int64, timeInterval TimeInterval) (first *entry, err error) { + entries, _, err := b.db.filter(identity, nil, &balanceFilter{startTimestamp, maxAllRangeTimestamp, expandFlag(timeIntervalToBitsetFilter[timeInterval])}, 1, desc) + if err != nil { + return nil, err + } else if len(entries) == 0 { + return nil, nil + } + return entries[0], nil +} + +// lastCached returns last cached entry for the given identity and time interval or nil if none found +func (b *Balance) lastCached(identity *assetIdentity, timeInterval TimeInterval) (first *entry, err error) { + entries, _, err := b.db.filter(identity, nil, &balanceFilter{minAllRangeTimestamp, maxAllRangeTimestamp, expandFlag(timeIntervalToBitsetFilter[timeInterval])}, 1, desc) + if err != nil { + return nil, err + } else if len(entries) == 0 { + return nil, nil + } + return entries[0], nil +} + +// shouldCancel returns true if the context has been cancelled and task should be aborted +func shouldCancel(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + } + return false +} diff --git a/services/wallet/history/balance_db.go b/services/wallet/history/balance_db.go new file mode 100644 index 000000000..9c9b855f1 --- /dev/null +++ b/services/wallet/history/balance_db.go @@ -0,0 +1,175 @@ +package history + +import ( + "database/sql" + "fmt" + "math" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/status-im/status-go/services/wallet/bigint" +) + +type BalanceDB struct { + db *sql.DB +} + +func NewBalanceDB(sqlDb *sql.DB) *BalanceDB { + return &BalanceDB{ + db: sqlDb, + } +} + +// entry represents a single row in the balance_history table +type entry struct { + chainID uint64 + address common.Address + tokenSymbol string + block *big.Int + timestamp int64 + balance *big.Int +} + +// bitsetFilter stores the time interval for which the data points are matching +type bitsetFilter int + +const ( + minAllRangeTimestamp = 0 + maxAllRangeTimestamp = math.MaxInt64 + bitsetFilterFlagCount = 30 +) + +// expandFlag will generate a bitset that matches all lower value flags (fills the less significant bits of the flag with 1; e.g. 0b1000 -> 0b1111) +func expandFlag(flag bitsetFilter) bitsetFilter { + return (flag << 1) - 1 +} + +func (b *BalanceDB) add(entry *entry, bitset bitsetFilter) error { + _, err := b.db.Exec("INSERT INTO balance_history (chain_id, address, currency, block, timestamp, bitset, balance) VALUES (?, ?, ?, ?, ?, ?, ?)", entry.chainID, entry.address, entry.tokenSymbol, (*bigint.SQLBigInt)(entry.block), entry.timestamp, int(bitset), (*bigint.SQLBigIntBytes)(entry.balance)) + return err +} + +type sortDirection = int + +const ( + asc sortDirection = 0 + desc sortDirection = 1 +) + +type assetIdentity struct { + ChainID uint64 + Address common.Address + TokenSymbol string +} + +// bitset is used so higher values can include lower values to simulate time interval levels and high granularity intervals include lower ones +// minTimestamp and maxTimestamp interval filter the results by timestamp. +type balanceFilter struct { + minTimestamp int64 + maxTimestamp int64 + bitset bitsetFilter +} + +// filters returns a sorted list of entries, empty array if none is found for the given input or nil if error +// if startingAtBlock is provided, the result will start with the provided block number or the next available one +// if startingAtBlock is NOT provided the result will begin from the first available block that matches filter.minTimestamp +// sort defines the order of the result by block number (which correlates also with timestamp) +func (b *BalanceDB) filter(identity *assetIdentity, startingAtBlock *big.Int, filter *balanceFilter, maxEntries int, sort sortDirection) (entries []*entry, bitsetList []bitsetFilter, err error) { + // Start from the first block in case a specific one was not provided + if startingAtBlock == nil { + startingAtBlock = big.NewInt(0) + } + // We are interested in order by timestamp, but we request by block number that correlates to the order of timestamp and it is indexed + var queryStr string + rawQueryStr := "SELECT block, timestamp, balance, bitset FROM balance_history WHERE chain_id = ? AND address = ? AND currency = ? AND block >= ? AND timestamp BETWEEN ? AND ? AND (bitset & ?) > 0 ORDER BY block %s LIMIT ?" + if sort == asc { + queryStr = fmt.Sprintf(rawQueryStr, "ASC") + } else { + queryStr = fmt.Sprintf(rawQueryStr, "DESC") + } + rows, err := b.db.Query(queryStr, identity.ChainID, identity.Address, identity.TokenSymbol, (*bigint.SQLBigInt)(startingAtBlock), filter.minTimestamp, filter.maxTimestamp, filter.bitset, maxEntries) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + result := make([]*entry, 0) + for rows.Next() { + entry := &entry{ + chainID: 0, + address: identity.Address, + tokenSymbol: identity.TokenSymbol, + block: new(big.Int), + balance: new(big.Int), + } + var bitset int + err := rows.Scan((*bigint.SQLBigInt)(entry.block), &entry.timestamp, (*bigint.SQLBigIntBytes)(entry.balance), &bitset) + if err != nil { + return nil, nil, err + } + entry.chainID = identity.ChainID + result = append(result, entry) + bitsetList = append(bitsetList, bitsetFilter(bitset)) + } + return result, bitsetList, nil +} + +// get calls filter that matches all entries +func (b *BalanceDB) get(identity *assetIdentity, startingAtBlock *big.Int, maxEntries int, sort sortDirection) (entries []*entry, bitsetList []bitsetFilter, err error) { + return b.filter(identity, startingAtBlock, &balanceFilter{ + minTimestamp: minAllRangeTimestamp, + maxTimestamp: maxAllRangeTimestamp, + bitset: expandFlag(1 << bitsetFilterFlagCount), + }, maxEntries, sort) +} + +// getFirst returns the first entry for the block or nil if no entry is found +func (b *BalanceDB) getFirst(chainID uint64, block *big.Int) (res *entry, bitset bitsetFilter, err error) { + res = &entry{ + chainID: chainID, + block: new(big.Int).Set(block), + balance: new(big.Int), + } + + queryStr := "SELECT address, currency, timestamp, balance, bitset FROM balance_history WHERE chain_id = ? AND block = ?" + row := b.db.QueryRow(queryStr, chainID, (*bigint.SQLBigInt)(block)) + var bitsetRaw int + + err = row.Scan(&res.address, &res.tokenSymbol, &res.timestamp, (*bigint.SQLBigIntBytes)(res.balance), &bitsetRaw) + if err == sql.ErrNoRows { + return nil, 0, nil + } else if err != nil { + return nil, 0, err + } + + return res, bitsetFilter(bitsetRaw), nil +} + +// getFirst returns the last entry for the chainID or nil if no entry is found +func (b *BalanceDB) getLastEntryForChain(chainID uint64) (res *entry, bitset bitsetFilter, err error) { + res = &entry{ + chainID: chainID, + block: new(big.Int), + balance: new(big.Int), + } + + queryStr := "SELECT address, currency, timestamp, block, balance, bitset FROM balance_history WHERE chain_id = ? ORDER BY block DESC" + row := b.db.QueryRow(queryStr, chainID) + var bitsetRaw int + + err = row.Scan(&res.address, &res.tokenSymbol, &res.timestamp, (*bigint.SQLBigInt)(res.block), (*bigint.SQLBigIntBytes)(res.balance), &bitsetRaw) + if err == sql.ErrNoRows { + return nil, 0, nil + } else if err != nil { + return nil, 0, err + } + + return res, bitsetFilter(bitsetRaw), nil +} + +func (b *BalanceDB) updateBitset(asset *assetIdentity, block *big.Int, newBitset bitsetFilter) error { + // Updating bitset value in place doesn't work. + // Tried "INSERT INTO balance_history ... ON CONFLICT(chain_id, address, currency, block) DO UPDATE SET timestamp=excluded.timestamp, bitset=(bitset | excluded.bitset), balance=excluded.balance" + _, err := b.db.Exec("UPDATE balance_history SET bitset = ? WHERE chain_id = ? AND address = ? AND currency = ? AND block = ?", int(newBitset), asset.ChainID, asset.Address, asset.TokenSymbol, (*bigint.SQLBigInt)(block)) + return err +} diff --git a/services/wallet/history/balance_db_test.go b/services/wallet/history/balance_db_test.go new file mode 100644 index 000000000..9b9c08f8d --- /dev/null +++ b/services/wallet/history/balance_db_test.go @@ -0,0 +1,328 @@ +package history + +import ( + "database/sql" + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/common" + + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/services/wallet/bigint" +) + +func setupBalanceDBTest(t *testing.T) (*BalanceDB, func()) { + db, err := appdatabase.InitializeDB(":memory:", "wallet-history-balance_db-tests", 1) + require.NoError(t, err) + return NewBalanceDB(db), func() { + require.NoError(t, db.Close()) + } +} + +// generateTestDataForElementCount generates dummy consecutive blocks of data for the same chain_id, address and currency +func generateTestDataForElementCount(count int) (result []*entry) { + baseDataPoint := entry{ + chainID: 777, + address: common.Address{7}, + tokenSymbol: "ETH", + block: big.NewInt(11), + balance: big.NewInt(101), + timestamp: 11, + } + + result = make([]*entry, 0, count) + for i := 0; i < count; i++ { + newDataPoint := baseDataPoint + newDataPoint.block = new(big.Int).Add(baseDataPoint.block, big.NewInt(int64(i))) + newDataPoint.balance = new(big.Int).Add(baseDataPoint.balance, big.NewInt(int64(i))) + newDataPoint.timestamp += int64(i) + result = append(result, &newDataPoint) + } + return result +} + +func TestBalanceDBAddDataPoint(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoint := generateTestDataForElementCount(1)[0] + + err := bDB.add(testDataPoint, filterWeekly) + require.NoError(t, err) + + outDataPoint := entry{ + chainID: 0, + block: big.NewInt(0), + balance: big.NewInt(0), + } + rows, err := bDB.db.Query("SELECT * FROM balance_history") + require.NoError(t, err) + + ok := rows.Next() + require.True(t, ok) + + bitset := 0 + err = rows.Scan(&outDataPoint.chainID, &outDataPoint.address, &outDataPoint.tokenSymbol, (*bigint.SQLBigInt)(outDataPoint.block), &outDataPoint.timestamp, &bitset, (*bigint.SQLBigIntBytes)(outDataPoint.balance)) + require.NoError(t, err) + require.NotEqual(t, err, sql.ErrNoRows) + require.Equal(t, testDataPoint, &outDataPoint) + + ok = rows.Next() + require.False(t, ok) +} + +func TestBalanceDBGetOldestDataPoint(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := len(testDataPoints) - 1; i >= 0; i-- { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + outDataPoints, _, err := bDB.get(&assetIdentity{testDataPoints[0].chainID, testDataPoints[0].address, testDataPoints[0].tokenSymbol}, nil, 1, asc) + require.NoError(t, err) + require.NotEqual(t, outDataPoints, nil) + require.Equal(t, outDataPoints[0], testDataPoints[0]) +} + +func TestBalanceDBGetLatestDataPoint(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := 0; i < len(testDataPoints); i++ { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + outDataPoints, _, err := bDB.get(&assetIdentity{testDataPoints[0].chainID, testDataPoints[0].address, testDataPoints[0].tokenSymbol}, nil, 1, desc) + require.NoError(t, err) + require.NotEqual(t, outDataPoints, nil) + require.Equal(t, outDataPoints[0], testDataPoints[len(testDataPoints)-1]) +} + +func TestBalanceDBGetFirst(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := 0; i < len(testDataPoints); i++ { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + duplicateIndex := 2 + newDataPoint := entry{ + chainID: testDataPoints[duplicateIndex].chainID, + address: common.Address{77}, + tokenSymbol: testDataPoints[duplicateIndex].tokenSymbol, + block: new(big.Int).Set(testDataPoints[duplicateIndex].block), + balance: big.NewInt(102), + timestamp: testDataPoints[duplicateIndex].timestamp, + } + err := bDB.add(&newDataPoint, 2) + require.NoError(t, err) + + outDataPoint, _, err := bDB.getFirst(testDataPoints[duplicateIndex].chainID, testDataPoints[duplicateIndex].block) + require.NoError(t, err) + require.NotEqual(t, nil, outDataPoint) + require.Equal(t, testDataPoints[duplicateIndex], outDataPoint) +} + +func TestBalanceDBGetLastEntryForChain(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := 0; i < len(testDataPoints); i++ { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + // Same data with different addresses + for i := 0; i < len(testDataPoints); i++ { + newDataPoint := testDataPoints[i] + newDataPoint.address = common.Address{77} + err := bDB.add(newDataPoint, 1) + require.NoError(t, err) + } + + outDataPoint, _, err := bDB.getLastEntryForChain(testDataPoints[0].chainID) + require.NoError(t, err) + require.NotEqual(t, nil, outDataPoint) + + expectedDataPoint := testDataPoints[len(testDataPoints)-1] + require.Equal(t, expectedDataPoint.chainID, outDataPoint.chainID) + require.Equal(t, expectedDataPoint.tokenSymbol, outDataPoint.tokenSymbol) + require.Equal(t, expectedDataPoint.block, outDataPoint.block) + require.Equal(t, expectedDataPoint.timestamp, outDataPoint.timestamp) + require.Equal(t, expectedDataPoint.balance, outDataPoint.balance) +} + +func TestBalanceDBGetDataPointsInTimeRange(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := 0; i < len(testDataPoints); i++ { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + startIndex := 1 + endIndex := 3 + outDataPoints, _, err := bDB.filter(&assetIdentity{testDataPoints[0].chainID, testDataPoints[0].address, testDataPoints[0].tokenSymbol}, nil, &balanceFilter{testDataPoints[startIndex].timestamp, testDataPoints[endIndex].timestamp, 1}, 100, asc) + require.NoError(t, err) + require.NotEqual(t, outDataPoints, nil) + require.Equal(t, len(outDataPoints), endIndex-startIndex+1) + for i := startIndex; i <= endIndex; i++ { + require.Equal(t, outDataPoints[i-startIndex], testDataPoints[i]) + } +} + +func TestBalanceDBGetClosestDataPointToTimestamp(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(5) + for i := 0; i < len(testDataPoints); i++ { + err := bDB.add(testDataPoints[i], 1) + require.NoError(t, err) + } + + itemToGetIndex := 2 + outDataPoints, _, err := bDB.filter(&assetIdentity{testDataPoints[0].chainID, testDataPoints[0].address, testDataPoints[0].tokenSymbol}, nil, &balanceFilter{testDataPoints[itemToGetIndex].timestamp, maxAllRangeTimestamp, 1}, 1, asc) + require.NoError(t, err) + require.NotEqual(t, outDataPoints, nil) + require.Equal(t, len(outDataPoints), 1) + require.Equal(t, outDataPoints[0], testDataPoints[itemToGetIndex]) +} + +func TestBalanceDBUpdateUpdateBitset(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoints := generateTestDataForElementCount(1) + + err := bDB.add(testDataPoints[0], 1) + require.NoError(t, err) + err = bDB.add(testDataPoints[0], 2) + require.Error(t, err, "Expected \"UNIQUE constraint failed: ...\"") + err = bDB.updateBitset(&assetIdentity{testDataPoints[0].chainID, testDataPoints[0].address, testDataPoints[0].tokenSymbol}, testDataPoints[0].block, 2) + require.NoError(t, err) + + outDataPoint := entry{ + chainID: 0, + block: big.NewInt(0), + balance: big.NewInt(0), + } + rows, err := bDB.db.Query("SELECT * FROM balance_history") + require.NoError(t, err) + + ok := rows.Next() + require.True(t, ok) + + bitset := 0 + err = rows.Scan(&outDataPoint.chainID, &outDataPoint.address, &outDataPoint.tokenSymbol, (*bigint.SQLBigInt)(outDataPoint.block), &outDataPoint.timestamp, &bitset, (*bigint.SQLBigIntBytes)(outDataPoint.balance)) + require.NoError(t, err) + require.NotEqual(t, err, sql.ErrNoRows) + require.Equal(t, testDataPoints[0], &outDataPoint) + require.Equal(t, 2, bitset) + + ok = rows.Next() + require.False(t, ok) +} + +func TestBalanceDBCheckMissingDataPoint(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + testDataPoint := generateTestDataForElementCount(1)[0] + + err := bDB.add(testDataPoint, 1) + require.NoError(t, err) + + missingDataPoint := testDataPoint + missingDataPoint.block = big.NewInt(12) + + outDataPoints, bitset, err := bDB.get(&assetIdentity{missingDataPoint.chainID, missingDataPoint.address, missingDataPoint.tokenSymbol}, missingDataPoint.block, 1, asc) + require.NoError(t, err) + require.Equal(t, 0, len(outDataPoints)) + require.Equal(t, 0, len(bitset)) +} + +func TestBalanceDBBitsetFilter(t *testing.T) { + bDB, cleanDB := setupBalanceDBTest(t) + defer cleanDB() + + data := generateTestDataForElementCount(3) + + for i := 0; i < len(data); i++ { + err := bDB.add(data[i], 1< 0 { + return nil, ethereum.NotFound + } else { + require.Greater(src.t, number.Int64(), int64(0)) + blockNo = number.Int64() + } + timestamp := src.blockNumberToTimestamp(blockNo) + + if _, contains := src.requestedBlocks[blockNo]; contains { + src.requestedBlocks[blockNo].headerInfoRequests++ + } else { + src.requestedBlocks[blockNo] = &requestedBlock{ + time: uint64(timestamp), + headerInfoRequests: 1, + } + } + + return src.generateBlockInfo(blockNo, uint64(timestamp)), nil +} + +func (src *chainClientTestSource) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + return src.balanceAtFn(ctx, account, blockNumber) +} + +func weiInEth() *big.Int { + res, _ := new(big.Int).SetString("1000000000000000000", 0) + return res +} + +func (src *chainClientTestSource) BalanceAtMock(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + var blockNo int64 + if blockNumber == nil { + // Last block was requested + blockNo = src.blocksCount() + } else if blockNumber.Cmp(big.NewInt(src.blocksCount())) > 0 { + return nil, ethereum.NotFound + } else { + require.Greater(src.t, blockNumber.Int64(), int64(0)) + blockNo = blockNumber.Int64() + } + + if _, contains := src.requestedBlocks[blockNo]; contains { + src.requestedBlocks[blockNo].balanceRequests++ + } else { + src.requestedBlocks[blockNo] = &requestedBlock{ + time: uint64(src.blockNumberToTimestamp(blockNo)), + balanceRequests: 1, + } + } + + return new(big.Int).Mul(big.NewInt(blockNo), weiInEth()), nil +} + +func (src *chainClientTestSource) ChainID() uint64 { + return 777 +} + +func (src *chainClientTestSource) Currency() string { + return "eth" +} + +func (src *chainClientTestSource) TimeNow() int64 { + if src.firstTimeRequest == 0 { + src.firstTimeRequest = time.Now().UTC().Unix() + } + return src.mockTime + (time.Now().UTC().Unix() - src.firstTimeRequest) +} + +// extractTestData returns reqBlkNos sorted in ascending order +func extractTestData(dataSource *chainClientTestSource) (reqBlkNos []int64, infoRequests map[int64]int, balanceRequests map[int64]int) { + reqBlkNos = make([]int64, 0, len(dataSource.requestedBlocks)) + for blockNo := range dataSource.requestedBlocks { + reqBlkNos = append(reqBlkNos, blockNo) + } + sort.Slice(reqBlkNos, func(i, j int) bool { + return reqBlkNos[i] < reqBlkNos[j] + }) + + infoRequests = make(map[int64]int) + balanceRequests = make(map[int64]int, len(reqBlkNos)) + for i := 0; i < len(reqBlkNos); i++ { + n := reqBlkNos[i] + rB := dataSource.requestedBlocks[n] + + if rB.headerInfoRequests > 0 { + infoRequests[n] = rB.headerInfoRequests + } + if rB.balanceRequests > 0 { + balanceRequests[n] = rB.balanceRequests + } + } + return +} + +func minimumExpectedDataPoints(interval TimeInterval) int { + return int(math.Ceil(float64(timeIntervalDuration[interval]) / float64(strideDuration(interval)))) +} + +func getTimeError(dataSource *chainClientTestSource, data []*DataPoint, interval TimeInterval) int64 { + timeRange := int64(data[len(data)-1].Timestamp - data[0].Timestamp) + var expectedDuration int64 + if interval != BalanceHistoryAllTime { + expectedDuration = int64(timeIntervalDuration[interval].Seconds()) + } else { + expectedDuration = int64((time.Duration(dataSource.availableYears()) * oneYear).Seconds()) + } + return timeRange - expectedDuration +} + +func TestBalanceHistoryGetWithoutFetch(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + currentTimestamp := dataSource.TimeNow() + + testData := []struct { + name string + interval TimeInterval + }{ + {"Week", BalanceHistory7Days}, + {"Month", BalanceHistory1Month}, + {"HalfYear", BalanceHistory6Months}, + {"Year", BalanceHistory1Year}, + {"AllTime", BalanceHistoryAllTime}, + } + for _, testInput := range testData { + t.Run(testInput.name, func(t *testing.T) { + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, testInput.interval) + require.NoError(t, err) + require.Equal(t, 0, len(balanceData)) + }) + } +} + +func TestBalanceHistoryGetWithoutOverlappingFetch(t *testing.T) { + testData := []struct { + name string + interval TimeInterval + }{ + {"Week", BalanceHistory7Days}, + {"Month", BalanceHistory1Month}, + {"HalfYear", BalanceHistory6Months}, + {"Year", BalanceHistory1Year}, + {"AllTime", BalanceHistoryAllTime}, + } + for _, testInput := range testData { + t.Run(testInput.name, func(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + currentTimestamp := dataSource.TimeNow() + getUntilTimestamp := currentTimestamp - int64((400 /*days*/ * 24 * time.Hour).Seconds()) + + fetchInterval := testInput.interval + 3 + if fetchInterval > BalanceHistoryAllTime { + fetchInterval = BalanceHistory7Days + BalanceHistoryAllTime - testInput.interval + } + err := bh.update(context.Background(), dataSource, common.Address{7}, fetchInterval) + require.NoError(t, err) + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, getUntilTimestamp, testInput.interval) + require.NoError(t, err) + require.Equal(t, 0, len(balanceData)) + }) + } +} + +func TestBalanceHistoryGetWithOverlappingFetch(t *testing.T) { + testData := []struct { + name string + interval TimeInterval + lessDaysToGet int + }{ + {"Week", BalanceHistory7Days, 6}, + {"Month", BalanceHistory1Month, 1}, + {"HalfYear", BalanceHistory6Months, 8}, + {"Year", BalanceHistory1Year, 16}, + {"AllTime", BalanceHistoryAllTime, 130}, + } + for _, testInput := range testData { + t.Run(testInput.name, func(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + currentTimestamp := dataSource.TimeNow() + olderUntilTimestamp := currentTimestamp - int64((time.Duration(testInput.lessDaysToGet) * 24 * time.Hour).Seconds()) + + err := bh.update(context.Background(), dataSource, common.Address{7}, testInput.interval) + require.NoError(t, err) + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, testInput.interval) + require.NoError(t, err) + require.GreaterOrEqual(t, len(balanceData), minimumExpectedDataPoints(testInput.interval)) + + olderBalanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, olderUntilTimestamp, testInput.interval) + require.NoError(t, err) + require.Less(t, len(olderBalanceData), len(balanceData)) + }) + } +} + +func TestBalanceHistoryFetchFirstTime(t *testing.T) { + testData := []struct { + name string + interval TimeInterval + }{ + {"Week", BalanceHistory7Days}, + {"Month", BalanceHistory1Month}, + {"HalfYear", BalanceHistory6Months}, + {"Year", BalanceHistory1Year}, + {"AllTime", BalanceHistoryAllTime}, + } + for _, testInput := range testData { + t.Run(testInput.name, func(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + currentTimestamp := dataSource.TimeNow() + + err := bh.update(context.Background(), dataSource, common.Address{7}, testInput.interval) + require.NoError(t, err) + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, testInput.interval) + require.NoError(t, err) + require.GreaterOrEqual(t, len(balanceData), minimumExpectedDataPoints(testInput.interval)) + + reqBlkNos, headerInfos, balances := extractTestData(dataSource) + require.Equal(t, len(balanceData), len(balances)) + + // Ensure we don't request the same info twice + for block, count := range headerInfos { + require.Equal(t, 1, count, "block %d has one info request", block) + if balanceCount, contains := balances[block]; contains { + require.Equal(t, 1, balanceCount, "block %d has one balance request", block) + } + } + for block, count := range balances { + require.Equal(t, 1, count, "block %d has one request", block) + } + + resIdx := 0 + for i := 0; i < len(reqBlkNos); i++ { + n := reqBlkNos[i] + rB := dataSource.requestedBlocks[n] + + if _, contains := balances[n]; contains { + require.Equal(t, rB.time, balanceData[resIdx].Timestamp) + if resIdx > 0 { + require.Greater(t, balanceData[resIdx].Timestamp, balanceData[resIdx-1].Timestamp, "result timestamps are in order") + } + resIdx++ + } + } + + errorFromIdeal := getTimeError(dataSource, balanceData, testInput.interval) + require.Less(t, math.Abs(float64(errorFromIdeal)), strideDuration(testInput.interval).Seconds(), "Duration error [%d s] is within 1 stride [%.f s] for interval [%#v]", errorFromIdeal, strideDuration(testInput.interval).Seconds(), testInput.interval) + }) + } +} + +func TestBalanceHistoryFetchError(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + bkFn := dataSource.balanceAtFn + // Fail first request + dataSource.balanceAtFn = func(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + return nil, errors.New("test error") + } + currentTimestamp := dataSource.TimeNow() + err := bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Year) + require.Error(t, err, "Expect \"test error\"") + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, BalanceHistory1Year) + require.NoError(t, err) + require.Equal(t, 0, len(balanceData)) + + _, headerInfos, balances := extractTestData(dataSource) + require.Equal(t, 0, len(balances)) + require.Equal(t, 1, len(headerInfos)) + + dataSource.resetStats() + // Fail later + dataSource.balanceAtFn = func(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + if len(dataSource.requestedBlocks) == 15 { + return nil, errors.New("test error") + } + return dataSource.BalanceAtMock(ctx, account, blockNumber) + } + err = bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Year) + require.Error(t, err, "Expect \"test error\"") + + balanceData, err = bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, BalanceHistory1Year) + require.NoError(t, err) + require.Equal(t, 14, len(balanceData)) + + reqBlkNos, headerInfos, balances := extractTestData(dataSource) + // The request for block info is made before the balance request + require.Equal(t, 1, dataSource.requestedBlocks[reqBlkNos[0]].headerInfoRequests) + require.Equal(t, 0, dataSource.requestedBlocks[reqBlkNos[0]].balanceRequests) + require.Equal(t, 14, len(balances)) + require.Equal(t, len(balances), len(headerInfos)-1) + + dataSource.resetStats() + dataSource.balanceAtFn = bkFn + err = bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Year) + require.NoError(t, err) + + balanceData, err = bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, BalanceHistory1Year) + require.NoError(t, err) + require.GreaterOrEqual(t, len(balanceData), minimumExpectedDataPoints(BalanceHistory1Year)) + + _, headerInfos, balances = extractTestData(dataSource) + // Account for cache hits + require.Equal(t, len(balanceData)-14, len(balances)) + require.Equal(t, len(balances), len(headerInfos)) + + for i := 1; i < len(balanceData); i++ { + require.Greater(t, balanceData[i].Timestamp, balanceData[i-1].Timestamp, "result timestamps are in order") + } + + errorFromIdeal := getTimeError(dataSource, balanceData, BalanceHistory1Year) + require.Less(t, math.Abs(float64(errorFromIdeal)), strideDuration(BalanceHistory1Year).Seconds(), "Duration error [%d s] is within 1 stride [%.f s] for interval [%#v]", errorFromIdeal, strideDuration(BalanceHistory1Year).Seconds(), BalanceHistory1Year) +} + +func TestBalanceHistoryValidateBalanceValuesAndCacheHit(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 20 /*years*/) + currentTimestamp := dataSource.TimeNow() + requestedBalance := make(map[int64]*big.Int) + dataSource.balanceAtFn = func(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + balance, err := dataSource.BalanceAtMock(ctx, account, blockNumber) + requestedBalance[blockNumber.Int64()] = new(big.Int).Set(balance) + return balance, err + } + + testData := []struct { + name string + interval TimeInterval + }{ + {"Week", BalanceHistory7Days}, + {"Month", BalanceHistory1Month}, + {"HalfYear", BalanceHistory6Months}, + {"Year", BalanceHistory1Year}, + {"AllTime", BalanceHistoryAllTime}, + } + for _, testInput := range testData { + t.Run(testInput.name, func(t *testing.T) { + dataSource.resetStats() + err := bh.update(context.Background(), dataSource, common.Address{7}, testInput.interval) + require.NoError(t, err) + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, currentTimestamp, testInput.interval) + require.NoError(t, err) + require.GreaterOrEqual(t, len(balanceData), minimumExpectedDataPoints(testInput.interval)) + + reqBlkNos, headerInfos, _ := extractTestData(dataSource) + // Only first run is not affected by cache + if testInput.interval == BalanceHistory7Days { + require.Equal(t, len(balanceData), len(requestedBalance)) + require.Equal(t, len(balanceData), len(headerInfos)) + } else { + require.Greater(t, len(balanceData), len(requestedBalance)) + require.Greater(t, len(balanceData), len(headerInfos)) + } + + resIdx := 0 + // Check that balance values are the one requested + for i := 0; i < len(reqBlkNos); i++ { + n := reqBlkNos[i] + + if value, contains := requestedBalance[n]; contains { + require.Equal(t, value.Cmp(balanceData[resIdx].Value.ToInt()), 0) + resIdx++ + } + blockHeaderRequestCount := dataSource.requestedBlocks[n].headerInfoRequests + require.Less(t, blockHeaderRequestCount, 2) + blockBalanceRequestCount := dataSource.requestedBlocks[n].balanceRequests + require.Less(t, blockBalanceRequestCount, 2) + } + + // Check that balance values are in order + for i := 1; i < len(balanceData); i++ { + require.Greater(t, balanceData[i].Value.ToInt().Cmp(balanceData[i-1].Value.ToInt()), 0, "expected balanceData[%d] > balanceData[%d] for interval %d", i, i-1, testInput.interval) + } + requestedBalance = make(map[int64]*big.Int) + }) + } +} + +func TestGetBalanceHistoryUpdateLater(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + currentTime := getTestTime(t) + initialTime := currentTime + moreThanADay := 24*time.Hour + 15*time.Minute + moreThanAMonth := 401 * moreThanADay + initialTime = initialTime.Add(-moreThanADay - moreThanAMonth) + dataSource := newTestSourceWithCurrentTime(t, 20 /*years*/, initialTime.Unix()) + + err := bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Month) + require.NoError(t, err) + + prevBalanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, dataSource.TimeNow(), BalanceHistory1Month) + require.NoError(t, err) + require.GreaterOrEqual(t, len(prevBalanceData), minimumExpectedDataPoints(BalanceHistory1Month)) + + // Advance little bit more than a day + later := initialTime + later = later.Add(moreThanADay) + dataSource.setCurrentTime(later.Unix()) + dataSource.resetStats() + + err = bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Month) + require.NoError(t, err) + + updatedBalanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, dataSource.TimeNow(), BalanceHistory1Month) + require.NoError(t, err) + require.GreaterOrEqual(t, len(updatedBalanceData), minimumExpectedDataPoints(BalanceHistory1Month)) + + reqBlkNos, blockInfos, balances := extractTestData(dataSource) + require.Equal(t, 2, len(reqBlkNos)) + require.Equal(t, len(reqBlkNos), len(blockInfos)) + require.Equal(t, len(blockInfos), len(balances)) + + for block, count := range balances { + require.Equal(t, 1, count, "block %d has one request", block) + } + + resIdx := len(updatedBalanceData) - 2 + for i := 0; i < len(reqBlkNos); i++ { + rB := dataSource.requestedBlocks[reqBlkNos[i]] + + // Ensure block approximation error doesn't exceed 10 blocks + require.Equal(t, 0.0, math.Abs(float64(int64(rB.time)-int64(updatedBalanceData[resIdx].Timestamp)))) + if resIdx > 0 { + // Ensure result timestamps are in order + require.Greater(t, updatedBalanceData[resIdx].Timestamp, updatedBalanceData[resIdx-1].Timestamp) + } + resIdx++ + } + + errorFromIdeal := getTimeError(dataSource, updatedBalanceData, BalanceHistory1Month) + require.Less(t, math.Abs(float64(errorFromIdeal)), strideDuration(BalanceHistory1Month).Seconds(), "Duration error [%d s] is within 1 stride [%.f s] for interval [%#v]", errorFromIdeal, strideDuration(BalanceHistory1Month).Seconds(), BalanceHistory1Month) + + // Advance little bit more than a month + dataSource.setCurrentTime(currentTime.Unix()) + dataSource.resetStats() + + err = bh.update(context.Background(), dataSource, common.Address{7}, BalanceHistory1Month) + require.NoError(t, err) + + newBalanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, dataSource.TimeNow(), BalanceHistory1Month) + require.NoError(t, err) + require.GreaterOrEqual(t, len(newBalanceData), minimumExpectedDataPoints(BalanceHistory1Month)) + + _, headerInfos, balances := extractTestData(dataSource) + require.Greater(t, len(balances), len(newBalanceData), "Expected more balance requests due to missing time catch up") + + // Ensure we don't request the same info twice + for block, count := range headerInfos { + require.Equal(t, 1, count, "block %d has one info request", block) + if balanceCount, contains := balances[block]; contains { + require.Equal(t, 1, balanceCount, "block %d has one balance request", block) + } + } + for block, count := range balances { + require.Equal(t, 1, count, "block %d has one request", block) + } + + for i := 1; i < len(newBalanceData); i++ { + require.Greater(t, newBalanceData[i].Timestamp, newBalanceData[i-1].Timestamp, "result timestamps are in order") + } + + errorFromIdeal = getTimeError(dataSource, newBalanceData, BalanceHistory1Month) + require.Less(t, math.Abs(float64(errorFromIdeal)), strideDuration(BalanceHistory1Month).Seconds(), "Duration error [%d s] is within 1 stride [%.f s] for interval [%#v]", errorFromIdeal, strideDuration(BalanceHistory1Month).Seconds(), BalanceHistory1Month) +} + +func TestGetBalanceHistoryFetchMultipleAccounts(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + sevenDataSource := newTestSource(t, 5 /*years*/) + + err := bh.update(context.Background(), sevenDataSource, common.Address{7}, BalanceHistory1Month) + require.NoError(t, err) + + sevenBalanceData, err := bh.get(context.Background(), sevenDataSource.ChainID(), sevenDataSource.Currency(), common.Address{7}, sevenDataSource.TimeNow(), BalanceHistory1Month) + require.NoError(t, err) + require.GreaterOrEqual(t, len(sevenBalanceData), minimumExpectedDataPoints(BalanceHistory1Month)) + + _, sevenBlockInfos, _ := extractTestData(sevenDataSource) + require.Greater(t, len(sevenBlockInfos), 0) + + nineDataSource := newTestSource(t, 5 /*years*/) + err = bh.update(context.Background(), nineDataSource, common.Address{9}, BalanceHistory1Month) + require.NoError(t, err) + + nineBalanceData, err := bh.get(context.Background(), nineDataSource.ChainID(), nineDataSource.Currency(), common.Address{7}, nineDataSource.TimeNow(), BalanceHistory1Month) + require.NoError(t, err) + require.GreaterOrEqual(t, len(nineBalanceData), minimumExpectedDataPoints(BalanceHistory1Month)) + + _, nineBlockInfos, nineBalances := extractTestData(nineDataSource) + require.Equal(t, 0, len(nineBlockInfos)) + require.Equal(t, len(nineBalanceData), len(nineBalances)) +} + +func TestGetBalanceHistoryUpdateCancellation(t *testing.T) { + bh, cleanDB := setupBalanceTest(t) + defer cleanDB() + + dataSource := newTestSource(t, 5 /*years*/) + ctx, cancelFn := context.WithCancel(context.Background()) + bkFn := dataSource.balanceAtFn + // Fail after 15 requests + dataSource.balanceAtFn = func(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + if len(dataSource.requestedBlocks) == 15 { + cancelFn() + } + return dataSource.BalanceAtMock(ctx, account, blockNumber) + } + err := bh.update(ctx, dataSource, common.Address{7}, BalanceHistory1Year) + require.Error(t, ctx.Err(), "Service canceled") + require.Error(t, err, "context cancelled") + + balanceData, err := bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, dataSource.TimeNow(), BalanceHistory1Year) + require.NoError(t, err) + require.Equal(t, 15, len(balanceData)) + + _, blockInfos, balances := extractTestData(dataSource) + // The request for block info is made before the balance fails + require.Equal(t, 15, len(balances)) + require.Equal(t, 15, len(blockInfos)) + + dataSource.balanceAtFn = bkFn + ctx, cancelFn = context.WithCancel(context.Background()) + + err = bh.update(ctx, dataSource, common.Address{7}, BalanceHistory1Year) + require.NoError(t, ctx.Err()) + require.NoError(t, err) + + balanceData, err = bh.get(context.Background(), dataSource.ChainID(), dataSource.Currency(), common.Address{7}, dataSource.TimeNow(), BalanceHistory1Year) + require.NoError(t, err) + require.GreaterOrEqual(t, len(balanceData), minimumExpectedDataPoints(BalanceHistory1Year)) + cancelFn() +} + +func TestBlockStrideHaveCommonDivisor(t *testing.T) { + values := make([]blocksStride, 0, len(timeIntervalToStride)) + for _, blockCount := range timeIntervalToStride { + values = append(values, blockCount) + } + sort.Slice(values, func(i, j int) bool { + return values[i] < values[j] + }) + for i := 1; i < len(values); i++ { + require.Equal(t, blocksStride(0), values[i]%values[i-1], " %d value from index %d is divisible with previous %d", values[i], i, values[i-1]) + } +} + +func TestBlockStrideMatchesBitsetFilter(t *testing.T) { + filterToStrideEquivalence := map[bitsetFilter]blocksStride{ + filterAllTime: fourMonthsStride, + filterWeekly: weekStride, + filterTwiceADay: twiceADayStride, + } + + for interval, bitsetFiler := range timeIntervalToBitsetFilter { + stride, found := timeIntervalToStride[interval] + require.True(t, found) + require.Equal(t, stride, filterToStrideEquivalence[bitsetFiler]) + } +} + +func TestTimeIntervalToBitsetFilterAreConsecutiveFlags(t *testing.T) { + values := make([]int, 0, len(timeIntervalToBitsetFilter)) + for i := BalanceHistoryAllTime; i >= BalanceHistory7Days; i-- { + values = append(values, int(timeIntervalToBitsetFilter[i])) + } + + for i := 0; i < len(values); i++ { + // count number of bits set + count := 0 + for j := 0; j <= 30; j++ { + if values[i]&(1< 0 { + require.GreaterOrEqual(t, values[i], values[i-1], "%b value from index %d is higher then previous %d", values[i], i, values[i-1]) + } + } +} diff --git a/services/wallet/history/service.go b/services/wallet/history/service.go new file mode 100644 index 000000000..0a15352aa --- /dev/null +++ b/services/wallet/history/service.go @@ -0,0 +1,346 @@ +package history + +import ( + "context" + "database/sql" + "errors" + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + + statustypes "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/multiaccounts/accounts" + statusrpc "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/rpc/network" + + "github.com/status-im/status-go/services/wallet/chain" + "github.com/status-im/status-go/services/wallet/token" + "github.com/status-im/status-go/services/wallet/walletevent" +) + +// EventBalanceHistoryUpdateStarted and EventBalanceHistoryUpdateDone are used to notify the UI that balance history is being updated +const ( + EventBalanceHistoryUpdateStarted walletevent.EventType = "wallet-balance-history-update-started" + EventBalanceHistoryUpdateFinished walletevent.EventType = "wallet-balance-history-update-finished" + EventBalanceHistoryUpdateFinishedWithError walletevent.EventType = "wallet-balance-history-update-finished-with-error" + + balanceHistoryUpdateInterval = 12 * time.Hour +) + +type Service struct { + balance *Balance + db *sql.DB + eventFeed *event.Feed + rpcClient *statusrpc.Client + networkManager *network.Manager + tokenManager *token.Manager + serviceContext context.Context + cancelFn context.CancelFunc + + timer *time.Timer + visibleTokenSymbols []string + visibleTokenSymbolsMutex sync.Mutex // Protects access to visibleSymbols +} + +func NewService(db *sql.DB, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager) *Service { + return &Service{ + balance: NewBalance(NewBalanceDB(db)), + db: db, + eventFeed: eventFeed, + rpcClient: rpcClient, + networkManager: rpcClient.NetworkManager, + tokenManager: tokenManager, + } +} + +func (s *Service) Stop() { + if s.cancelFn != nil { + s.cancelFn() + } +} + +func (s *Service) triggerEvent(eventType walletevent.EventType, account statustypes.Address, message string) { + s.eventFeed.Send(walletevent.Event{ + Type: eventType, + Accounts: []common.Address{ + common.Address(account), + }, + Message: message, + }) +} + +func (s *Service) StartBalanceHistory() { + go func() { + s.serviceContext, s.cancelFn = context.WithCancel(context.Background()) + s.timer = time.NewTimer(balanceHistoryUpdateInterval) + + update := func() (exit bool) { + err := s.updateBalanceHistoryForAllEnabledNetworks(s.serviceContext) + if s.serviceContext.Err() != nil { + s.triggerEvent(EventBalanceHistoryUpdateFinished, statustypes.Address{}, "Service canceled") + s.timer.Stop() + return true + } + if err != nil { + s.triggerEvent(EventBalanceHistoryUpdateFinishedWithError, statustypes.Address{}, err.Error()) + } + return false + } + + if update() { + return + } + + for range s.timer.C { + s.resetTimer(balanceHistoryUpdateInterval) + + if update() { + return + } + } + }() +} + +func (s *Service) resetTimer(interval time.Duration) { + if s.timer != nil { + s.timer.Stop() + s.timer.Reset(interval) + } +} + +func (s *Service) UpdateVisibleTokens(symbols []string) { + s.visibleTokenSymbolsMutex.Lock() + defer s.visibleTokenSymbolsMutex.Unlock() + + startUpdate := len(s.visibleTokenSymbols) == 0 && len(symbols) > 0 + s.visibleTokenSymbols = symbols + if startUpdate { + s.resetTimer(0) + } +} + +func (s *Service) isTokenVisible(tokenSymbol string) bool { + s.visibleTokenSymbolsMutex.Lock() + defer s.visibleTokenSymbolsMutex.Unlock() + + for _, visibleSymbol := range s.visibleTokenSymbols { + if visibleSymbol == tokenSymbol { + return true + } + } + return false +} + +// Native token implementation of DataSource interface +type chainClientSource struct { + chainClient *chain.Client + currency string +} + +func (src *chainClientSource) HeaderByNumber(ctx context.Context, blockNo *big.Int) (*types.Header, error) { + return src.chainClient.HeaderByNumber(ctx, blockNo) +} + +func (src *chainClientSource) BalanceAt(ctx context.Context, account common.Address, blockNo *big.Int) (*big.Int, error) { + return src.chainClient.BalanceAt(ctx, account, blockNo) +} + +func (src *chainClientSource) ChainID() uint64 { + return src.chainClient.ChainID +} + +func (src *chainClientSource) Currency() string { + return src.currency +} + +func (src *chainClientSource) TimeNow() int64 { + return time.Now().UTC().Unix() +} + +type tokenChainClientSource struct { + chainClientSource + TokenManager *token.Manager + NetworkManager *network.Manager + + firstUnavailableBlockNo *big.Int +} + +func (src *tokenChainClientSource) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { + network := src.NetworkManager.Find(src.chainClient.ChainID) + if network == nil { + return nil, errors.New("network not found") + } + token := src.TokenManager.FindToken(network, src.currency) + if token == nil { + return nil, errors.New("token not found") + } + if src.firstUnavailableBlockNo != nil && blockNumber.Cmp(src.firstUnavailableBlockNo) < 0 { + return big.NewInt(0), nil + } + balance, err := src.TokenManager.GetTokenBalanceAt(ctx, src.chainClient, account, token.Address, blockNumber) + if err != nil { + if err.Error() == "no contract code at given address" { + // Ignore requests before contract deployment and mark this state for future requests + src.firstUnavailableBlockNo = new(big.Int).Set(blockNumber) + return big.NewInt(0), nil + } + return nil, err + } + return balance, err +} + +// GetBalanceHistory returns token count balance +// TODO: fetch token to FIAT exchange rates and return FIAT balance +func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, currency string, endTimestamp int64, timeInterval TimeInterval) ([]*DataPoint, error) { + allData := make(map[uint64][]*DataPoint) + for _, chainID := range chainIDs { + data, err := s.balance.get(ctx, chainID, currency, address, endTimestamp, timeInterval) + if err != nil { + return nil, err + } + if len(data) > 0 { + allData[chainID] = data + } + } + + return mergeDataPoints(allData) +} + +// mergeDataPoints merges same block numbers from different chains which are incompatible due to different timelines +// TODO: use time-based intervals instead of block numbers +func mergeDataPoints(data map[uint64][]*DataPoint) ([]*DataPoint, error) { + if len(data) == 0 { + return make([]*DataPoint, 0), nil + } + + pos := make(map[uint64]int) + for k := range data { + pos[k] = 0 + } + + res := make([]*DataPoint, 0) + done := false + for !done { + var minNo *big.Int + var timestamp uint64 + // Take the smallest block number + for k := range data { + blockNo := new(big.Int).Set(data[k][pos[k]].BlockNumber.ToInt()) + if minNo == nil { + minNo = new(big.Int).Set(blockNo) + // We use it only if we have a full match + timestamp = data[k][pos[k]].Timestamp + } else if blockNo.Cmp(minNo) < 0 { + minNo.Set(blockNo) + } + } + // If all chains have the same block number sum it; also increment the processed position + sumOfAll := big.NewInt(0) + for k := range data { + cur := data[k][pos[k]] + if cur.BlockNumber.ToInt().Cmp(minNo) == 0 { + pos[k]++ + if sumOfAll != nil { + sumOfAll.Add(sumOfAll, cur.Value.ToInt()) + } + } else { + sumOfAll = nil + } + } + // If sum of all make sense add it to the result otherwise ignore it + if sumOfAll != nil { + // TODO: convert to FIAT value + res = append(res, &DataPoint{ + BlockNumber: (*hexutil.Big)(minNo), + Timestamp: timestamp, + Value: (*hexutil.Big)(sumOfAll), + }) + } + + // Check if we reached the end of any chain + for k := range data { + if pos[k] == len(data[k]) { + done = true + break + } + } + } + return res, nil +} + +// updateBalanceHistoryForAllEnabledNetworks iterates over all enabled and supported networks for the s.visibleTokenSymbol +// and updates the balance history for the given address +// +// expects ctx to have cancellation support and processing to be cancelled by the caller +func (s *Service) updateBalanceHistoryForAllEnabledNetworks(ctx context.Context) error { + accountsDB, err := accounts.NewDB(s.db) + if err != nil { + return err + } + + addresses, err := accountsDB.GetWalletAddresses() + if err != nil { + return err + } + + networks, err := s.networkManager.Get(true) + if err != nil { + return err + } + + for _, address := range addresses { + s.triggerEvent(EventBalanceHistoryUpdateStarted, address, "") + + for _, network := range networks { + tokensForChain, err := s.tokenManager.GetTokens(network.ChainID) + if err != nil { + tokensForChain = make([]*token.Token, 0) + } + tokensForChain = append(tokensForChain, s.tokenManager.ToToken(network)) + + for _, token := range tokensForChain { + if !s.isTokenVisible(token.Symbol) { + continue + } + + var dataSource DataSource + chainClient, err := chain.NewClient(s.rpcClient, network.ChainID) + if err != nil { + return err + } + if token.IsNative() { + dataSource = &chainClientSource{chainClient, token.Symbol} + } else { + dataSource = &tokenChainClientSource{ + chainClientSource: chainClientSource{ + chainClient: chainClient, + currency: token.Symbol, + }, + TokenManager: s.tokenManager, + NetworkManager: s.networkManager, + } + } + + for currentInterval := int(BalanceHistoryAllTime); currentInterval >= int(BalanceHistory7Days); currentInterval-- { + select { + case <-ctx.Done(): + return errors.New("context cancelled") + default: + } + err = s.balance.update(ctx, dataSource, common.Address(address), TimeInterval(currentInterval)) + if err != nil { + log.Warn("Error updating balance history", "chainID", dataSource.ChainID(), "currency", dataSource.Currency(), "address", address.String(), "interval", currentInterval, "err", err) + } + } + } + } + s.triggerEvent(EventBalanceHistoryUpdateFinished, address, "") + } + return nil +} diff --git a/services/wallet/history/service_test.go b/services/wallet/history/service_test.go new file mode 100644 index 000000000..929d5813d --- /dev/null +++ b/services/wallet/history/service_test.go @@ -0,0 +1,164 @@ +package history + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + + "github.com/stretchr/testify/require" +) + +type TestDataPoint struct { + value int64 + timestamp uint64 + blockNumber int64 + chainID uint64 +} + +// generateTestDataForElementCount generates dummy consecutive blocks of data for the same chain_id, address and currency +func prepareTestData(data []TestDataPoint) map[uint64][]*DataPoint { + res := make(map[uint64][]*DataPoint) + for i := 0; i < len(data); i++ { + entry := data[i] + _, found := res[entry.chainID] + if !found { + res[entry.chainID] = make([]*DataPoint, 0) + } + res[entry.chainID] = append(res[entry.chainID], &DataPoint{ + BlockNumber: (*hexutil.Big)(big.NewInt(data[i].blockNumber)), + Timestamp: data[i].timestamp, + Value: (*hexutil.Big)(big.NewInt(data[i].value)), + }) + } + return res +} + +func getBlockNumbers(data []*DataPoint) []int64 { + res := make([]int64, 0) + for _, entry := range data { + res = append(res, entry.BlockNumber.ToInt().Int64()) + } + return res +} + +func getValues(data []*DataPoint) []int64 { + res := make([]int64, 0) + for _, entry := range data { + res = append(res, entry.Value.ToInt().Int64()) + } + return res +} + +func getTimestamps(data []*DataPoint) []int64 { + res := make([]int64, 0) + for _, entry := range data { + res = append(res, int64(entry.Timestamp)) + } + return res +} + +func TestServiceGetBalanceHistory(t *testing.T) { + testData := prepareTestData([]TestDataPoint{ + // Drop 100 + {value: 1, timestamp: 100, blockNumber: 100, chainID: 1}, + {value: 1, timestamp: 100, blockNumber: 100, chainID: 2}, + // Keep 105 + {value: 1, timestamp: 105, blockNumber: 105, chainID: 1}, + {value: 1, timestamp: 105, blockNumber: 105, chainID: 2}, + {value: 1, timestamp: 105, blockNumber: 105, chainID: 3}, + // Drop 110 + {value: 1, timestamp: 105, blockNumber: 105, chainID: 2}, + {value: 1, timestamp: 105, blockNumber: 105, chainID: 3}, + // Keep 115 + {value: 2, timestamp: 115, blockNumber: 115, chainID: 1}, + {value: 2, timestamp: 115, blockNumber: 115, chainID: 2}, + {value: 2, timestamp: 115, blockNumber: 115, chainID: 3}, + // Drop 120 + {value: 1, timestamp: 120, blockNumber: 120, chainID: 3}, + // Keep 125 + {value: 3, timestamp: 125, blockNumber: 125, chainID: 1}, + {value: 3, timestamp: 125, blockNumber: 125, chainID: 2}, + {value: 3, timestamp: 125, blockNumber: 125, chainID: 3}, + // Keep 130 + {value: 4, timestamp: 130, blockNumber: 130, chainID: 1}, + {value: 4, timestamp: 130, blockNumber: 130, chainID: 2}, + {value: 4, timestamp: 130, blockNumber: 130, chainID: 3}, + // Drop 135 + {value: 1, timestamp: 135, blockNumber: 135, chainID: 1}, + }) + + res, err := mergeDataPoints(testData) + require.NoError(t, err) + require.Equal(t, 4, len(res)) + require.Equal(t, []int64{105, 115, 125, 130}, getBlockNumbers(res)) + require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res)) + require.Equal(t, []int64{105, 115, 125, 130}, getTimestamps(res)) +} + +func TestServiceGetBalanceHistoryAllMatch(t *testing.T) { + testData := prepareTestData([]TestDataPoint{ + // Keep 105 + {value: 1, timestamp: 105, blockNumber: 105, chainID: 1}, + {value: 1, timestamp: 105, blockNumber: 105, chainID: 2}, + {value: 1, timestamp: 105, blockNumber: 105, chainID: 3}, + // Keep 115 + {value: 2, timestamp: 115, blockNumber: 115, chainID: 1}, + {value: 2, timestamp: 115, blockNumber: 115, chainID: 2}, + {value: 2, timestamp: 115, blockNumber: 115, chainID: 3}, + // Keep 125 + {value: 3, timestamp: 125, blockNumber: 125, chainID: 1}, + {value: 3, timestamp: 125, blockNumber: 125, chainID: 2}, + {value: 3, timestamp: 125, blockNumber: 125, chainID: 3}, + // Keep 135 + {value: 4, timestamp: 135, blockNumber: 135, chainID: 1}, + {value: 4, timestamp: 135, blockNumber: 135, chainID: 2}, + {value: 4, timestamp: 135, blockNumber: 135, chainID: 3}, + }) + + res, err := mergeDataPoints(testData) + require.NoError(t, err) + require.Equal(t, 4, len(res)) + require.Equal(t, []int64{105, 115, 125, 135}, getBlockNumbers(res)) + require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res)) + require.Equal(t, []int64{105, 115, 125, 135}, getTimestamps(res)) +} + +func TestServiceGetBalanceHistoryOneChain(t *testing.T) { + testData := prepareTestData([]TestDataPoint{ + // Keep 105 + {value: 1, timestamp: 105, blockNumber: 105, chainID: 1}, + // Keep 115 + {value: 2, timestamp: 115, blockNumber: 115, chainID: 1}, + // Keep 125 + {value: 3, timestamp: 125, blockNumber: 125, chainID: 1}, + }) + + res, err := mergeDataPoints(testData) + require.NoError(t, err) + require.Equal(t, 3, len(res)) + require.Equal(t, []int64{105, 115, 125}, getBlockNumbers(res)) + require.Equal(t, []int64{1, 2, 3}, getValues(res)) + require.Equal(t, []int64{105, 115, 125}, getTimestamps(res)) +} + +func TestServiceGetBalanceHistoryDropAll(t *testing.T) { + testData := prepareTestData([]TestDataPoint{ + {value: 1, timestamp: 100, blockNumber: 100, chainID: 1}, + {value: 1, timestamp: 100, blockNumber: 101, chainID: 2}, + {value: 1, timestamp: 100, blockNumber: 102, chainID: 3}, + {value: 1, timestamp: 100, blockNumber: 103, chainID: 4}, + }) + + res, err := mergeDataPoints(testData) + require.NoError(t, err) + require.Equal(t, 0, len(res)) +} + +func TestServiceGetBalanceHistoryEmptyDB(t *testing.T) { + testData := prepareTestData([]TestDataPoint{}) + + res, err := mergeDataPoints(testData) + require.NoError(t, err) + require.Equal(t, 0, len(res)) +} diff --git a/services/wallet/service.go b/services/wallet/service.go index 3459f6fc7..56d41f86c 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -16,6 +16,7 @@ import ( "github.com/status-im/status-go/services/ens" "github.com/status-im/status-go/services/stickers" "github.com/status-im/status-go/services/wallet/chain" + "github.com/status-im/status-go/services/wallet/history" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" "github.com/status-im/status-go/services/wallet/walletevent" @@ -55,6 +56,7 @@ func NewService( cryptoCompare := NewCryptoCompare() priceManager := NewPriceManager(db, cryptoCompare) reader := NewReader(rpcClient, tokenManager, priceManager, cryptoCompare, accountsDB, walletFeed) + history := history.NewService(db, walletFeed, rpcClient, tokenManager) return &Service{ db: db, accountsDB: accountsDB, @@ -75,6 +77,7 @@ func NewService( signals: signals, reader: reader, cryptoCompare: cryptoCompare, + history: history, } } @@ -100,6 +103,7 @@ type Service struct { signals *walletevent.SignalsTransmitter reader *Reader cryptoCompare *CryptoCompare + history *history.Service } // Start signals transmitter. @@ -121,6 +125,7 @@ func (s *Service) Stop() error { s.signals.Stop() s.transferController.Stop() s.reader.Stop() + s.history.Stop() s.started = false log.Info("wallet stopped") return nil diff --git a/services/wallet/token/token.go b/services/wallet/token/token.go index d2ad8f5d6..01aad13f5 100644 --- a/services/wallet/token/token.go +++ b/services/wallet/token/token.go @@ -371,6 +371,18 @@ func (tm *Manager) GetTokenBalance(ctx context.Context, client *chain.Client, ac }, account) } +func (tm *Manager) GetTokenBalanceAt(ctx context.Context, client *chain.Client, account common.Address, token common.Address, blockNumber *big.Int) (*big.Int, error) { + caller, err := ierc20.NewIERC20Caller(token, client) + if err != nil { + return nil, err + } + + return caller.BalanceOf(&bind.CallOpts{ + Context: ctx, + BlockNumber: blockNumber, + }, account) +} + func (tm *Manager) GetChainBalance(ctx context.Context, client *chain.Client, account common.Address) (*big.Int, error) { return client.BalanceAt(ctx, account, nil) } diff --git a/services/wallet/transfer/balance_cache.go b/services/wallet/transfer/balance_cache.go index e6b252f94..7aa18f929 100644 --- a/services/wallet/transfer/balance_cache.go +++ b/services/wallet/transfer/balance_cache.go @@ -15,12 +15,6 @@ type nonceRange struct { min *big.Int } -// balanceHistoryCache is used temporary until we cache balance history in DB -type balanceHistoryCache struct { - lastBlockNo *big.Int - lastBlockTimestamp int64 -} - type balanceCache struct { // balances maps an address to a map of a block number and the balance of this particular address balances map[common.Address]map[*big.Int]*big.Int @@ -28,7 +22,6 @@ type balanceCache struct { nonceRanges map[common.Address]map[int64]nonceRange sortedRanges map[common.Address][]nonceRange rw sync.RWMutex - history *balanceHistoryCache } type BalanceCache interface { diff --git a/services/wallet/transfer/controller.go b/services/wallet/transfer/controller.go index 5a99095c0..fa043d2c7 100644 --- a/services/wallet/transfer/controller.go +++ b/services/wallet/transfer/controller.go @@ -4,10 +4,8 @@ import ( "context" "database/sql" "math/big" - "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" @@ -291,125 +289,3 @@ func (c *Controller) GetCachedBalances(ctx context.Context, chainID uint64, addr return blocksToViews(result), nil } - -type BalanceState struct { - Value *hexutil.Big `json:"value"` - Timestamp uint64 `json:"time"` -} - -type BalanceHistoryTimeInterval int - -const ( - BalanceHistory7Hours BalanceHistoryTimeInterval = iota + 1 - BalanceHistory1Month - BalanceHistory6Months - BalanceHistory1Year - BalanceHistoryAllTime -) - -var balanceHistoryTimeIntervalToHoursPerStep = map[BalanceHistoryTimeInterval]int64{ - BalanceHistory7Hours: 2, - BalanceHistory1Month: 12, - BalanceHistory6Months: (24 * 7) / 2, - BalanceHistory1Year: 24 * 7, -} - -var balanceHistoryTimeIntervalToSampleNo = map[BalanceHistoryTimeInterval]int64{ - BalanceHistory7Hours: 84, - BalanceHistory1Month: 60, - BalanceHistory6Months: 52, - BalanceHistory1Year: 52, - BalanceHistoryAllTime: 50, -} - -// GetBalanceHistory expect a time precision of +/- average block time (~12s) -// implementation relies that a block has constant time length to save block header requests -func (c *Controller) GetBalanceHistory(ctx context.Context, chainID uint64, address common.Address, timeInterval BalanceHistoryTimeInterval) ([]BalanceState, error) { - chainClient, err := chain.NewClient(c.rpcClient, chainID) - if err != nil { - return nil, err - } - - if c.balanceCache == nil { - c.balanceCache = newBalanceCache() - } - - if c.balanceCache.history == nil { - c.balanceCache.history = new(balanceHistoryCache) - } - - currentTimestamp := time.Now().Unix() - lastBlockNo := big.NewInt(0) - var lastBlockTimestamp int64 - if (currentTimestamp - c.balanceCache.history.lastBlockTimestamp) >= (12 * 60 * 60) { - lastBlock, err := chainClient.BlockByNumber(ctx, nil) - if err != nil { - return nil, err - } - lastBlockNo.Set(lastBlock.Number()) - lastBlockTimestamp = int64(lastBlock.Time()) - c.balanceCache.history.lastBlockNo = big.NewInt(0).Set(lastBlockNo) - c.balanceCache.history.lastBlockTimestamp = lastBlockTimestamp - } else { - lastBlockNo.Set(c.balanceCache.history.lastBlockNo) - lastBlockTimestamp = c.balanceCache.history.lastBlockTimestamp - } - - initialBlock, err := chainClient.BlockByNumber(ctx, big.NewInt(1)) - if err != nil { - return nil, err - } - initialBlockNo := big.NewInt(0).Set(initialBlock.Number()) - initialBlockTimestamp := int64(initialBlock.Time()) - - allTimeBlockCount := big.NewInt(0).Sub(lastBlockNo, initialBlockNo) - allTimeInterval := lastBlockTimestamp - initialBlockTimestamp - - // Expected to be around 12 - blockDuration := float64(allTimeInterval) / float64(allTimeBlockCount.Int64()) - - lastBlockTime := time.Unix(lastBlockTimestamp, 0) - // Snap to the beginning of the day or half day which is the closest to the last block - hour := 0 - if lastBlockTime.Hour() >= 12 { - hour = 12 - } - lastTime := time.Date(lastBlockTime.Year(), lastBlockTime.Month(), lastBlockTime.Day(), hour, 0, 0, 0, lastBlockTime.Location()) - endBlockTimestamp := lastTime.Unix() - blockGaps := big.NewInt(int64(float64(lastBlockTimestamp-endBlockTimestamp) / blockDuration)) - endBlockNo := big.NewInt(0).Sub(lastBlockNo, blockGaps) - - totalBlockCount, startTimestamp := int64(0), int64(0) - if timeInterval == BalanceHistoryAllTime { - startTimestamp = initialBlockTimestamp - totalBlockCount = endBlockNo.Int64() - } else { - secondsToNow := balanceHistoryTimeIntervalToHoursPerStep[timeInterval] * 3600 * (balanceHistoryTimeIntervalToSampleNo[timeInterval]) - startTimestamp = endBlockTimestamp - secondsToNow - totalBlockCount = int64(float64(secondsToNow) / blockDuration) - } - blocksInStep := totalBlockCount / (balanceHistoryTimeIntervalToSampleNo[timeInterval]) - stepDuration := int64(float64(blocksInStep) * blockDuration) - - points := make([]BalanceState, 0) - - nextBlockNumber := big.NewInt(0).Set(endBlockNo) - nextTimestamp := endBlockTimestamp - for nextTimestamp >= startTimestamp && nextBlockNumber.Cmp(initialBlockNo) >= 0 && nextBlockNumber.Cmp(big.NewInt(0)) > 0 { - newBlockNo := big.NewInt(0).Set(nextBlockNumber) - currentBalance, err := c.balanceCache.BalanceAt(ctx, chainClient, address, newBlockNo) - if err != nil { - return nil, err - } - - var currentBalanceState BalanceState - currentBalanceState.Value = (*hexutil.Big)(currentBalance) - currentBalanceState.Timestamp = uint64(nextTimestamp) - points = append([]BalanceState{currentBalanceState}, points...) - - // decrease block number and timestamp - nextTimestamp -= stepDuration - nextBlockNumber.Sub(nextBlockNumber, big.NewInt(blocksInStep)) - } - return points, nil -}