diff --git a/services/wallet/api.go b/services/wallet/api.go index 34479ff81..d3a188a7b 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -132,8 +132,8 @@ func (api *API) FetchDecodedTxData(ctx context.Context, data string) (*thirdpart } // GetBalanceHistory retrieves token balance history for token identity on multiple chains -func (api *API) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, tokenSymbol string, currencySymbol string, timeInterval history.TimeInterval) ([]*history.ValuePoint, error) { - log.Debug("wallet.api.GetBalanceHistory", "chainIDs", chainIDs, "address", address, "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "timeInterval", timeInterval) +func (api *API) GetBalanceHistory(ctx context.Context, chainIDs []uint64, addresses []common.Address, tokenSymbol string, currencySymbol string, timeInterval history.TimeInterval) ([]*history.ValuePoint, error) { + log.Debug("wallet.api.GetBalanceHistory", "chainIDs", chainIDs, "address", addresses, "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "timeInterval", timeInterval) var fromTimestamp uint64 now := uint64(time.Now().UTC().Unix()) @@ -152,14 +152,14 @@ func (api *API) GetBalanceHistory(ctx context.Context, chainIDs []uint64, addres return nil, fmt.Errorf("unknown time interval: %v", timeInterval) } - return api.GetBalanceHistoryRange(ctx, chainIDs, address, tokenSymbol, currencySymbol, fromTimestamp, now) + return api.GetBalanceHistoryRange(ctx, chainIDs, addresses, tokenSymbol, currencySymbol, fromTimestamp, now) } // GetBalanceHistoryRange retrieves token balance history for token identity on multiple chains for a time range // 'toTimestamp' is ignored for now, but will be used in the future to limit the range of the history -func (api *API) GetBalanceHistoryRange(ctx context.Context, chainIDs []uint64, address common.Address, tokenSymbol string, currencySymbol string, fromTimestamp uint64, _ uint64) ([]*history.ValuePoint, error) { - log.Debug("wallet.api.GetBalanceHistoryRange", "chainIDs", chainIDs, "address", address, "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "fromTimestamp", fromTimestamp) - return api.s.history.GetBalanceHistory(ctx, chainIDs, address, tokenSymbol, currencySymbol, fromTimestamp) +func (api *API) GetBalanceHistoryRange(ctx context.Context, chainIDs []uint64, addresses []common.Address, tokenSymbol string, currencySymbol string, fromTimestamp uint64, _ uint64) ([]*history.ValuePoint, error) { + log.Debug("wallet.api.GetBalanceHistoryRange", "chainIDs", chainIDs, "address", addresses, "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "fromTimestamp", fromTimestamp) + return api.s.history.GetBalanceHistory(ctx, chainIDs, addresses, tokenSymbol, currencySymbol, fromTimestamp) } func (api *API) GetTokenList(ctx context.Context) ([]*token.List, error) { diff --git a/services/wallet/history/balance.go b/services/wallet/history/balance.go index 5c8ace146..e573d0a5e 100644 --- a/services/wallet/history/balance.go +++ b/services/wallet/history/balance.go @@ -58,10 +58,10 @@ func NewBalance(db *BalanceDB) *Balance { } // get returns the balance history for the given address from the given timestamp till now -func (b *Balance) get(ctx context.Context, chainID uint64, currency string, address common.Address, fromTimestamp uint64) ([]*entry, error) { - log.Debug("Getting balance history", "chainID", chainID, "currency", currency, "address", address, "fromTimestamp", fromTimestamp) +func (b *Balance) get(ctx context.Context, chainID uint64, currency string, addresses []common.Address, fromTimestamp uint64) ([]*entry, error) { + log.Debug("Getting balance history", "chainID", chainID, "currency", currency, "address", addresses, "fromTimestamp", fromTimestamp) - cached, err := b.db.getNewerThan(&assetIdentity{chainID, address, currency}, fromTimestamp) + cached, err := b.db.getNewerThan(&assetIdentity{chainID, addresses, currency}, fromTimestamp) if err != nil { return nil, err } @@ -69,69 +69,85 @@ func (b *Balance) get(ctx context.Context, chainID uint64, currency string, addr return cached, nil } -func (b *Balance) addEdgePoints(chainID uint64, currency string, address common.Address, fromTimestamp, toTimestamp uint64, data []*entry) (res []*entry, err error) { - log.Debug("Adding edge points", "chainID", chainID, "currency", currency, "address", address, "fromTimestamp", fromTimestamp) +func (b *Balance) addEdgePoints(chainID uint64, currency string, addresses []common.Address, fromTimestamp, toTimestamp uint64, data []*entry) (res []*entry, err error) { + log.Debug("Adding edge points", "chainID", chainID, "currency", currency, "address", addresses, "fromTimestamp", fromTimestamp) - var firstEntry *entry + res = data - if len(data) > 0 { - firstEntry = data[0] - } else { - firstEntry = &entry{ - chainID: chainID, - address: address, - tokenSymbol: currency, - timestamp: int64(fromTimestamp), + for _, address := range addresses { + var firstEntry *entry + + if len(data) > 0 { + for _, entry := range data { + if entry.address == address { + firstEntry = entry + break + } + } } - } - - previous, err := b.db.getEntryPreviousTo(firstEntry) - if err != nil { - return nil, err - } - - firstTimestamp, lastTimestamp := timestampBoundaries(fromTimestamp, toTimestamp, data) - - if previous != nil { - previous.timestamp = int64(firstTimestamp) // We might need to use another minimal offset respecting the time interval - previous.block = nil - res = append([]*entry{previous}, data...) - } else { - // Add a zero point at the beginning to draw a line from - res = append([]*entry{ - { + if firstEntry == nil { + firstEntry = &entry{ chainID: chainID, address: address, tokenSymbol: currency, - timestamp: int64(firstTimestamp), - balance: big.NewInt(0), - }, - }, data...) - } + timestamp: int64(fromTimestamp), + } + } - if res[len(res)-1].timestamp < int64(lastTimestamp) { - // Add a last point to draw a line to - res = append(res, &entry{ - chainID: chainID, - address: address, - tokenSymbol: currency, - timestamp: int64(lastTimestamp), - balance: res[len(res)-1].balance, - }) + previous, err := b.db.getEntryPreviousTo(firstEntry) + if err != nil { + return nil, err + } + + firstTimestamp, lastTimestamp := timestampBoundaries(fromTimestamp, toTimestamp, address, data) + + if previous != nil { + previous.timestamp = int64(firstTimestamp) // We might need to use another minimal offset respecting the time interval + previous.block = nil + res = append([]*entry{previous}, res...) + } else { + // Add a zero point at the beginning to draw a line from + res = append([]*entry{ + { + chainID: chainID, + address: address, + tokenSymbol: currency, + timestamp: int64(firstTimestamp), + balance: big.NewInt(0), + }, + }, res...) + } + + if res[len(res)-1].timestamp < int64(lastTimestamp) { + // Add a last point to draw a line to + res = append(res, &entry{ + chainID: chainID, + address: address, + tokenSymbol: currency, + timestamp: int64(lastTimestamp), + balance: res[len(res)-1].balance, + }) + } } return res, nil } -func timestampBoundaries(fromTimestamp, toTimestamp uint64, data []*entry) (firstTimestamp, lastTimestamp uint64) { +func timestampBoundaries(fromTimestamp, toTimestamp uint64, address common.Address, data []*entry) (firstTimestamp, lastTimestamp uint64) { firstTimestamp = fromTimestamp if fromTimestamp == 0 { if len(data) > 0 { - if data[0].timestamp == 0 { - panic("data[0].timestamp must never be 0") + for _, entry := range data { + if entry.address == address { + if entry.timestamp == 0 { + panic("data[0].timestamp must never be 0") + } + firstTimestamp = uint64(entry.timestamp) - 1 + break + } } - firstTimestamp = uint64(data[0].timestamp) - 1 - } else { + } + if firstTimestamp == fromTimestamp { firstTimestamp = genesisTimestamp } } @@ -145,8 +161,8 @@ func timestampBoundaries(fromTimestamp, toTimestamp uint64, data []*entry) (firs return firstTimestamp, lastTimestamp } -func addPaddingPoints(currency string, address common.Address, toTimestamp uint64, data []*entry, limit int) (res []*entry, err error) { - log.Debug("addPaddingPoints start", "currency", currency, "address", address, "len(data)", len(data), "data", data, "limit", limit) +func addPaddingPoints(currency string, addresses []common.Address, toTimestamp uint64, data []*entry, limit int) (res []*entry, err error) { + log.Debug("addPaddingPoints start", "currency", currency, "address", addresses, "len(data)", len(data), "data", data, "limit", limit) if len(data) < 2 { // Edge points must be added separately during the previous step return nil, errors.New("slice is empty") @@ -162,6 +178,11 @@ func addPaddingPoints(currency string, address common.Address, toTimestamp uint6 res = make([]*entry, len(data)) copy(res, data) + var address common.Address + if len(addresses) > 0 { + address = addresses[0] + } + for i, j, index := 1, 0, 0; len(res) < limit; index++ { // Add a last point to draw a line to. For some cases we might not need it, // but when merging with points from other chains, we might get wrong balance if we don't have it. diff --git a/services/wallet/history/balance_db.go b/services/wallet/history/balance_db.go index cd6c6e2d9..b4a161865 100644 --- a/services/wallet/history/balance_db.go +++ b/services/wallet/history/balance_db.go @@ -2,6 +2,7 @@ package history import ( "database/sql" + "encoding/hex" "fmt" "math/big" @@ -33,10 +34,23 @@ type entry struct { type assetIdentity struct { ChainID uint64 - Address common.Address + Addresses []common.Address TokenSymbol string } +func (a *assetIdentity) addressesToString() string { + var addressesStr string + for i, address := range a.Addresses { + addressStr := hex.EncodeToString(address[:]) + if i == 0 { + addressesStr = "X'" + addressStr + "'" + } else { + addressesStr += ", X'" + addressStr + "'" + } + } + return addressesStr +} + func (e *entry) String() string { return fmt.Sprintf("chainID: %v, address: %v, tokenSymbol: %v, tokenAddress: %v, block: %v, timestamp: %v, balance: %v", e.chainID, e.address, e.tokenSymbol, e.tokenAddress, e.block, e.timestamp, e.balance) @@ -87,8 +101,9 @@ func (b *BalanceDB) getEntriesWithoutBalances(chainID uint64, address common.Add func (b *BalanceDB) getNewerThan(identity *assetIdentity, timestamp uint64) (entries []*entry, err error) { // DISTINCT removes duplicates that can happen when a block has multiple transfers of same token - rawQueryStr := "SELECT DISTINCT block, timestamp, balance FROM balance_history WHERE chain_id = ? AND address = ? AND currency = ? AND timestamp > ? ORDER BY timestamp" - rows, err := b.db.Query(rawQueryStr, identity.ChainID, identity.Address, identity.TokenSymbol, timestamp) + rawQueryStr := "SELECT DISTINCT block, timestamp, balance, address FROM balance_history WHERE chain_id = ? AND address IN (%s) AND currency = ? AND timestamp > ? ORDER BY timestamp" + queryString := fmt.Sprintf(rawQueryStr, identity.addressesToString()) + rows, err := b.db.Query(queryString, identity.ChainID, identity.TokenSymbol, timestamp) if err == sql.ErrNoRows { return nil, nil } else if err != nil { @@ -101,12 +116,11 @@ func (b *BalanceDB) getNewerThan(identity *assetIdentity, timestamp uint64) (ent for rows.Next() { entry := &entry{ chainID: identity.ChainID, - address: identity.Address, tokenSymbol: identity.TokenSymbol, block: new(big.Int), balance: new(big.Int), } - err := rows.Scan((*bigint.SQLBigInt)(entry.block), &entry.timestamp, (*bigint.SQLBigIntBytes)(entry.balance)) + err := rows.Scan((*bigint.SQLBigInt)(entry.block), &entry.timestamp, (*bigint.SQLBigIntBytes)(entry.balance), &entry.address) if err != nil { return nil, err } diff --git a/services/wallet/history/balance_test.go b/services/wallet/history/balance_test.go index 0295d2074..039cb8922 100644 --- a/services/wallet/history/balance_test.go +++ b/services/wallet/history/balance_test.go @@ -30,7 +30,7 @@ func dbWithEntries(t *testing.T, entries []*entry) *BalanceDB { func TestBalance_addPaddingPoints(t *testing.T) { type args struct { currency string - address common.Address + addresses []common.Address fromTimestamp uint64 currentTimestamp uint64 data []*entry @@ -46,7 +46,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { name: "addOnePaddingPointAtMiddle", args: args{ currency: "ETH", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 0, currentTimestamp: 2, data: []*entry{ @@ -91,7 +91,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { name: "noPaddingEqualsLimit", args: args{ currency: "ETH", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 0, currentTimestamp: 2, data: []*entry{ @@ -134,7 +134,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { name: "limitLessThanDataSize", args: args{ currency: "ETH", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 0, currentTimestamp: 2, data: []*entry{ @@ -177,7 +177,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { name: "addMultiplePaddingPoints", args: args{ currency: "ETH", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 1, currentTimestamp: 5, data: []*entry{ @@ -240,7 +240,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { name: "addMultiplePaddingPointsDuplicateTimestamps", args: args{ currency: "ETH", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 1, currentTimestamp: 5, data: []*entry{ @@ -309,7 +309,7 @@ func TestBalance_addPaddingPoints(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRes, err := addPaddingPoints(tt.args.currency, tt.args.address, tt.args.currentTimestamp, tt.args.data, tt.args.limit) + gotRes, err := addPaddingPoints(tt.args.currency, tt.args.addresses, tt.args.currentTimestamp, tt.args.data, tt.args.limit) if (err != nil) != tt.wantErr { t.Errorf("Balance.addPaddingPoints() error = %v, wantErr %v", err, tt.wantErr) return @@ -331,7 +331,7 @@ func TestBalance_addEdgePoints(t *testing.T) { type args struct { chainID uint64 currency string - address common.Address + addresses []common.Address fromTimestamp uint64 toTimestamp uint64 data []*entry @@ -351,7 +351,7 @@ func TestBalance_addEdgePoints(t *testing.T) { args: args{ chainID: 111, currency: "SNT", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 1, toTimestamp: 2, data: []*entry{}, @@ -382,7 +382,7 @@ func TestBalance_addEdgePoints(t *testing.T) { args: args{ chainID: 111, currency: "SNT", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 0, // will set to genesisTimestamp toTimestamp: genesisTimestamp + 1, data: []*entry{}, @@ -422,7 +422,7 @@ func TestBalance_addEdgePoints(t *testing.T) { args: args{ chainID: 111, currency: "SNT", - address: common.Address{1}, + addresses: []common.Address{common.Address{1}}, fromTimestamp: 2, toTimestamp: 4, data: []*entry{ @@ -477,7 +477,7 @@ func TestBalance_addEdgePoints(t *testing.T) { b := &Balance{ db: tt.fields.db, } - gotRes, err := b.addEdgePoints(tt.args.chainID, tt.args.currency, tt.args.address, tt.args.fromTimestamp, tt.args.toTimestamp, tt.args.data) + gotRes, err := b.addEdgePoints(tt.args.chainID, tt.args.currency, tt.args.addresses, tt.args.fromTimestamp, tt.args.toTimestamp, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("Balance.addEdgePoints() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/services/wallet/history/service.go b/services/wallet/history/service.go index c7d153a87..102fe8e69 100644 --- a/services/wallet/history/service.go +++ b/services/wallet/history/service.go @@ -108,8 +108,8 @@ func (s *Service) Start() { }() } -func (s *Service) mergeChainsBalances(chainIDs []uint64, address common.Address, tokenSymbol string, fromTimestamp uint64, data map[uint64][]*entry) ([]*DataPoint, error) { - log.Debug("Merging balances", "address", address, "tokenSymbol", tokenSymbol, "fromTimestamp", fromTimestamp, "len(data)", len(data)) +func (s *Service) mergeChainsBalances(chainIDs []uint64, addresses []common.Address, tokenSymbol string, fromTimestamp uint64, data map[uint64][]*entry) ([]*DataPoint, error) { + log.Debug("Merging balances", "address", addresses, "tokenSymbol", tokenSymbol, "fromTimestamp", fromTimestamp, "len(data)", len(data)) toTimestamp := uint64(time.Now().UTC().Unix()) allData := make([]*entry, 0) @@ -118,7 +118,7 @@ func (s *Service) mergeChainsBalances(chainIDs []uint64, address common.Address, // Iterate over chainIDs param, not data keys, because data may not contain all the chains, but we need edge points for all of them for _, chainID := range chainIDs { // edge points are needed to properly calculate total balance, as they contain the balance for the first and last timestamp - chainData, err := s.balance.addEdgePoints(chainID, tokenSymbol, address, fromTimestamp, toTimestamp, data[chainID]) + chainData, err := s.balance.addEdgePoints(chainID, tokenSymbol, addresses, fromTimestamp, toTimestamp, data[chainID]) if err != nil { return nil, err } @@ -137,34 +137,60 @@ func (s *Service) mergeChainsBalances(chainIDs []uint64, address common.Address, // Add padding points to make chart look nice if len(allData) < minPointsForGraph { - allData, _ = addPaddingPoints(tokenSymbol, address, toTimestamp, allData, minPointsForGraph) + allData, _ = addPaddingPoints(tokenSymbol, addresses, toTimestamp, allData, minPointsForGraph) } - return entriesToDataPoints(chainIDs, allData) + return entriesToDataPoints(allData) } // Expects sorted data -func entriesToDataPoints(chainIDs []uint64, data []*entry) ([]*DataPoint, error) { +func entriesToDataPoints(data []*entry) ([]*DataPoint, error) { var resSlice []*DataPoint var groupedEntries []*entry // Entries with the same timestamp - sumBalances := func(entries []*entry) *big.Int { + type AddressKey struct { + Address common.Address + ChainID uint64 + } + + sumBalances := func(balanceMap map[AddressKey]*big.Int) *big.Int { + // Sum balances of all accounts and chains in current timestamp sum := big.NewInt(0) - for _, entry := range entries { - sum.Add(sum, entry.balance) + for _, balance := range balanceMap { + sum.Add(sum, balance) } return sum } - // calculate balance for entries with the same timestam and add a single point for them + updateBalanceMap := func(balanceMap map[AddressKey]*big.Int, entries []*entry) map[AddressKey]*big.Int { + // Update balance map for this timestamp + for _, entry := range entries { + if entry.chainID == 0 { + continue + } + key := AddressKey{ + Address: entry.address, + ChainID: entry.chainID, + } + balanceMap[key] = entry.balance + } + return balanceMap + } + + // Balance map always contains current balance for each address in specific timestamp + // It is required to sum up balances from previous timestamp from accounts not present in current timestamp + balanceMap := make(map[AddressKey]*big.Int) + for _, entry := range data { if len(groupedEntries) > 0 { if entry.timestamp == groupedEntries[0].timestamp { groupedEntries = append(groupedEntries, entry) continue } else { - // Calculate balance for the grouped entries - cumulativeBalance := sumBalances(groupedEntries) + // Split grouped entries into addresses + balanceMap = updateBalanceMap(balanceMap, groupedEntries) + // Calculate balance for all the addresses + cumulativeBalance := sumBalances(balanceMap) // Points in slice contain balances for all chains resSlice = appendPointToSlice(resSlice, &DataPoint{ Timestamp: uint64(groupedEntries[0].timestamp), @@ -182,7 +208,10 @@ func entriesToDataPoints(chainIDs []uint64, data []*entry) ([]*DataPoint, error) // If only edge points are present, groupedEntries will be non-empty if len(groupedEntries) > 0 { - cumulativeBalance := sumBalances(groupedEntries) + // Split grouped entries into addresses + balanceMap = updateBalanceMap(balanceMap, groupedEntries) + // Calculate balance for all the addresses + cumulativeBalance := sumBalances(balanceMap) resSlice = appendPointToSlice(resSlice, &DataPoint{ Timestamp: uint64(groupedEntries[0].timestamp), Balance: (*hexutil.Big)(cumulativeBalance), @@ -210,12 +239,12 @@ func appendPointToSlice(slice []*DataPoint, point *DataPoint) []*DataPoint { } // GetBalanceHistory returns token count balance -func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, address common.Address, tokenSymbol string, currencySymbol string, fromTimestamp uint64) ([]*ValuePoint, error) { - log.Debug("GetBalanceHistory", "chainIDs", chainIDs, "address", address.String(), "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "fromTimestamp", fromTimestamp) +func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, addresses []common.Address, tokenSymbol string, currencySymbol string, fromTimestamp uint64) ([]*ValuePoint, error) { + log.Debug("GetBalanceHistory", "chainIDs", chainIDs, "address", addresses, "tokenSymbol", tokenSymbol, "currencySymbol", currencySymbol, "fromTimestamp", fromTimestamp) chainDataMap := make(map[uint64][]*entry) for _, chainID := range chainIDs { - chainData, err := s.balance.get(ctx, chainID, tokenSymbol, address, fromTimestamp) // TODO Make chainID a slice? + chainData, err := s.balance.get(ctx, chainID, tokenSymbol, addresses, fromTimestamp) // TODO Make chainID a slice? if err != nil { return nil, err } @@ -228,7 +257,8 @@ func (s *Service) GetBalanceHistory(ctx context.Context, chainIDs []uint64, addr } // Need to get balance for all the chains for the first timestamp, otherwise total values will be incorrect - data, err := s.mergeChainsBalances(chainIDs, address, tokenSymbol, fromTimestamp, chainDataMap) + data, err := s.mergeChainsBalances(chainIDs, addresses, tokenSymbol, fromTimestamp, chainDataMap) + if err != nil { return nil, err } else if len(data) == 0 { diff --git a/services/wallet/history/service_test.go b/services/wallet/history/service_test.go index 835dbc278..6906f6aa1 100644 --- a/services/wallet/history/service_test.go +++ b/services/wallet/history/service_test.go @@ -5,13 +5,13 @@ import ( "reflect" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" ) func Test_entriesToDataPoints(t *testing.T) { type args struct { - chainIDs []uint64 - data []*entry + data []*entry } tests := []struct { name string @@ -22,7 +22,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "zeroAllChainsSameTimestamp", args: args{ - chainIDs: []uint64{1, 2}, data: []*entry{ { chainID: 1, @@ -49,7 +48,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "oneZeroAllChainsDifferentTimestamp", args: args{ - chainIDs: []uint64{1, 2}, data: []*entry{ { chainID: 2, @@ -80,7 +78,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "nonZeroAllChainsDifferentTimestamp", args: args{ - chainIDs: []uint64{1, 2}, data: []*entry{ { chainID: 2, @@ -100,7 +97,7 @@ func Test_entriesToDataPoints(t *testing.T) { Timestamp: 1, }, { - Balance: (*hexutil.Big)(big.NewInt(2)), + Balance: (*hexutil.Big)(big.NewInt(3)), Timestamp: 2, }, }, @@ -109,7 +106,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "sameChainDifferentTimestamp", args: args{ - chainIDs: []uint64{1, 2}, data: []*entry{ { chainID: 1, @@ -149,7 +145,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "sameChainDifferentTimestampOtherChainsEmpty", args: args{ - chainIDs: []uint64{1, 2}, data: []*entry{ { chainID: 1, @@ -195,7 +190,6 @@ func Test_entriesToDataPoints(t *testing.T) { { name: "onlyEdgePointsOnManyChainsWithPadding", args: args{ - chainIDs: []uint64{1, 2, 3}, data: []*entry{ // Left edge - same timestamp { @@ -271,11 +265,67 @@ func Test_entriesToDataPoints(t *testing.T) { }, wantErr: false, }, + { + name: "multipleAddresses", + args: args{ + data: []*entry{ + { + chainID: 2, + balance: big.NewInt(5), + timestamp: 1, + address: common.Address{1}, + }, + { + chainID: 1, + balance: big.NewInt(6), + timestamp: 1, + address: common.Address{2}, + }, + { + chainID: 1, + balance: big.NewInt(1), + timestamp: 2, + address: common.Address{1}, + }, + { + chainID: 1, + balance: big.NewInt(2), + timestamp: 3, + address: common.Address{2}, + }, + { + chainID: 1, + balance: big.NewInt(4), + timestamp: 4, + address: common.Address{2}, + }, + }, + }, + want: []*DataPoint{ + { + Balance: (*hexutil.Big)(big.NewInt(11)), + Timestamp: 1, + }, + { + Balance: (*hexutil.Big)(big.NewInt(12)), + Timestamp: 2, + }, + { + Balance: (*hexutil.Big)(big.NewInt(8)), + Timestamp: 3, + }, + { + Balance: (*hexutil.Big)(big.NewInt(10)), + Timestamp: 4, + }, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := entriesToDataPoints(tt.args.chainIDs, tt.args.data) + got, err := entriesToDataPoints(tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("entriesToDataPoints() error = %v, wantErr %v", err, tt.wantErr) return