feat(Wallet): add activity filter api prototype

Add the possibility of retrieving the metadata of wallet activity based
on the given filter criteria.

Current implementation relies that after fetching the metadata, user
will follow up with more requests for details. However, after some
experimenting I'm considering extracting all required information
for the summary viewing while filtering. This way there will be no
need for another batch requests for transfers, multi-transactions and
pending transactions to show the summary. Only when user wants to see
the details for one will specifically request it.

For this first prototype, the filter criteria is limited to:
- time
- type
- addresses

Major changes:
- Add the filter definition to be used in propagating the filter
  information
- Add GetActivityEntries API to return the list of activity entries
  for the given addresses/chainIDs by a view in the complete list
- GetTransfersForIdentities to batch retrieve further details of the
  transfers
- GetPendingTransactionsForIdentities to batch retrieve further details
  of the pending transactions
- Added a new package testutils for tests.
- Added tests

Updates status-desktop #10366
Closes status-desktop #10633
This commit is contained in:
Stefan 2023-04-21 14:59:29 +03:00 committed by Stefan Dunca
parent 0197e6c484
commit c020222f1b
11 changed files with 987 additions and 29 deletions

View File

@ -0,0 +1,299 @@
package activity
import (
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/services/wallet/transfer"
)
type PayloadType = int
const (
MultiTransactionPT PayloadType = iota + 1
SimpleTransactionPT
PendingTransactionPT
)
type Entry struct {
// TODO: rename in payloadType
transactionType PayloadType
transaction *transfer.TransactionIdentity
id transfer.MultiTransactionIDType
timestamp int64
activityType Type
}
type jsonSerializationTemplate struct {
TransactionType PayloadType `json:"transactionType"`
Transaction *transfer.TransactionIdentity `json:"transaction"`
ID transfer.MultiTransactionIDType `json:"id"`
Timestamp int64 `json:"timestamp"`
ActivityType Type `json:"activityType"`
}
func (e *Entry) MarshalJSON() ([]byte, error) {
return json.Marshal(jsonSerializationTemplate{
TransactionType: e.transactionType,
Transaction: e.transaction,
ID: e.id,
Timestamp: e.timestamp,
ActivityType: e.activityType,
})
}
func (e *Entry) UnmarshalJSON(data []byte) error {
aux := jsonSerializationTemplate{}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
e.transactionType = aux.TransactionType
e.transaction = aux.Transaction
e.id = aux.ID
e.timestamp = aux.Timestamp
e.activityType = aux.ActivityType
return nil
}
func NewActivityEntryWithTransaction(transactionType PayloadType, transaction *transfer.TransactionIdentity, timestamp int64, activityType Type) Entry {
if transactionType != SimpleTransactionPT && transactionType != PendingTransactionPT {
panic("invalid transaction type")
}
return Entry{
transactionType: transactionType,
transaction: transaction,
id: 0,
timestamp: timestamp,
activityType: activityType,
}
}
func NewActivityEntryWithMultiTransaction(id transfer.MultiTransactionIDType, timestamp int64, activityType Type) Entry {
return Entry{
transactionType: MultiTransactionPT,
id: id,
timestamp: timestamp,
activityType: activityType,
}
}
func (e *Entry) TransactionType() PayloadType {
return e.transactionType
}
func multiTransactionTypeToActivityType(mtType transfer.MultiTransactionType) Type {
if mtType == transfer.MultiTransactionSend {
return SendAT
} else if mtType == transfer.MultiTransactionSwap {
return SwapAT
} else if mtType == transfer.MultiTransactionBridge {
return BridgeAT
}
panic("unknown multi transaction type")
}
func typesContain(slice []Type, item Type) bool {
for _, a := range slice {
if a == item {
return true
}
}
return false
}
func joinMTTypes(types []transfer.MultiTransactionType) string {
var sb strings.Builder
for i, val := range types {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(strconv.Itoa(int(val)))
}
return sb.String()
}
func joinAddresses(addresses []common.Address) string {
var sb strings.Builder
for i, address := range addresses {
if i == 0 {
sb.WriteString("('")
} else {
sb.WriteString("'),('")
}
sb.WriteString(strings.ToUpper(hex.EncodeToString(address[:])))
}
sb.WriteString("')")
return sb.String()
}
func activityTypesToMultiTransactionTypes(trTypes []Type) []transfer.MultiTransactionType {
mtTypes := make([]transfer.MultiTransactionType, 0, len(trTypes))
for _, t := range trTypes {
var mtType transfer.MultiTransactionType
if t == SendAT {
mtType = transfer.MultiTransactionSend
} else if t == SwapAT {
mtType = transfer.MultiTransactionSwap
} else if t == BridgeAT {
mtType = transfer.MultiTransactionBridge
} else {
continue
}
mtTypes = append(mtTypes, mtType)
}
return mtTypes
}
// TODO: extend with SEND/RECEIVE for transfers and pending_transactions
// TODO: clarify if we include sender and receiver in pending_transactions as we do for transfers
// TODO optimization: consider implementing nullable []byte instead of using strings for addresses
// Query includes duplicates, will return multiple rows for the same transaction
const queryFormatString = `
WITH filter_conditions AS (
SELECT
? AS startFilterDisabled,
? AS startTimestamp,
? AS endFilterDisabled,
? AS endTimestamp,
? AS filterActivityTypeAll,
? AS filterActivityTypeSend,
? AS filterActivityTypeReceive,
? AS filterAllAddresses
),
filter_addresses(address) AS (
VALUES %s
)
SELECT
transfers.hash AS transfer_hash,
NULL AS pending_hash,
transfers.network_id AS network_id,
0 AS multi_tx_id,
transfers.timestamp AS timestamp,
NULL AS mt_type,
HEX(transfers.address) AS owner_address
FROM transfers, filter_conditions
WHERE transfers.multi_transaction_id = 0
AND ((startFilterDisabled OR timestamp >= startTimestamp) AND (endFilterDisabled OR timestamp <= endTimestamp))
AND (filterActivityTypeAll OR (filterActivityTypeSend AND (filterAllAddresses OR (HEX(transfers.sender) IN filter_addresses))) OR (filterActivityTypeReceive AND (filterAllAddresses OR (HEX(transfers.address) IN filter_addresses))))
AND (filterAllAddresses OR (HEX(transfers.sender) IN filter_addresses) OR (HEX(transfers.address) IN filter_addresses))
UNION ALL
SELECT
NULL AS transfer_hash,
pending_transactions.hash AS pending_hash,
pending_transactions.network_id AS network_id,
0 AS multi_tx_id,
pending_transactions.timestamp AS timestamp,
NULL AS mt_type,
NULL AS owner_address
FROM pending_transactions, filter_conditions
WHERE pending_transactions.multi_transaction_id = 0
AND ((startFilterDisabled OR timestamp >= startTimestamp) AND (endFilterDisabled OR timestamp <= endTimestamp))
AND (filterActivityTypeAll OR filterActivityTypeSend)
AND (filterAllAddresses OR (HEX(pending_transactions.from_address) IN filter_addresses) OR (HEX(pending_transactions.to_address) IN filter_addresses))
UNION ALL
SELECT
NULL AS transfer_hash,
NULL AS pending_hash,
NULL AS network_id,
multi_transactions.ROWID AS multi_tx_id,
multi_transactions.timestamp AS timestamp,
multi_transactions.type AS mt_type,
NULL AS owner_address
FROM multi_transactions, filter_conditions
WHERE ((startFilterDisabled OR timestamp >= startTimestamp) AND (endFilterDisabled OR timestamp <= endTimestamp))
AND (filterActivityTypeAll OR (multi_transactions.type IN (%s)))
AND (filterAllAddresses OR (HEX(multi_transactions.from_address) IN filter_addresses) OR (HEX(multi_transactions.to_address) IN filter_addresses))
ORDER BY timestamp DESC
LIMIT ? OFFSET ?`
func GetActivityEntries(db *sql.DB, addresses []common.Address, chainIDs []uint64, filter Filter, offset int, limit int) ([]Entry, error) {
// Query the transfers, pending_transactions, and multi_transactions tables ordered by timestamp column
// TODO: finish filter: chainIDs, statuses, tokenTypes, counterpartyAddresses
// TODO: use all accounts list for detecting SEND/RECEIVE instead of the current addresses list; also change activityType detection in transfer part
startFilterDisabled := !(filter.Period.StartTimestamp > 0)
endFilterDisabled := !(filter.Period.EndTimestamp > 0)
filterActivityTypeAll := typesContain(filter.Types, AllAT) || len(filter.Types) == 0
filterAllAddresses := len(addresses) == 0
//fmt.Println("@dd filter: timeEnabled", filter.Period.StartTimestamp, filter.Period.EndTimestamp, "; type", filter.Types, "offset", offset, "limit", limit)
joinedAddresses := "(NULL)"
if !filterAllAddresses {
joinedAddresses = joinAddresses(addresses)
}
mtTypes := activityTypesToMultiTransactionTypes(filter.Types)
joinedMTTypes := joinMTTypes(mtTypes)
queryString := fmt.Sprintf(queryFormatString, joinedAddresses, joinedMTTypes)
rows, err := db.Query(queryString,
startFilterDisabled, filter.Period.StartTimestamp, endFilterDisabled, filter.Period.EndTimestamp,
filterActivityTypeAll, typesContain(filter.Types, SendAT), typesContain(filter.Types, ReceiveAT),
filterAllAddresses,
limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var transferHash, pendingHash []byte
var chainID, multiTxID sql.NullInt64
var timestamp int64
var dbActivityType sql.NullByte
var dbAddress sql.NullString
err := rows.Scan(&transferHash, &pendingHash, &chainID, &multiTxID, &timestamp, &dbActivityType, &dbAddress)
if err != nil {
return nil, err
}
var entry Entry
if transferHash != nil && chainID.Valid {
var activityType Type = SendAT
thisAddress := common.HexToAddress(dbAddress.String)
for _, address := range addresses {
if address == thisAddress {
activityType = ReceiveAT
}
}
entry = NewActivityEntryWithTransaction(SimpleTransactionPT, &transfer.TransactionIdentity{ChainID: uint64(chainID.Int64), Hash: common.BytesToHash(transferHash), Address: thisAddress}, timestamp, activityType)
} else if pendingHash != nil && chainID.Valid {
var activityType Type = SendAT
entry = NewActivityEntryWithTransaction(PendingTransactionPT, &transfer.TransactionIdentity{ChainID: uint64(chainID.Int64), Hash: common.BytesToHash(pendingHash)}, timestamp, activityType)
} else if multiTxID.Valid {
activityType := multiTransactionTypeToActivityType(transfer.MultiTransactionType(dbActivityType.Byte))
entry = NewActivityEntryWithMultiTransaction(transfer.MultiTransactionIDType(multiTxID.Int64),
timestamp, activityType)
} else {
return nil, errors.New("invalid row data")
}
entries = append(entries, entry)
}
if err = rows.Err(); err != nil {
return nil, err
}
return entries, nil
}

View File

@ -0,0 +1,406 @@
package activity
import (
"database/sql"
"testing"
"github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/services/wallet/testutils"
"github.com/status-im/status-go/services/wallet/transfer"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
func setupTestActivityDB(t *testing.T) (db *sql.DB, close func()) {
db, err := appdatabase.SetupTestMemorySQLDB("wallet-activity-tests")
require.NoError(t, err)
return db, func() {
require.NoError(t, db.Close())
}
}
func insertTestPendingTransaction(t *testing.T, db *sql.DB, tr *transfer.TestTransaction) {
_, err := db.Exec(`
INSERT INTO pending_transactions (network_id, hash, timestamp, from_address, to_address,
symbol, gas_price, gas_limit, value, data, type, additional_data, multi_transaction_id
) VALUES (?, ?, ?, ?, ?, 'ETH', 0, 0, ?, '', 'test', '', ?)`,
tr.ChainID, tr.Hash, tr.Timestamp, tr.From, tr.To, tr.Value, tr.MultiTransactionID)
require.NoError(t, err)
}
type testData struct {
tr1 transfer.TestTransaction // index 1
pendingTr transfer.TestTransaction // index 2
singletonMTr transfer.TestTransaction // index 3
mTr transfer.TestTransaction // index 4
subTr transfer.TestTransaction // index 5
subPendingTr transfer.TestTransaction // index 6
singletonMTID transfer.MultiTransactionIDType
mTrID transfer.MultiTransactionIDType
}
// Generates and adds to the DB 6 transactions. 2 transactions, 2 pending and 2 multi transactions
// There are only 4 extractable transactions and multi-transaction with timestamps 1-4. The other 2 are associated with a multi-transaction
func fillTestData(t *testing.T, db *sql.DB) (td testData) {
trs := transfer.GenerateTestTransactions(t, db, 1, 6)
td.tr1 = trs[0]
transfer.InsertTestTransfer(t, db, &td.tr1)
td.pendingTr = trs[1]
insertTestPendingTransaction(t, db, &td.pendingTr)
td.singletonMTr = trs[2]
td.singletonMTID = transfer.InsertTestMultiTransaction(t, db, &td.singletonMTr)
td.mTr = trs[3]
td.mTrID = transfer.InsertTestMultiTransaction(t, db, &td.mTr)
td.subTr = trs[4]
td.subTr.MultiTransactionID = td.mTrID
transfer.InsertTestTransfer(t, db, &td.subTr)
td.subPendingTr = trs[5]
td.subPendingTr.MultiTransactionID = td.mTrID
insertTestPendingTransaction(t, db, &td.subPendingTr)
return
}
func TestGetActivityEntriesAll(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
td := fillTestData(t, db)
var filter Filter
entries, err := GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 10)
require.NoError(t, err)
require.Equal(t, 4, len(entries))
// Ensure we have the correct order
var curTimestamp int64 = 4
for _, entry := range entries {
require.Equal(t, curTimestamp, entry.timestamp, "entries are sorted by timestamp; expected %d, got %d", curTimestamp, entry.timestamp)
curTimestamp--
}
require.True(t, testutils.StructExistsInSlice(Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: td.tr1.ChainID, Hash: td.tr1.Hash, Address: td.tr1.To},
id: td.tr1.MultiTransactionID,
timestamp: td.tr1.Timestamp,
activityType: SendAT,
}, entries))
require.True(t, testutils.StructExistsInSlice(Entry{
transactionType: PendingTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: td.pendingTr.ChainID, Hash: td.pendingTr.Hash},
id: td.pendingTr.MultiTransactionID,
timestamp: td.pendingTr.Timestamp,
activityType: SendAT,
}, entries))
require.True(t, testutils.StructExistsInSlice(Entry{
transactionType: MultiTransactionPT,
transaction: nil,
id: td.singletonMTID,
timestamp: td.singletonMTr.Timestamp,
activityType: SendAT,
}, entries))
require.True(t, testutils.StructExistsInSlice(Entry{
transactionType: MultiTransactionPT,
transaction: nil,
id: td.mTrID,
timestamp: td.mTr.Timestamp,
activityType: SendAT,
}, entries))
// Ensure the sub-transactions of the multi-transactions are not returned
require.False(t, testutils.StructExistsInSlice(Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: td.subTr.ChainID, Hash: td.subTr.Hash, Address: td.subTr.To},
id: td.subTr.MultiTransactionID,
timestamp: td.subTr.Timestamp,
activityType: SendAT,
}, entries))
require.False(t, testutils.StructExistsInSlice(Entry{
transactionType: PendingTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: td.subPendingTr.ChainID, Hash: td.subPendingTr.Hash},
id: td.subPendingTr.MultiTransactionID,
timestamp: td.subPendingTr.Timestamp,
activityType: SendAT,
}, entries))
}
// TestGetActivityEntriesWithSenderFilter covers the issue with returning the same transaction
// twice when the sender and receiver have entries in the transfers table
func TestGetActivityEntriesWithSameTransactionForSenderAndReceiverInDB(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
// Add 4 extractable transactions with timestamps 1-4
td := fillTestData(t, db)
// Add another transaction with sender and receiver reversed
receiverTr := td.tr1
prevTo := receiverTr.To
receiverTr.To = td.tr1.From
receiverTr.From = prevTo
transfer.InsertTestTransfer(t, db, &receiverTr)
var filter Filter
entries, err := GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 10)
require.NoError(t, err)
// TODO: decide how should we handle this case filter out or include it in the result
// For now we include both. Can be changed by using UNION instead of UNION ALL in the query or by filtering out
require.Equal(t, 5, len(entries))
}
func TestGetActivityEntriesFilterByTime(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
td := fillTestData(t, db)
// Add 6 extractable transactions with timestamps 6-12
trs := transfer.GenerateTestTransactions(t, db, 6, 6)
for i := range trs {
transfer.InsertTestTransfer(t, db, &trs[i])
}
// Test start only
var filter Filter
filter.Period.StartTimestamp = td.singletonMTr.Timestamp
entries, err := GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 8, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[5].ChainID, Hash: trs[5].Hash, Address: trs[5].To},
id: 0,
timestamp: trs[5].Timestamp,
activityType: SendAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: MultiTransactionPT,
transaction: nil,
id: td.singletonMTID,
timestamp: td.singletonMTr.Timestamp,
activityType: SendAT,
}, entries[7])
// Test complete interval
filter.Period.EndTimestamp = trs[2].Timestamp
entries, err = GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 5, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[2].ChainID, Hash: trs[2].Hash, Address: trs[2].To},
id: 0,
timestamp: trs[2].Timestamp,
activityType: SendAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: MultiTransactionPT,
transaction: nil,
id: td.singletonMTID,
timestamp: td.singletonMTr.Timestamp,
activityType: SendAT,
}, entries[4])
// Test end only
filter.Period.StartTimestamp = 0
entries, err = GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 7, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[2].ChainID, Hash: trs[2].Hash, Address: trs[2].To},
id: 0,
timestamp: trs[2].Timestamp,
activityType: SendAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: td.tr1.ChainID, Hash: td.tr1.Hash, Address: td.tr1.To},
id: 0,
timestamp: td.tr1.Timestamp,
activityType: SendAT,
}, entries[6])
}
func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
// Add 10 extractable transactions with timestamps 1-10
trs := transfer.GenerateTestTransactions(t, db, 1, 10)
for i := range trs {
transfer.InsertTestTransfer(t, db, &trs[i])
}
var filter Filter
// Get all
entries, err := GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 5)
require.NoError(t, err)
require.Equal(t, 5, len(entries))
// Get time based interval
filter.Period.StartTimestamp = trs[2].Timestamp
filter.Period.EndTimestamp = trs[8].Timestamp
entries, err = GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 3)
require.NoError(t, err)
require.Equal(t, 3, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[8].ChainID, Hash: trs[8].Hash, Address: trs[8].To},
id: 0,
timestamp: trs[8].Timestamp,
activityType: SendAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[6].ChainID, Hash: trs[6].Hash, Address: trs[6].To},
id: 0,
timestamp: trs[6].Timestamp,
activityType: SendAT,
}, entries[2])
// Move window 2 entries forward
entries, err = GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 2, 3)
require.NoError(t, err)
require.Equal(t, 3, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[6].ChainID, Hash: trs[6].Hash, Address: trs[6].To},
id: 0,
timestamp: trs[6].Timestamp,
activityType: SendAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[4].ChainID, Hash: trs[4].Hash, Address: trs[4].To},
id: 0,
timestamp: trs[4].Timestamp,
activityType: SendAT,
}, entries[2])
// Move window 4 more entries to test filter cap
entries, err = GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 6, 3)
require.NoError(t, err)
require.Equal(t, 1, len(entries))
// Check start and end content
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[2].ChainID, Hash: trs[2].Hash, Address: trs[2].To},
id: 0,
timestamp: trs[2].Timestamp,
activityType: SendAT,
}, entries[0])
}
func TestGetActivityEntriesFilterByType(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
// Adds 4 extractable transactions
fillTestData(t, db)
// Add 6 extractable transactions: one MultiTransactionSwap, two MultiTransactionBridge rest Send
trs := transfer.GenerateTestTransactions(t, db, 6, 6)
trs[1].MultiTransactionType = transfer.MultiTransactionBridge
trs[3].MultiTransactionType = transfer.MultiTransactionSwap
trs[5].MultiTransactionType = transfer.MultiTransactionBridge
for i := range trs {
if trs[i].MultiTransactionType != transfer.MultiTransactionSend {
transfer.InsertTestMultiTransaction(t, db, &trs[i])
} else {
transfer.InsertTestTransfer(t, db, &trs[i])
}
}
// Test filtering out without address involved
var filter Filter
// TODO: add more types to cover all cases
filter.Types = []Type{SendAT, SwapAT}
entries, err := GetActivityEntries(db, []common.Address{}, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 8, len(entries))
swapCount := 0
sendCount := 0
for _, entry := range entries {
if entry.activityType == SendAT {
sendCount++
}
if entry.activityType == SwapAT {
swapCount++
}
}
require.Equal(t, 7, sendCount)
require.Equal(t, 1, swapCount)
// Test filtering out with address involved
filter.Types = []Type{BridgeAT, ReceiveAT}
// Include one "to" from transfers to be detected as receive
addresses := []common.Address{trs[0].To, trs[1].To, trs[2].From, trs[3].From, trs[5].From}
entries, err = GetActivityEntries(db, addresses, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 3, len(entries))
bridgeCount := 0
receiveCount := 0
for _, entry := range entries {
if entry.activityType == BridgeAT {
bridgeCount++
}
if entry.activityType == ReceiveAT {
receiveCount++
}
}
require.Equal(t, 2, bridgeCount)
require.Equal(t, 1, receiveCount)
}
func TestGetActivityEntriesFilterByAddress(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
// Adds 4 extractable transactions
td := fillTestData(t, db)
// Add 6 extractable transactions: one MultiTransactionSwap, two MultiTransactionBridge rest Send
trs := transfer.GenerateTestTransactions(t, db, 7, 6)
for i := range trs {
transfer.InsertTestTransfer(t, db, &trs[i])
}
var filter Filter
addressesFilter := []common.Address{td.mTr.To, trs[1].From, trs[4].To}
entries, err := GetActivityEntries(db, addressesFilter, []uint64{}, filter, 0, 15)
require.NoError(t, err)
require.Equal(t, 3, len(entries))
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[4].ChainID, Hash: trs[4].Hash, Address: trs[4].To},
id: 0,
timestamp: trs[4].Timestamp,
activityType: ReceiveAT,
}, entries[0])
require.Equal(t, Entry{
transactionType: SimpleTransactionPT,
transaction: &transfer.TransactionIdentity{ChainID: trs[1].ChainID, Hash: trs[1].Hash, Address: trs[1].To},
id: 0,
timestamp: trs[1].Timestamp,
activityType: SendAT,
}, entries[1])
require.Equal(t, Entry{
transactionType: MultiTransactionPT,
transaction: nil,
id: td.mTrID,
timestamp: td.mTr.Timestamp,
activityType: SendAT,
}, entries[2])
}

View File

@ -0,0 +1,46 @@
package activity
import "github.com/ethereum/go-ethereum/common"
type Period struct {
// 0 means no limit
StartTimestamp int64 `json:"startTimestamp"`
EndTimestamp int64 `json:"endTimestamp"`
}
type Type int
const (
AllAT Type = iota
SendAT
ReceiveAT
BuyAT
SwapAT
BridgeAT
)
type Status int
const (
AllAS Status = iota
FailedAS
PendingAS
CompleteAS
FinalizedAS
)
type TokenType int
const (
AllTT TokenType = iota
AssetTT
CollectiblesTT
)
type Filter struct {
Period Period `json:"period"`
Types []Type `json:"types"`
Statuses []Status `json:"statuses"`
TokenTypes []TokenType `json:"tokenTypes"`
CounterpartyAddresses []common.Address `json:"counterpartyAddresses"`
}

View File

@ -11,6 +11,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/services/wallet/activity"
"github.com/status-im/status-go/services/wallet/bridge" "github.com/status-im/status-go/services/wallet/bridge"
"github.com/status-im/status-go/services/wallet/currency" "github.com/status-im/status-go/services/wallet/currency"
"github.com/status-im/status-go/services/wallet/history" "github.com/status-im/status-go/services/wallet/history"
@ -97,6 +98,12 @@ func (api *API) GetTransfersByAddressAndChainID(ctx context.Context, chainID uin
return api.s.transferController.GetTransfersByAddress(ctx, chainID, address, hexBigToBN(toBlock), limit.ToInt().Int64(), fetchMore) return api.s.transferController.GetTransfersByAddress(ctx, chainID, address, hexBigToBN(toBlock), limit.ToInt().Int64(), fetchMore)
} }
func (api *API) GetTransfersForIdentities(ctx context.Context, identities []transfer.TransactionIdentity) ([]transfer.View, error) {
log.Debug("[Wallet: GetTransfersForIdentities] count", len(identities))
return api.s.transferController.GetTransfersForIdentities(ctx, identities)
}
// Deprecated: GetCachedBalances is deprecated. Use GetTokensBalances instead // Deprecated: GetCachedBalances is deprecated. Use GetTokensBalances instead
func (api *API) GetCachedBalances(ctx context.Context, addresses []common.Address) ([]transfer.BlockView, error) { func (api *API) GetCachedBalances(ctx context.Context, addresses []common.Address) ([]transfer.BlockView, error) {
return api.s.transferController.GetCachedBalances(ctx, api.s.rpcClient.UpstreamChainID, addresses) return api.s.transferController.GetCachedBalances(ctx, api.s.rpcClient.UpstreamChainID, addresses)
@ -230,6 +237,20 @@ func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs [
return rst, err return rst, err
} }
func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identities []transfer.TransactionIdentity) (result []*transfer.PendingTransaction, err error) {
log.Debug("call to GetPendingTransactionsForIdentities")
result = make([]*transfer.PendingTransaction, 0, len(identities))
var pt *transfer.PendingTransaction
for _, identity := range identities {
pt, err = api.s.transactionManager.GetPendingEntry(identity.ChainID, identity.Hash)
result = append(result, pt)
}
log.Debug("result from GetPendingTransactionsForIdentities", "len", len(result))
return
}
func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*transfer.PendingTransaction, error) { func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*transfer.PendingTransaction, error) {
log.Debug("call to get pending outbound transactions by address") log.Debug("call to get pending outbound transactions by address")
rst, err := api.s.transactionManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address) rst, err := api.s.transactionManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address)
@ -495,3 +516,8 @@ func (api *API) FetchAllCurrencyFormats() (currency.FormatPerSymbol, error) {
log.Debug("call to FetchAllCurrencyFormats") log.Debug("call to FetchAllCurrencyFormats")
return api.s.currency.FetchAllCurrencyFormats() return api.s.currency.FetchAllCurrencyFormats()
} }
func (api *API) GetActivityEntries(addresses []common.Address, chainIDs []uint64, filter activity.Filter, offset int, limit int) ([]activity.Entry, error) {
log.Debug("call to GetActivityEntries")
return activity.GetActivityEntries(api.s.db, addresses, chainIDs, filter, offset, limit)
}

View File

@ -0,0 +1,12 @@
package testutils
import "reflect"
func StructExistsInSlice[T any](target T, slice []T) bool {
for _, item := range slice {
if reflect.DeepEqual(target, item) {
return true
}
}
return false
}

View File

@ -219,6 +219,16 @@ func (c *Controller) GetTransfersByAddress(ctx context.Context, chainID uint64,
return castToTransferViews(rst), nil return castToTransferViews(rst), nil
} }
func (c *Controller) GetTransfersForIdentities(ctx context.Context, identities []TransactionIdentity) ([]View, error) {
rst, err := c.db.GetTransfersForIdentities(ctx, identities)
if err != nil {
log.Error("[transfer.Controller.GetTransfersByAddress] DB err", err)
return nil, err
}
return castToTransferViews(rst), nil
}
func (c *Controller) GetCachedBalances(ctx context.Context, chainID uint64, addresses []common.Address) ([]BlockView, error) { func (c *Controller) GetCachedBalances(ctx context.Context, chainID uint64, addresses []common.Address) ([]BlockView, error) {
result, error := c.blockDAO.getLastKnownBlocks(chainID, addresses) result, error := c.blockDAO.getLastKnownBlocks(chainID, addresses)
if error != nil { if error != nil {

View File

@ -1,6 +1,7 @@
package transfer package transfer
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
@ -262,7 +263,7 @@ func (db *Database) GetTransfersByAddressAndBlock(chainID uint64, address common
return query.Scan(rows) return query.Scan(rows)
} }
// GetTransfers load transfers transfer betweeen two blocks. // GetTransfers load transfers transfer between two blocks.
func (db *Database) GetTransfers(chainID uint64, start, end *big.Int) (rst []Transfer, err error) { func (db *Database) GetTransfers(chainID uint64, start, end *big.Int) (rst []Transfer, err error) {
query := newTransfersQuery().FilterNetwork(chainID).FilterStart(start).FilterEnd(end).FilterLoaded(1) query := newTransfersQuery().FilterNetwork(chainID).FilterStart(start).FilterEnd(end).FilterLoaded(1)
rows, err := db.client.Query(query.String(), query.Args()...) rows, err := db.client.Query(query.String(), query.Args()...)
@ -273,6 +274,22 @@ func (db *Database) GetTransfers(chainID uint64, start, end *big.Int) (rst []Tra
return query.Scan(rows) return query.Scan(rows)
} }
func (db *Database) GetTransfersForIdentities(ctx context.Context, identities []TransactionIdentity) (rst []Transfer, err error) {
query := newTransfersQuery()
for _, identity := range identities {
subQuery := newSubQuery()
// TODO optimization: consider using tuples in sqlite and IN operator
subQuery = subQuery.FilterNetwork(identity.ChainID).FilterTransactionHash(identity.Hash).FilterAddress(identity.Address)
query.addSubQuery(subQuery, OrSeparator)
}
rows, err := db.client.QueryContext(ctx, query.String(), query.Args()...)
if err != nil {
return
}
defer rows.Close()
return query.Scan(rows)
}
func (db *Database) GetPreloadedTransactions(chainID uint64, address common.Address, blockHash common.Hash) (rst []Transfer, err error) { func (db *Database) GetPreloadedTransactions(chainID uint64, address common.Address, blockHash common.Hash) (rst []Transfer, err error) {
query := newTransfersQuery(). query := newTransfersQuery().
FilterNetwork(chainID). FilterNetwork(chainID).

View File

@ -1,6 +1,7 @@
package transfer package transfer
import ( import (
"context"
"math/big" "math/big"
"testing" "testing"
@ -10,11 +11,10 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/appdatabase"
"github.com/status-im/status-go/sqlite"
) )
func setupTestDB(t *testing.T) (*Database, *BlockDAO, func()) { func setupTestDB(t *testing.T) (*Database, *BlockDAO, func()) {
db, err := appdatabase.InitializeDB(sqlite.InMemoryPath, "wallet-tests", sqlite.ReducedKDFIterationsNumber) db, err := appdatabase.SetupTestMemorySQLDB("wallet-transfer-tests")
require.NoError(t, err) require.NoError(t, err)
return NewDB(db), &BlockDAO{db}, func() { return NewDB(db), &BlockDAO{db}, func() {
require.NoError(t, db.Close()) require.NoError(t, db.Close())
@ -197,5 +197,34 @@ func TestDBGetTransfersFromBlock(t *testing.T) {
rst, err = db.GetTransfers(777, big.NewInt(2), big.NewInt(5)) rst, err = db.GetTransfers(777, big.NewInt(2), big.NewInt(5))
require.NoError(t, err) require.NoError(t, err)
require.Len(t, rst, 4) require.Len(t, rst, 4)
}
func TestGetTransfersForIdentities(t *testing.T) {
db, _, stop := setupTestDB(t)
defer stop()
trs := GenerateTestTransactions(t, db.client, 1, 4)
for i := range trs {
InsertTestTransfer(t, db.client, &trs[i])
}
entries, err := db.GetTransfersForIdentities(context.Background(), []TransactionIdentity{
TransactionIdentity{trs[1].ChainID, trs[1].Hash, trs[1].To},
TransactionIdentity{trs[3].ChainID, trs[3].Hash, trs[3].To}})
require.NoError(t, err)
require.Equal(t, 2, len(entries))
require.Equal(t, trs[1].Hash, entries[0].ID)
require.Equal(t, trs[3].Hash, entries[1].ID)
require.Equal(t, trs[1].From, entries[0].From)
require.Equal(t, trs[3].From, entries[1].From)
require.Equal(t, trs[1].To, entries[0].Address)
require.Equal(t, trs[3].To, entries[1].Address)
require.Equal(t, big.NewInt(trs[1].BlkNumber), entries[0].BlockNumber)
require.Equal(t, big.NewInt(trs[3].BlkNumber), entries[1].BlockNumber)
require.Equal(t, uint64(trs[1].Timestamp), entries[0].Timestamp)
require.Equal(t, uint64(trs[3].Timestamp), entries[1].Timestamp)
require.Equal(t, trs[1].ChainID, entries[0].NetworkID)
require.Equal(t, trs[3].ChainID, entries[1].NetworkID)
require.Equal(t, MultiTransactionIDType(0), entries[0].MultiTransactionID)
require.Equal(t, MultiTransactionIDType(0), entries[1].MultiTransactionID)
} }

View File

@ -12,30 +12,66 @@ import (
const baseTransfersQuery = "SELECT hash, type, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, network_id, base_gas_fee, COALESCE(multi_transaction_id, 0) FROM transfers" const baseTransfersQuery = "SELECT hash, type, blk_hash, blk_number, timestamp, address, tx, sender, receipt, log, network_id, base_gas_fee, COALESCE(multi_transaction_id, 0) FROM transfers"
type transfersQuery struct {
buf *bytes.Buffer
args []interface{}
whereAdded bool
subQuery bool
}
func newTransfersQuery() *transfersQuery { func newTransfersQuery() *transfersQuery {
newQuery := newEmptyQuery()
newQuery.buf.WriteString(baseTransfersQuery)
return newQuery
}
func newSubQuery() *transfersQuery {
newQuery := newEmptyQuery()
newQuery.subQuery = true
return newQuery
}
func newEmptyQuery() *transfersQuery {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
buf.WriteString(baseTransfersQuery)
return &transfersQuery{buf: buf} return &transfersQuery{buf: buf}
} }
type transfersQuery struct { func (q *transfersQuery) addWhereSeparator(separator SeparatorType) {
buf *bytes.Buffer if !q.whereAdded {
args []interface{} if !q.subQuery {
added bool q.buf.WriteString(" WHERE")
}
q.whereAdded = true
} else if separator == OrSeparator {
q.buf.WriteString(" OR")
} else if separator == AndSeparator {
q.buf.WriteString(" AND")
} else if separator != NoSeparator {
panic("Unknown separator. Need to handle current SeparatorType value")
}
} }
func (q *transfersQuery) andOrWhere() { type SeparatorType int
if q.added {
q.buf.WriteString(" AND") const (
} else { NoSeparator SeparatorType = iota + 1
q.buf.WriteString(" WHERE") OrSeparator
} AndSeparator
)
// addSubQuery adds where clause formed as: WHERE/<separator> (<subQuery>)
func (q *transfersQuery) addSubQuery(subQuery *transfersQuery, separator SeparatorType) *transfersQuery {
q.addWhereSeparator(separator)
q.buf.WriteString(" (")
q.buf.Write(subQuery.buf.Bytes())
q.buf.WriteString(")")
q.args = append(q.args, subQuery.args...)
return q
} }
func (q *transfersQuery) FilterStart(start *big.Int) *transfersQuery { func (q *transfersQuery) FilterStart(start *big.Int) *transfersQuery {
if start != nil { if start != nil {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" blk_number >= ?") q.buf.WriteString(" blk_number >= ?")
q.args = append(q.args, (*bigint.SQLBigInt)(start)) q.args = append(q.args, (*bigint.SQLBigInt)(start))
} }
@ -44,8 +80,7 @@ func (q *transfersQuery) FilterStart(start *big.Int) *transfersQuery {
func (q *transfersQuery) FilterEnd(end *big.Int) *transfersQuery { func (q *transfersQuery) FilterEnd(end *big.Int) *transfersQuery {
if end != nil { if end != nil {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" blk_number <= ?") q.buf.WriteString(" blk_number <= ?")
q.args = append(q.args, (*bigint.SQLBigInt)(end)) q.args = append(q.args, (*bigint.SQLBigInt)(end))
} }
@ -53,8 +88,7 @@ func (q *transfersQuery) FilterEnd(end *big.Int) *transfersQuery {
} }
func (q *transfersQuery) FilterLoaded(loaded int) *transfersQuery { func (q *transfersQuery) FilterLoaded(loaded int) *transfersQuery {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" loaded = ? ") q.buf.WriteString(" loaded = ? ")
q.args = append(q.args, loaded) q.args = append(q.args, loaded)
@ -62,32 +96,35 @@ func (q *transfersQuery) FilterLoaded(loaded int) *transfersQuery {
} }
func (q *transfersQuery) FilterNetwork(network uint64) *transfersQuery { func (q *transfersQuery) FilterNetwork(network uint64) *transfersQuery {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" network_id = ?") q.buf.WriteString(" network_id = ?")
q.args = append(q.args, network) q.args = append(q.args, network)
return q return q
} }
func (q *transfersQuery) FilterAddress(address common.Address) *transfersQuery { func (q *transfersQuery) FilterAddress(address common.Address) *transfersQuery {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" address = ?") q.buf.WriteString(" address = ?")
q.args = append(q.args, address) q.args = append(q.args, address)
return q return q
} }
func (q *transfersQuery) FilterTransactionHash(hash common.Hash) *transfersQuery {
q.addWhereSeparator(AndSeparator)
q.buf.WriteString(" hash = ?")
q.args = append(q.args, hash)
return q
}
func (q *transfersQuery) FilterBlockHash(blockHash common.Hash) *transfersQuery { func (q *transfersQuery) FilterBlockHash(blockHash common.Hash) *transfersQuery {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" blk_hash = ?") q.buf.WriteString(" blk_hash = ?")
q.args = append(q.args, blockHash) q.args = append(q.args, blockHash)
return q return q
} }
func (q *transfersQuery) FilterBlockNumber(blockNumber *big.Int) *transfersQuery { func (q *transfersQuery) FilterBlockNumber(blockNumber *big.Int) *transfersQuery {
q.andOrWhere() q.addWhereSeparator(AndSeparator)
q.added = true
q.buf.WriteString(" blk_number = ?") q.buf.WriteString(" blk_number = ?")
q.args = append(q.args, (*bigint.SQLBigInt)(blockNumber)) q.args = append(q.args, (*bigint.SQLBigInt)(blockNumber))
return q return q

View File

@ -0,0 +1,68 @@
package transfer
import (
"database/sql"
"fmt"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
type TestTransaction struct {
Hash common.Hash
ChainID uint64
From common.Address // [sender]
To common.Address // [address]
Timestamp int64
Value int64
BlkNumber int64
MultiTransactionID MultiTransactionIDType
MultiTransactionType MultiTransactionType
}
func GenerateTestTransactions(t *testing.T, db *sql.DB, firstStartIndex int, count int) (result []TestTransaction) {
for i := firstStartIndex; i < (firstStartIndex + count); i++ {
tr := TestTransaction{
Hash: common.HexToHash(fmt.Sprintf("0x1%d", i)),
ChainID: uint64(i),
From: common.HexToAddress(fmt.Sprintf("0x2%d", i)),
To: common.HexToAddress(fmt.Sprintf("0x3%d", i)),
Timestamp: int64(i),
Value: int64(i),
BlkNumber: int64(i),
MultiTransactionID: NoMultiTransactionID,
MultiTransactionType: MultiTransactionSend,
}
result = append(result, tr)
}
return
}
func InsertTestTransfer(t *testing.T, db *sql.DB, tr *TestTransaction) {
// Respect `FOREIGN KEY(network_id,address,blk_hash)` of `transfers` table
blkHash := common.HexToHash("4")
_, err := db.Exec(`
INSERT OR IGNORE INTO blocks(
network_id, address, blk_number, blk_hash
) VALUES (?, ?, ?, ?);
INSERT INTO transfers (network_id, hash, address, blk_hash, tx,
sender, receipt, log, type, blk_number, timestamp, loaded,
multi_transaction_id, base_gas_fee
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, "test", ?, ?, 0, ?, 0)`,
tr.ChainID, tr.To, tr.BlkNumber, blkHash,
tr.ChainID, tr.Hash, tr.To, blkHash, &JSONBlob{}, tr.From, &JSONBlob{}, &JSONBlob{}, tr.BlkNumber, tr.Timestamp, tr.MultiTransactionID)
require.NoError(t, err)
}
func InsertTestMultiTransaction(t *testing.T, db *sql.DB, tr *TestTransaction) MultiTransactionIDType {
result, err := db.Exec(`
INSERT INTO multi_transactions (from_address, from_asset, from_amount, to_address, to_asset, type, timestamp
) VALUES (?, 'ETH', 0, ?, 'SNT', ?, ?)`,
tr.From, tr.To, tr.MultiTransactionType, tr.Timestamp)
require.NoError(t, err)
rowID, err := result.LastInsertId()
require.NoError(t, err)
return MultiTransactionIDType(rowID)
}

View File

@ -44,6 +44,7 @@ func NewTransactionManager(db *sql.DB, gethManager *account.GethManager, transac
type MultiTransactionType uint8 type MultiTransactionType uint8
// TODO: extend with know types
const ( const (
MultiTransactionSend = iota MultiTransactionSend = iota
MultiTransactionSwap MultiTransactionSwap
@ -94,6 +95,12 @@ type PendingTransaction struct {
MultiTransactionID MultiTransactionIDType `json:"multi_transaction_id"` MultiTransactionID MultiTransactionIDType `json:"multi_transaction_id"`
} }
type TransactionIdentity struct {
ChainID uint64 `json:"chainId"`
Hash common.Hash `json:"hash"`
Address common.Address `json:"address"`
}
const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data, const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data, symbol, gas_price, gas_limit, type, additional_data,
network_id, COALESCE(multi_transaction_id, 0) network_id, COALESCE(multi_transaction_id, 0)
@ -173,6 +180,7 @@ func (tm *TransactionManager) GetPendingByAddress(chainIDs []uint64, address com
} }
// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity // GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity
// TODO: consider using address also in case we expect to have also for the receiver
func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) { func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) {
row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data, row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data, symbol, gas_price, gas_limit, type, additional_data,