fix: merge balance history using block time

This change improves on the previous implementation which used
the block number that doesn't work with incompatible blockchains e.g. L1 vs. L2

Closes: #9205
This commit is contained in:
Stefan 2023-01-20 20:06:50 +04:00 committed by Stefan Dunca
parent c8994fe175
commit 90b39eeb41
2 changed files with 221 additions and 68 deletions

View File

@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"math/big"
"sort"
"sync"
"time"
@ -48,6 +49,8 @@ type Service struct {
visibleTokenSymbolsMutex sync.Mutex // Protects access to visibleSymbols
}
type chainIdentity uint64
func NewService(db *sql.DB, eventFeed *event.Feed, rpcClient *statusrpc.Client, tokenManager *token.Manager) *Service {
return &Service{
balance: NewBalance(NewBalanceDB(db)),
@ -198,80 +201,178 @@ func (src *tokenChainClientSource) BalanceAt(ctx context.Context, account common
// GetBalanceHistory returns token count balance
// TODO: fetch token to FIAT exchange rates and return FIAT balance
func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, currency string, endTimestamp int64, timeInterval TimeInterval) ([]*DataPoint, error) {
allData := make(map[uint64][]*DataPoint)
allData := make(map[chainIdentity][]*DataPoint)
for _, chainID := range chainIDs {
data, err := s.balance.get(ctx, chainID, currency, address, endTimestamp, timeInterval)
if err != nil {
return nil, err
}
if len(data) > 0 {
allData[chainID] = data
allData[chainIdentity(chainID)] = data
}
}
return mergeDataPoints(allData)
return mergeDataPoints(allData, strideDuration(timeInterval))
}
// mergeDataPoints merges same block numbers from different chains which are incompatible due to different timelines
// TODO: use time-based intervals instead of block numbers
func mergeDataPoints(data map[uint64][]*DataPoint) ([]*DataPoint, error) {
// mergeDataPoints merges close in time block numbers. Drops the ones that are not in a stride duration
// this should improve merging balance data from different chains which are incompatible due to different timelines
// and block length
func mergeDataPoints(data map[chainIdentity][]*DataPoint, stride time.Duration) ([]*DataPoint, error) {
if len(data) == 0 {
return make([]*DataPoint, 0), nil
}
pos := make(map[uint64]int)
pos := make(map[chainIdentity]int)
for k := range data {
pos[k] = 0
}
res := make([]*DataPoint, 0)
done := false
for !done {
var minNo *big.Int
var timestamp uint64
// Take the smallest block number
for k := range data {
blockNo := new(big.Int).Set(data[k][pos[k]].BlockNumber.ToInt())
if minNo == nil {
minNo = new(big.Int).Set(blockNo)
// We use it only if we have a full match
timestamp = data[k][pos[k]].Timestamp
} else if blockNo.Cmp(minNo) < 0 {
minNo.Set(blockNo)
}
}
// If all chains have the same block number sum it; also increment the processed position
sumOfAll := big.NewInt(0)
for k := range data {
cur := data[k][pos[k]]
if cur.BlockNumber.ToInt().Cmp(minNo) == 0 {
pos[k]++
if sumOfAll != nil {
sumOfAll.Add(sumOfAll, cur.Value.ToInt())
}
} else {
sumOfAll = nil
}
}
// If sum of all make sense add it to the result otherwise ignore it
if sumOfAll != nil {
// TODO: convert to FIAT value
res = append(res, &DataPoint{
BlockNumber: (*hexutil.Big)(minNo),
Timestamp: timestamp,
Value: (*hexutil.Big)(sumOfAll),
})
}
strideStart := findFirstStrideWindow(data, stride)
for {
strideEnd := strideStart + int64(stride.Seconds())
// - Gather all points in the stride window starting with current pos
var strideIdentities map[chainIdentity][]timeIdentity
strideIdentities, pos = dataInStrideWindowAndNextPos(data, pos, strideEnd)
// Check if we reached the end of any chain
// Check if all chains have data
strideComplete := true
for k := range data {
if pos[k] == len(data[k]) {
done = true
_, strideComplete = strideIdentities[k]
if !strideComplete {
break
}
}
if strideComplete {
chainMaxBalance := make(map[chainIdentity]*DataPoint)
for chainID, identities := range strideIdentities {
for _, identity := range identities {
_, exists := chainMaxBalance[chainID]
if exists && (*big.Int)(identity.dataPoint(data).Value).Cmp((*big.Int)(chainMaxBalance[chainID].Value)) <= 0 {
continue
}
chainMaxBalance[chainID] = identity.dataPoint(data)
}
}
balance := big.NewInt(0)
for _, chainBalance := range chainMaxBalance {
balance.Add(balance, (*big.Int)(chainBalance.Value))
}
res = append(res, &DataPoint{
Timestamp: uint64(strideEnd),
Value: (*hexutil.Big)(balance),
BlockNumber: (*hexutil.Big)(getBlockID(chainMaxBalance)),
})
}
if allPastEnd(data, pos) {
return res, nil
}
strideStart = strideEnd
}
return res, nil
}
func getBlockID(chainBalance map[chainIdentity]*DataPoint) *big.Int {
var res *big.Int
for _, balance := range chainBalance {
if res == nil {
res = new(big.Int).Set(balance.BlockNumber.ToInt())
} else if res.Cmp(balance.BlockNumber.ToInt()) != 0 {
return nil
}
}
return res
}
type timeIdentity struct {
chain chainIdentity
index int
}
func (i timeIdentity) dataPoint(data map[chainIdentity][]*DataPoint) *DataPoint {
return data[i.chain][i.index]
}
func (i timeIdentity) atEnd(data map[chainIdentity][]*DataPoint) bool {
return (i.index + 1) == len(data[i.chain])
}
func (i timeIdentity) pastEnd(data map[chainIdentity][]*DataPoint) bool {
return i.index >= len(data[i.chain])
}
func allPastEnd(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) bool {
for chainID := range pos {
if !(timeIdentity{chainID, pos[chainID]}).pastEnd(data) {
return false
}
}
return true
}
// findFirstStrideWindow returns the start of the first stride window
// Tried to implement finding an optimal stride window but it was becoming too complicated and not worth it given that it will
// potentially save the first and last stride but it is not guaranteed. Current implementation should give good results
// as long as the the DataPoints are regular enough
func findFirstStrideWindow(data map[chainIdentity][]*DataPoint, stride time.Duration) int64 {
pos := make(map[chainIdentity]int)
for k := range data {
pos[k] = 0
}
// Identify the current oldest and newest block
cur := sortTimeAsc(data, pos)
return int64(cur[0].dataPoint(data).Timestamp)
}
func copyMap[K comparable, V any](original map[K]V) map[K]V {
copy := make(map[K]V, len(original))
for key, value := range original {
copy[key] = value
}
return copy
}
// startPos might have indexes past the end of the data for a chain
func dataInStrideWindowAndNextPos(data map[chainIdentity][]*DataPoint, startPos map[chainIdentity]int, endT int64) (identities map[chainIdentity][]timeIdentity, nextPos map[chainIdentity]int) {
pos := copyMap(startPos)
identities = make(map[chainIdentity][]timeIdentity)
// Identify the current oldest and newest block
lastLen := int(-1)
for lastLen < len(identities) {
lastLen = len(identities)
sorted := sortTimeAsc(data, pos)
for _, identity := range sorted {
if identity.dataPoint(data).Timestamp < uint64(endT) {
identities[identity.chain] = append(identities[identity.chain], identity)
pos[identity.chain]++
}
}
}
return identities, pos
}
// sortTimeAsc expect indexes in pos past the end of the data for a chain
func sortTimeAsc(data map[chainIdentity][]*DataPoint, pos map[chainIdentity]int) []timeIdentity {
res := make([]timeIdentity, 0, len(data))
for k := range data {
identity := timeIdentity{
chain: k,
index: pos[k],
}
if !identity.pastEnd(data) {
res = append(res, identity)
}
}
sort.Slice(res, func(i, j int) bool {
return res[i].dataPoint(data).Timestamp < res[j].dataPoint(data).Timestamp
})
return res
}
// updateBalanceHistoryForAllEnabledNetworks iterates over all enabled and supported networks for the s.visibleTokenSymbol

View File

@ -3,6 +3,7 @@ package history
import (
"math/big"
"testing"
"time"
"github.com/ethereum/go-ethereum/common/hexutil"
@ -13,12 +14,12 @@ type TestDataPoint struct {
value int64
timestamp uint64
blockNumber int64
chainID uint64
chainID chainIdentity
}
// generateTestDataForElementCount generates dummy consecutive blocks of data for the same chain_id, address and currency
func prepareTestData(data []TestDataPoint) map[uint64][]*DataPoint {
res := make(map[uint64][]*DataPoint)
func prepareTestData(data []TestDataPoint) map[chainIdentity][]*DataPoint {
res := make(map[chainIdentity][]*DataPoint)
for i := 0; i < len(data); i++ {
entry := data[i]
_, found := res[entry.chainID]
@ -34,10 +35,15 @@ func prepareTestData(data []TestDataPoint) map[uint64][]*DataPoint {
return res
}
// getBlockNumbers returns -1 if block number is nil
func getBlockNumbers(data []*DataPoint) []int64 {
res := make([]int64, 0)
for _, entry := range data {
res = append(res, entry.BlockNumber.ToInt().Int64())
if entry.BlockNumber == nil {
res = append(res, -1)
} else {
res = append(res, entry.BlockNumber.ToInt().Int64())
}
}
return res
}
@ -58,7 +64,8 @@ func getTimestamps(data []*DataPoint) []int64 {
return res
}
func TestServiceGetBalanceHistory(t *testing.T) {
func TestServiceMergeDataPoints(t *testing.T) {
strideDuration := 5 * time.Second
testData := prepareTestData([]TestDataPoint{
// Drop 100
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
@ -88,15 +95,16 @@ func TestServiceGetBalanceHistory(t *testing.T) {
{value: 1, timestamp: 135, blockNumber: 135, chainID: 1},
})
res, err := mergeDataPoints(testData)
res, err := mergeDataPoints(testData, strideDuration)
require.NoError(t, err)
require.Equal(t, 4, len(res))
require.Equal(t, []int64{105, 115, 125, 130}, getBlockNumbers(res))
require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res))
require.Equal(t, []int64{105, 115, 125, 130}, getTimestamps(res))
require.Equal(t, []int64{110, 120, 130, 135}, getTimestamps(res))
}
func TestServiceGetBalanceHistoryAllMatch(t *testing.T) {
func TestServiceMergeDataPointsAllMatch(t *testing.T) {
strideDuration := 10 * time.Second
testData := prepareTestData([]TestDataPoint{
// Keep 105
{value: 1, timestamp: 105, blockNumber: 105, chainID: 1},
@ -116,15 +124,16 @@ func TestServiceGetBalanceHistoryAllMatch(t *testing.T) {
{value: 4, timestamp: 135, blockNumber: 135, chainID: 3},
})
res, err := mergeDataPoints(testData)
res, err := mergeDataPoints(testData, strideDuration)
require.NoError(t, err)
require.Equal(t, 4, len(res))
require.Equal(t, []int64{105, 115, 125, 135}, getBlockNumbers(res))
require.Equal(t, []int64{3, 3 * 2, 3 * 3, 3 * 4}, getValues(res))
require.Equal(t, []int64{105, 115, 125, 135}, getTimestamps(res))
require.Equal(t, []int64{115, 125, 135, 145}, getTimestamps(res))
}
func TestServiceGetBalanceHistoryOneChain(t *testing.T) {
func TestServiceMergeDataPointsOneChain(t *testing.T) {
strideDuration := 10 * time.Second
testData := prepareTestData([]TestDataPoint{
// Keep 105
{value: 1, timestamp: 105, blockNumber: 105, chainID: 1},
@ -134,31 +143,74 @@ func TestServiceGetBalanceHistoryOneChain(t *testing.T) {
{value: 3, timestamp: 125, blockNumber: 125, chainID: 1},
})
res, err := mergeDataPoints(testData)
res, err := mergeDataPoints(testData, strideDuration)
require.NoError(t, err)
require.Equal(t, 3, len(res))
require.Equal(t, []int64{105, 115, 125}, getBlockNumbers(res))
require.Equal(t, []int64{1, 2, 3}, getValues(res))
require.Equal(t, []int64{105, 115, 125}, getTimestamps(res))
require.Equal(t, []int64{115, 125, 135}, getTimestamps(res))
}
func TestServiceGetBalanceHistoryDropAll(t *testing.T) {
func TestServiceMergeDataPointsDropAll(t *testing.T) {
strideDuration := 10 * time.Second
testData := prepareTestData([]TestDataPoint{
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
{value: 1, timestamp: 100, blockNumber: 101, chainID: 2},
{value: 1, timestamp: 100, blockNumber: 102, chainID: 3},
{value: 1, timestamp: 100, blockNumber: 103, chainID: 4},
{value: 1, timestamp: 110, blockNumber: 110, chainID: 2},
{value: 1, timestamp: 120, blockNumber: 120, chainID: 3},
{value: 1, timestamp: 130, blockNumber: 130, chainID: 4},
})
res, err := mergeDataPoints(testData)
res, err := mergeDataPoints(testData, strideDuration)
require.NoError(t, err)
require.Equal(t, 0, len(res))
}
func TestServiceGetBalanceHistoryEmptyDB(t *testing.T) {
func TestServiceMergeDataPointsEmptyDB(t *testing.T) {
testData := prepareTestData([]TestDataPoint{})
res, err := mergeDataPoints(testData)
strideDuration := 10 * time.Second
res, err := mergeDataPoints(testData, strideDuration)
require.NoError(t, err)
require.Equal(t, 0, len(res))
}
func TestServiceFindFirstStrideWindowFirstForAllChainInOneStride(t *testing.T) {
strideDuration := 10 * time.Second
testData := prepareTestData([]TestDataPoint{
{value: 1, timestamp: 103, blockNumber: 101, chainID: 2},
{value: 1, timestamp: 106, blockNumber: 102, chainID: 3},
{value: 1, timestamp: 100, blockNumber: 100, chainID: 1},
{value: 1, timestamp: 110, blockNumber: 103, chainID: 1},
{value: 1, timestamp: 110, blockNumber: 103, chainID: 2},
})
startTimestamp := findFirstStrideWindow(testData, strideDuration)
require.Equal(t, testData[1][0].Timestamp, uint64(startTimestamp))
}
func TestServiceSortTimeAsc(t *testing.T) {
testData := prepareTestData([]TestDataPoint{
{value: 3, timestamp: 103, blockNumber: 103, chainID: 3},
{value: 4, timestamp: 104, blockNumber: 104, chainID: 4},
{value: 2, timestamp: 102, blockNumber: 102, chainID: 2},
{value: 1, timestamp: 101, blockNumber: 101, chainID: 1},
})
sorted := sortTimeAsc(testData, map[chainIdentity]int{4: 0, 3: 0, 2: 0, 1: 0})
require.Equal(t, []timeIdentity{{1, 0}, {2, 0}, {3, 0}, {4, 0}}, sorted)
}
func TestServiceAtEnd(t *testing.T) {
testData := prepareTestData([]TestDataPoint{
{value: 1, timestamp: 101, blockNumber: 101, chainID: 1},
{value: 1, timestamp: 101, blockNumber: 101, chainID: 2},
{value: 1, timestamp: 102, blockNumber: 102, chainID: 1},
})
sorted := sortTimeAsc(testData, map[chainIdentity]int{1: 0, 2: 0})
require.False(t, sorted[0].atEnd(testData))
require.True(t, sorted[1].atEnd(testData))
sorted = sortTimeAsc(testData, map[chainIdentity]int{1: 1, 2: 0})
require.True(t, sorted[1].atEnd(testData))
}