mirror of
https://github.com/status-im/status-go.git
synced 2025-02-22 11:48:31 +00:00
fix: merge balance history using block time
This change improves on the previous implementation which used the block number that doesn't work with incompatible blockchains e.g. L1 vs. L2 Closes: #9205
This commit is contained in:
parent
c8994fe175
commit
90b39eeb41
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math/big"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -48,6 +49,8 @@ type Service struct {
|
||||
visibleTokenSymbolsMutex sync.Mutex // Protects access to visibleSymbols
|
||||
}
|
||||
|
||||
type chainIdentity uint64
|
||||
|
||||
func NewService(db *sql.DB, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager) *Service {
|
||||
return &Service{
|
||||
balance: NewBalance(NewBalanceDB(db)),
|
||||
@ -198,82 +201,180 @@ func (src *tokenChainClientSource) BalanceAt(ctx context.Context, account common
|
||||
// GetBalanceHistory returns token count balance
|
||||
// TODO: fetch token to FIAT exchange rates and return FIAT balance
|
||||
func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, currency string, endTimestamp int64, timeInterval TimeInterval) ([]*DataPoint, error) {
|
||||
allData := make(map[uint64][]*DataPoint)
|
||||
allData := make(map[chainIdentity][]*DataPoint)
|
||||
for _, chainID := range chainIDs {
|
||||
data, err := s.balance.get(ctx, chainID, currency, address, endTimestamp, timeInterval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) > 0 {
|
||||
allData[chainID] = data
|
||||
allData[chainIdentity(chainID)] = data
|
||||
}
|
||||
}
|
||||
|
||||
return mergeDataPoints(allData)
|
||||
return mergeDataPoints(allData, strideDuration(timeInterval))
|
||||
}
|
||||
|
||||
// mergeDataPoints merges same block numbers from different chains which are incompatible due to different timelines
|
||||
// TODO: use time-based intervals instead of block numbers
|
||||
func mergeDataPoints(data map[uint64][]*DataPoint) ([]*DataPoint, error) {
|
||||
// mergeDataPoints merges close in time block numbers. Drops the ones that are not in a stride duration
|
||||
// this should improve merging balance data from different chains which are incompatible due to different timelines
|
||||
// and block length
|
||||
func mergeDataPoints(data map[chainIdentity][]*DataPoint, stride time.Duration) ([]*DataPoint, error) {
|
||||
if len(data) == 0 {
|
||||
return make([]*DataPoint, 0), nil
|
||||
}
|
||||
|
||||
pos := make(map[uint64]int)
|
||||
pos := make(map[chainIdentity]int)
|
||||
for k := range data {
|
||||
pos[k] = 0
|
||||
}
|
||||
|
||||
res := make([]*DataPoint, 0)
|
||||
done := false
|
||||
for !done {
|
||||
var minNo *big.Int
|
||||
var timestamp uint64
|
||||
// Take the smallest block number
|
||||
for k := range data {
|
||||
blockNo := new(big.Int).Set(data[k][pos[k]].BlockNumber.ToInt())
|
||||
if minNo == nil {
|
||||
minNo = new(big.Int).Set(blockNo)
|
||||
// We use it only if we have a full match
|
||||
timestamp = data[k][pos[k]].Timestamp
|
||||
} else if blockNo.Cmp(minNo) < 0 {
|
||||
minNo.Set(blockNo)
|
||||
}
|
||||
}
|
||||
// If all chains have the same block number sum it; also increment the processed position
|
||||
sumOfAll := big.NewInt(0)
|
||||
for k := range data {
|
||||
cur := data[k][pos[k]]
|
||||
if cur.BlockNumber.ToInt().Cmp(minNo) == 0 {
|
||||
pos[k]++
|
||||
if sumOfAll != nil {
|
||||
sumOfAll.Add(sumOfAll, cur.Value.ToInt())
|
||||
}
|
||||
} else {
|
||||
sumOfAll = nil
|
||||
}
|
||||
}
|
||||
// If sum of all make sense add it to the result otherwise ignore it
|
||||
if sumOfAll != nil {
|
||||
// TODO: convert to FIAT value
|
||||
res = append(res, &DataPoint{
|
||||
BlockNumber: (*hexutil.Big)(minNo),
|
||||
Timestamp: timestamp,
|
||||
Value: (*hexutil.Big)(sumOfAll),
|
||||
})
|
||||
}
|
||||
strideStart := findFirstStrideWindow(data, stride)
|
||||
for {
|
||||
strideEnd := strideStart + int64(stride.Seconds())
|
||||
// - Gather all points in the stride window starting with current pos
|
||||
var strideIdentities map[chainIdentity][]timeIdentity
|
||||
strideIdentities, pos = dataInStrideWindowAndNextPos(data, pos, strideEnd)
|
||||
|
||||
// Check if we reached the end of any chain
|
||||
// Check if all chains have data
|
||||
strideComplete := true
|
||||
for k := range data {
|
||||
if pos[k] == len(data[k]) {
|
||||
done = true
|
||||
_, strideComplete = strideIdentities[k]
|
||||
if !strideComplete {
|
||||
break
|
||||
}
|
||||
}
|
||||
if strideComplete {
|
||||
chainMaxBalance := make(map[chainIdentity]*DataPoint)
|
||||
for chainID, identities := range strideIdentities {
|
||||
for _, identity := range identities {
|
||||
_, exists := chainMaxBalance[chainID]
|
||||
if exists && (*big.Int)(identity.dataPoint(data).Value).Cmp((*big.Int)(chainMaxBalance[chainID].Value)) <= 0 {
|
||||
continue
|
||||
}
|
||||
chainMaxBalance[chainID] = identity.dataPoint(data)
|
||||
}
|
||||
}
|
||||
balance := big.NewInt(0)
|
||||
for _, chainBalance := range chainMaxBalance {
|
||||
balance.Add(balance, (*big.Int)(chainBalance.Value))
|
||||
}
|
||||
res = append(res, &DataPoint{
|
||||
Timestamp: uint64(strideEnd),
|
||||
Value: (*hexutil.Big)(balance),
|
||||
BlockNumber: (*hexutil.Big)(getBlockID(chainMaxBalance)),
|
||||
})
|
||||
}
|
||||
|
||||
if allPastEnd(data, pos) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
strideStart = strideEnd
|
||||
}
|
||||
}
|
||||
|
||||
func getBlockID(chainBalance map[chainIdentity]*DataPoint) *big.Int {
|
||||
var res *big.Int
|
||||
for _, balance := range chainBalance {
|
||||
if res == nil {
|
||||
res = new(big.Int).Set(balance.BlockNumber.ToInt())
|
||||
} else if res.Cmp(balance.BlockNumber.ToInt()) != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type timeIdentity struct {
|
||||
chain chainIdentity
|
||||
index int
|
||||
}
|
||||
|
||||
func (i timeIdentity) dataPoint(data map[chainIdentity][]*DataPoint) *DataPoint {
|
||||
return data[i.chain][i.index]
|
||||
}
|
||||
|
||||
func (i timeIdentity) atEnd(data map[chainIdentity][]*DataPoint) bool {
|
||||
return (i.index + 1) == len(data[i.chain])
|
||||
}
|
||||
|
||||
func (i timeIdentity) pastEnd(data map[chainIdentity][]*DataPoint) bool {
|
||||
return i.index >= len(data[i.chain])
|
||||
}
|
||||
|
||||
func allPastEnd(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) bool {
|
||||
for chainID := range pos {
|
||||
if !(timeIdentity{chainID, pos[chainID]}).pastEnd(data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// findFirstStrideWindow returns the start of the first stride window
|
||||
// Tried to implement finding an optimal stride window but it was becoming too complicated and not worth it given that it will
|
||||
// potentially save the first and last stride but it is not guaranteed. Current implementation should give good results
|
||||
// as long as the the DataPoints are regular enough
|
||||
func findFirstStrideWindow(data map[chainIdentity][]*DataPoint, stride time.Duration) int64 {
|
||||
pos := make(map[chainIdentity]int)
|
||||
for k := range data {
|
||||
pos[k] = 0
|
||||
}
|
||||
|
||||
// Identify the current oldest and newest block
|
||||
cur := sortTimeAsc(data, pos)
|
||||
return int64(cur[0].dataPoint(data).Timestamp)
|
||||
}
|
||||
|
||||
func copyMap[K comparable, V any](original map[K]V) map[K]V {
|
||||
copy := make(map[K]V, len(original))
|
||||
for key, value := range original {
|
||||
copy[key] = value
|
||||
}
|
||||
return copy
|
||||
}
|
||||
|
||||
// startPos might have indexes past the end of the data for a chain
|
||||
func dataInStrideWindowAndNextPos(data map[chainIdentity][]*DataPoint, startPos map[chainIdentity]int, endT int64) (identities map[chainIdentity][]timeIdentity, nextPos map[chainIdentity]int) {
|
||||
pos := copyMap(startPos)
|
||||
identities = make(map[chainIdentity][]timeIdentity)
|
||||
|
||||
// Identify the current oldest and newest block
|
||||
lastLen := int(-1)
|
||||
for lastLen < len(identities) {
|
||||
lastLen = len(identities)
|
||||
sorted := sortTimeAsc(data, pos)
|
||||
for _, identity := range sorted {
|
||||
if identity.dataPoint(data).Timestamp < uint64(endT) {
|
||||
identities[identity.chain] = append(identities[identity.chain], identity)
|
||||
pos[identity.chain]++
|
||||
}
|
||||
}
|
||||
}
|
||||
return identities, pos
|
||||
}
|
||||
|
||||
// sortTimeAsc expect indexes in pos past the end of the data for a chain
|
||||
func sortTimeAsc(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) []timeIdentity {
|
||||
res := make([]timeIdentity, 0, len(data))
|
||||
for k := range data {
|
||||
identity := timeIdentity{
|
||||
chain: k,
|
||||
index: pos[k],
|
||||
}
|
||||
if !identity.pastEnd(data) {
|
||||
res = append(res, identity)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].dataPoint(data).Timestamp < res[j].dataPoint(data).Timestamp
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
// updateBalanceHistoryForAllEnabledNetworks iterates over all enabled and supported networks for the s.visibleTokenSymbol
|
||||
// and updates the balance history for the given address
|
||||
//
|
||||
|
@ -3,6 +3,7 @@ package history
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||
|
||||
@ -13,12 +14,12 @@ type TestDataPoint struct {
|
||||
value int64
|
||||
timestamp uint64
|
||||
blockNumber int64
|
||||
chainID uint64
|
||||
chainID chainIdentity
|
||||
}
|
||||
|
||||
// generateTestDataForElementCount generates dummy consecutive blocks of data for the same chain_id, address and currency
|
||||
func prepareTestData(data []TestDataPoint) map[uint64][]*DataPoint {
|
||||
res := make(map[uint64][]*DataPoint)
|
||||
func prepareTestData(data []TestDataPoint) map[chainIdentity][]*DataPoint {
|
||||
res := make(map[chainIdentity][]*DataPoint)
|
||||
for i := 0; i < len(data); i++ {
|
||||
entry := data[i]
|
||||
_, found := res[entry.chainID]
|
||||
@ -34,11 +35,16 @@ func prepareTestData(data []TestDataPoint) map[uint64][]*DataPoint {
|
||||
return res
|
||||
}
|
||||
|
||||
// getBlockNumbers returns -1 if block number is nil
|
||||
func getBlockNumbers(data []*DataPoint) []int64 {
|
||||
res := make([]int64, 0)
|
||||
for _, entry := range data {
|
||||
if entry.BlockNumber == nil {
|
||||
res = append(res, -1)
|
||||
} else {
|
||||
res = append(res, entry.BlockNumber.ToInt().Int64())
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
@ -58,7 +64,8 @@ func getTimestamps(data []*DataPoint) []int64 {
|
||||
return res
|
||||
}
|
||||
|
||||
func TestServiceGetBalanceHistory(t *testing.T) {
|
||||
func TestServiceMergeDataPoints(t *testing.T) {
|
||||
strideDuration := 5 * time.Second
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
// Drop 100
|
||||
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
|
||||
@ -88,15 +95,16 @@ func TestServiceGetBalanceHistory(t *testing.T) {
|
||||
{value: 1, timestamp: 135, blockNumber: 135, chainID: 1},
|
||||
})
|
||||
|
||||
res, err := mergeDataPoints(testData)
|
||||
res, err := mergeDataPoints(testData, strideDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, len(res))
|
||||
require.Equal(t, []int64{105, 115, 125, 130}, getBlockNumbers(res))
|
||||
require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res))
|
||||
require.Equal(t, []int64{105, 115, 125, 130}, getTimestamps(res))
|
||||
require.Equal(t, []int64{110, 120, 130, 135}, getTimestamps(res))
|
||||
}
|
||||
|
||||
func TestServiceGetBalanceHistoryAllMatch(t *testing.T) {
|
||||
func TestServiceMergeDataPointsAllMatch(t *testing.T) {
|
||||
strideDuration := 10 * time.Second
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
// Keep 105
|
||||
{value: 1, timestamp: 105, blockNumber: 105, chainID: 1},
|
||||
@ -116,15 +124,16 @@ func TestServiceGetBalanceHistoryAllMatch(t *testing.T) {
|
||||
{value: 4, timestamp: 135, blockNumber: 135, chainID: 3},
|
||||
})
|
||||
|
||||
res, err := mergeDataPoints(testData)
|
||||
res, err := mergeDataPoints(testData, strideDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, len(res))
|
||||
require.Equal(t, []int64{105, 115, 125, 135}, getBlockNumbers(res))
|
||||
require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res))
|
||||
require.Equal(t, []int64{105, 115, 125, 135}, getTimestamps(res))
|
||||
require.Equal(t, []int64{115, 125, 135, 145}, getTimestamps(res))
|
||||
}
|
||||
|
||||
func TestServiceGetBalanceHistoryOneChain(t *testing.T) {
|
||||
func TestServiceMergeDataPointsOneChain(t *testing.T) {
|
||||
strideDuration := 10 * time.Second
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
// Keep 105
|
||||
{value: 1, timestamp: 105, blockNumber: 105, chainID: 1},
|
||||
@ -134,31 +143,74 @@ func TestServiceGetBalanceHistoryOneChain(t *testing.T) {
|
||||
{value: 3, timestamp: 125, blockNumber: 125, chainID: 1},
|
||||
})
|
||||
|
||||
res, err := mergeDataPoints(testData)
|
||||
res, err := mergeDataPoints(testData, strideDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, len(res))
|
||||
require.Equal(t, []int64{105, 115, 125}, getBlockNumbers(res))
|
||||
require.Equal(t, []int64{1, 2, 3}, getValues(res))
|
||||
require.Equal(t, []int64{105, 115, 125}, getTimestamps(res))
|
||||
require.Equal(t, []int64{115, 125, 135}, getTimestamps(res))
|
||||
}
|
||||
|
||||
func TestServiceGetBalanceHistoryDropAll(t *testing.T) {
|
||||
func TestServiceMergeDataPointsDropAll(t *testing.T) {
|
||||
strideDuration := 10 * time.Second
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
|
||||
{value: 1, timestamp: 100, blockNumber: 101, chainID: 2},
|
||||
{value: 1, timestamp: 100, blockNumber: 102, chainID: 3},
|
||||
{value: 1, timestamp: 100, blockNumber: 103, chainID: 4},
|
||||
{value: 1, timestamp: 110, blockNumber: 110, chainID: 2},
|
||||
{value: 1, timestamp: 120, blockNumber: 120, chainID: 3},
|
||||
{value: 1, timestamp: 130, blockNumber: 130, chainID: 4},
|
||||
})
|
||||
|
||||
res, err := mergeDataPoints(testData)
|
||||
res, err := mergeDataPoints(testData, strideDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(res))
|
||||
}
|
||||
|
||||
func TestServiceGetBalanceHistoryEmptyDB(t *testing.T) {
|
||||
func TestServiceMergeDataPointsEmptyDB(t *testing.T) {
|
||||
testData := prepareTestData([]TestDataPoint{})
|
||||
|
||||
res, err := mergeDataPoints(testData)
|
||||
strideDuration := 10 * time.Second
|
||||
|
||||
res, err := mergeDataPoints(testData, strideDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(res))
|
||||
}
|
||||
|
||||
func TestServiceFindFirstStrideWindowFirstForAllChainInOneStride(t *testing.T) {
|
||||
strideDuration := 10 * time.Second
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
{value: 1, timestamp: 103, blockNumber: 101, chainID: 2},
|
||||
{value: 1, timestamp: 106, blockNumber: 102, chainID: 3},
|
||||
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
|
||||
{value: 1, timestamp: 110, blockNumber: 103, chainID: 1},
|
||||
{value: 1, timestamp: 110, blockNumber: 103, chainID: 2},
|
||||
})
|
||||
|
||||
startTimestamp := findFirstStrideWindow(testData, strideDuration)
|
||||
require.Equal(t, testData[1][0].Timestamp, uint64(startTimestamp))
|
||||
}
|
||||
|
||||
func TestServiceSortTimeAsc(t *testing.T) {
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
{value: 3, timestamp: 103, blockNumber: 103, chainID: 3},
|
||||
{value: 4, timestamp: 104, blockNumber: 104, chainID: 4},
|
||||
{value: 2, timestamp: 102, blockNumber: 102, chainID: 2},
|
||||
{value: 1, timestamp: 101, blockNumber: 101, chainID: 1},
|
||||
})
|
||||
|
||||
sorted := sortTimeAsc(testData, map[chainIdentity]int{4: 0, 3: 0, 2: 0, 1: 0})
|
||||
require.Equal(t, []timeIdentity{{1, 0}, {2, 0}, {3, 0}, {4, 0}}, sorted)
|
||||
}
|
||||
|
||||
func TestServiceAtEnd(t *testing.T) {
|
||||
testData := prepareTestData([]TestDataPoint{
|
||||
{value: 1, timestamp: 101, blockNumber: 101, chainID: 1},
|
||||
{value: 1, timestamp: 101, blockNumber: 101, chainID: 2},
|
||||
{value: 1, timestamp: 102, blockNumber: 102, chainID: 1},
|
||||
})
|
||||
|
||||
sorted := sortTimeAsc(testData, map[chainIdentity]int{1: 0, 2: 0})
|
||||
require.False(t, sorted[0].atEnd(testData))
|
||||
require.True(t, sorted[1].atEnd(testData))
|
||||
sorted = sortTimeAsc(testData, map[chainIdentity]int{1: 1, 2: 0})
|
||||
require.True(t, sorted[1].atEnd(testData))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user