feat(wallet): make filer API async

Refactor the filter interface to be an async call which returns
the result using a wallet event
A call to the filter API will cancel the ongoing filter and receive
an error result event

Closes status-desktop #10994
This commit is contained in:
Stefan 2023-06-09 01:52:45 +02:00 committed by Stefan Dunca
parent 1a2ca21070
commit d8eb038d7d
5 changed files with 186 additions and 42 deletions

View File

@ -1,6 +1,7 @@
package activity package activity
import ( import (
"context"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -422,12 +423,12 @@ const (
noEntriesInTmpTableSQLValues = "(NULL)" noEntriesInTmpTableSQLValues = "(NULL)"
) )
// GetActivityEntries returns query the transfers, pending_transactions, and multi_transactions tables // getActivityEntries queries the transfers, pending_transactions, and multi_transactions tables
// based on filter parameters and arguments // based on filter parameters and arguments
// it returns metadata for all entries ordered by timestamp column // it returns metadata for all entries ordered by timestamp column
// //
// Adding a no-limit option was never considered or required. // Adding a no-limit option was never considered or required.
func GetActivityEntries(db *sql.DB, addresses []eth.Address, chainIDs []common.ChainID, filter Filter, offset int, limit int) ([]Entry, error) { func getActivityEntries(ctx context.Context, db *sql.DB, addresses []eth.Address, chainIDs []common.ChainID, filter Filter, offset int, limit int) ([]Entry, error) {
// TODO: filter collectibles after they are added to multi_transactions table // TODO: filter collectibles after they are added to multi_transactions table
if len(filter.Tokens.EnabledTypes) > 0 && !sliceContains(filter.Tokens.EnabledTypes, AssetTT) { if len(filter.Tokens.EnabledTypes) > 0 && !sliceContains(filter.Tokens.EnabledTypes, AssetTT) {
// For now we deal only with assets so return empty result // For now we deal only with assets so return empty result
@ -483,7 +484,7 @@ func GetActivityEntries(db *sql.DB, addresses []eth.Address, chainIDs []common.C
queryString := fmt.Sprintf(queryFormatString, involvedAddresses, toAddresses, assets, networks, queryString := fmt.Sprintf(queryFormatString, involvedAddresses, toAddresses, assets, networks,
joinedMTTypes) joinedMTTypes)
rows, err := db.Query(queryString, rows, err := db.QueryContext(ctx, queryString,
startFilterDisabled, filter.Period.StartTimestamp, endFilterDisabled, filter.Period.EndTimestamp, startFilterDisabled, filter.Period.StartTimestamp, endFilterDisabled, filter.Period.EndTimestamp,
filterActivityTypeAll, sliceContains(filter.Types, SendAT), sliceContains(filter.Types, ReceiveAT), filterActivityTypeAll, sliceContains(filter.Types, SendAT), sliceContains(filter.Types, ReceiveAT),
fromTrType, toTrType, fromTrType, toTrType,

View File

@ -1,6 +1,7 @@
package activity package activity
import ( import (
"context"
"database/sql" "database/sql"
"testing" "testing"
@ -92,7 +93,7 @@ func TestGetActivityEntriesAll(t *testing.T) {
td, fromAddresses, toAddresses := fillTestData(t, db) td, fromAddresses, toAddresses := fillTestData(t, db)
var filter Filter var filter Filter
entries, err := GetActivityEntries(db, append(toAddresses, fromAddresses...), []common.ChainID{}, filter, 0, 10) entries, err := getActivityEntries(context.Background(), db, append(toAddresses, fromAddresses...), []common.ChainID{}, filter, 0, 10)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 4, len(entries)) require.Equal(t, 4, len(entries))
@ -185,7 +186,7 @@ func TestGetActivityEntriesWithSameTransactionForSenderAndReceiverInDB(t *testin
transfer.InsertTestTransfer(t, db, &receiverTr) transfer.InsertTestTransfer(t, db, &receiverTr)
var filter Filter var filter Filter
entries, err := GetActivityEntries(db, []eth.Address{td.tr1.From, receiverTr.From}, []common.ChainID{}, filter, 0, 10) entries, err := getActivityEntries(context.Background(), db, []eth.Address{td.tr1.From, receiverTr.From}, []common.ChainID{}, filter, 0, 10)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(entries)) require.Equal(t, 2, len(entries))
@ -198,7 +199,7 @@ func TestGetActivityEntriesWithSameTransactionForSenderAndReceiverInDB(t *testin
require.NotEqual(t, eth.Address{}, entries[0].transaction.Address) require.NotEqual(t, eth.Address{}, entries[0].transaction.Address)
require.Equal(t, td.tr1.From, entries[0].transaction.Address) require.Equal(t, td.tr1.From, entries[0].transaction.Address)
entries, err = GetActivityEntries(db, []eth.Address{}, []common.ChainID{}, filter, 0, 10) entries, err = getActivityEntries(context.Background(), db, []eth.Address{}, []common.ChainID{}, filter, 0, 10)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 5, len(entries)) require.Equal(t, 5, len(entries))
@ -225,7 +226,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
var filter Filter var filter Filter
filter.Period.StartTimestamp = td.singletonMTr.Timestamp filter.Period.StartTimestamp = td.singletonMTr.Timestamp
filter.Period.EndTimestamp = NoLimitTimestampForPeriod filter.Period.EndTimestamp = NoLimitTimestampForPeriod
entries, err := GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 8, len(entries)) require.Equal(t, 8, len(entries))
// Check start and end content // Check start and end content
@ -250,7 +251,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
// Test complete interval // Test complete interval
filter.Period.EndTimestamp = trs[2].Timestamp filter.Period.EndTimestamp = trs[2].Timestamp
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 5, len(entries)) require.Equal(t, 5, len(entries))
// Check start and end content // Check start and end content
@ -275,7 +276,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
// Test end only // Test end only
filter.Period.StartTimestamp = NoLimitTimestampForPeriod filter.Period.StartTimestamp = NoLimitTimestampForPeriod
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 7, len(entries)) require.Equal(t, 7, len(entries))
// Check start and end content // Check start and end content
@ -313,14 +314,14 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
var filter Filter var filter Filter
// Get all // Get all
entries, err := GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 5) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 5)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 5, len(entries)) require.Equal(t, 5, len(entries))
// Get time based interval // Get time based interval
filter.Period.StartTimestamp = trs[2].Timestamp filter.Period.StartTimestamp = trs[2].Timestamp
filter.Period.EndTimestamp = trs[8].Timestamp filter.Period.EndTimestamp = trs[8].Timestamp
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 3) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 3)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
// Check start and end content // Check start and end content
@ -344,7 +345,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
}, entries[2]) }, entries[2])
// Move window 2 entries forward // Move window 2 entries forward
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 2, 3) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 2, 3)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
// Check start and end content // Check start and end content
@ -368,7 +369,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
}, entries[2]) }, entries[2])
// Move window 4 more entries to test filter cap // Move window 4 more entries to test filter cap
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 6, 3) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 6, 3)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(entries)) require.Equal(t, 1, len(entries))
// Check start and end content // Check start and end content
@ -431,12 +432,12 @@ func TestGetActivityEntriesFilterByType(t *testing.T) {
filter.Types = allActivityTypesFilter() filter.Types = allActivityTypesFilter()
// Set tr1 to Receive and pendingTr to Send; rest of two MT remain default Send // Set tr1 to Receive and pendingTr to Send; rest of two MT remain default Send
addresses := []eth_common.Address{td.tr1.To, td.pendingTr.From, td.singletonMTr.From, td.mTr.From, trs[0].From, trs[2].From, trs[4].From, trs[6].From, trs[8].From} addresses := []eth_common.Address{td.tr1.To, td.pendingTr.From, td.singletonMTr.From, td.mTr.From, trs[0].From, trs[2].From, trs[4].From, trs[6].From, trs[8].From}
entries, err := GetActivityEntries(db, addresses, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, addresses, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 9, len(entries)) require.Equal(t, 9, len(entries))
filter.Types = []Type{SendAT, SwapAT} filter.Types = []Type{SendAT, SwapAT}
entries, err = GetActivityEntries(db, addresses, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, addresses, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
// 3 from td Send + 2 trs MT Send + 1 (swap) // 3 from td Send + 2 trs MT Send + 1 (swap)
require.Equal(t, 6, len(entries)) require.Equal(t, 6, len(entries))
@ -449,7 +450,7 @@ func TestGetActivityEntriesFilterByType(t *testing.T) {
require.Equal(t, 0, bridgeCount) require.Equal(t, 0, bridgeCount)
filter.Types = []Type{BridgeAT, ReceiveAT} filter.Types = []Type{BridgeAT, ReceiveAT}
entries, err = GetActivityEntries(db, addresses, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, addresses, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
@ -476,12 +477,12 @@ func TestGetActivityEntriesFilterByAddresses(t *testing.T) {
var filter Filter var filter Filter
addressesFilter := allAddressesFilter() addressesFilter := allAddressesFilter()
entries, err := GetActivityEntries(db, addressesFilter, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, addressesFilter, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, len(entries)) require.Equal(t, 10, len(entries))
addressesFilter = []eth_common.Address{td.mTr.To, trs[1].From, trs[4].To} addressesFilter = []eth_common.Address{td.mTr.To, trs[1].From, trs[4].To}
entries, err = GetActivityEntries(db, addressesFilter, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, addressesFilter, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
require.Equal(t, Entry{ require.Equal(t, Entry{
@ -536,12 +537,12 @@ func TestGetActivityEntriesFilterByStatus(t *testing.T) {
var filter Filter var filter Filter
filter.Statuses = allActivityStatusesFilter() filter.Statuses = allActivityStatusesFilter()
entries, err := GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 11, len(entries)) require.Equal(t, 11, len(entries))
filter.Statuses = []Status{PendingAS} filter.Statuses = []Status{PendingAS}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
require.Equal(t, td.pendingTr.Hash, entries[2].transaction.Hash) require.Equal(t, td.pendingTr.Hash, entries[2].transaction.Hash)
@ -549,24 +550,24 @@ func TestGetActivityEntriesFilterByStatus(t *testing.T) {
require.Equal(t, trs[1].Hash, entries[0].transaction.Hash) require.Equal(t, trs[1].Hash, entries[0].transaction.Hash)
filter.Statuses = []Status{FailedAS} filter.Statuses = []Status{FailedAS}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(entries)) require.Equal(t, 2, len(entries))
filter.Statuses = []Status{CompleteAS} filter.Statuses = []Status{CompleteAS}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 6, len(entries)) require.Equal(t, 6, len(entries))
// Finalized is treated as Complete, would need dynamic blockchain status to track the Finalized level // Finalized is treated as Complete, would need dynamic blockchain status to track the Finalized level
filter.Statuses = []Status{FinalizedAS} filter.Statuses = []Status{FinalizedAS}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 6, len(entries)) require.Equal(t, 6, len(entries))
// Combined filter // Combined filter
filter.Statuses = []Status{FailedAS, PendingAS} filter.Statuses = []Status{FailedAS, PendingAS}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 5, len(entries)) require.Equal(t, 5, len(entries))
} }
@ -588,35 +589,35 @@ func TestGetActivityEntriesFilterByTokenType(t *testing.T) {
var filter Filter var filter Filter
filter.Tokens = noAssetsFilter() filter.Tokens = noAssetsFilter()
entries, err := GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(entries)) require.Equal(t, 0, len(entries))
filter.Tokens = allTokensFilter() filter.Tokens = allTokensFilter()
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, len(entries)) require.Equal(t, 10, len(entries))
// Regression when collectibles is nil // Regression when collectibles is nil
filter.Tokens = Tokens{[]TokenCode{}, nil, []TokenType{}} filter.Tokens = Tokens{[]TokenCode{}, nil, []TokenType{}}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, len(entries)) require.Equal(t, 10, len(entries))
filter.Tokens = Tokens{Assets: []TokenCode{"ETH"}, EnabledTypes: []TokenType{AssetTT}} filter.Tokens = Tokens{Assets: []TokenCode{"ETH"}, EnabledTypes: []TokenType{AssetTT}}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
// TODO: update tests after adding token type to transfers // TODO: update tests after adding token type to transfers
filter.Tokens = Tokens{Assets: []TokenCode{"USDC", "DAI"}, EnabledTypes: []TokenType{AssetTT}} filter.Tokens = Tokens{Assets: []TokenCode{"USDC", "DAI"}, EnabledTypes: []TokenType{AssetTT}}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(entries)) require.Equal(t, 1, len(entries))
// Regression when EnabledTypes ar empty // Regression when EnabledTypes ar empty
filter.Tokens = Tokens{Assets: []TokenCode{"USDC", "DAI"}, EnabledTypes: []TokenType{}} filter.Tokens = Tokens{Assets: []TokenCode{"USDC", "DAI"}, EnabledTypes: []TokenType{}}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(entries)) require.Equal(t, 1, len(entries))
} }
@ -637,22 +638,22 @@ func TestGetActivityEntriesFilterByToAddresses(t *testing.T) {
var filter Filter var filter Filter
filter.CounterpartyAddresses = allAddressesFilter() filter.CounterpartyAddresses = allAddressesFilter()
entries, err := GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, len(entries)) require.Equal(t, 10, len(entries))
filter.CounterpartyAddresses = []eth_common.Address{eth_common.HexToAddress("0x567890")} filter.CounterpartyAddresses = []eth_common.Address{eth_common.HexToAddress("0x567890")}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(entries)) require.Equal(t, 0, len(entries))
filter.CounterpartyAddresses = []eth_common.Address{td.pendingTr.To, td.mTr.To, trs[3].To} filter.CounterpartyAddresses = []eth_common.Address{td.pendingTr.To, td.mTr.To, trs[3].To}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 3, len(entries)) require.Equal(t, 3, len(entries))
filter.CounterpartyAddresses = []eth_common.Address{td.tr1.To, td.pendingTr.From, trs[3].From, trs[5].To} filter.CounterpartyAddresses = []eth_common.Address{td.tr1.To, td.pendingTr.From, trs[3].From, trs[5].To}
entries, err = GetActivityEntries(db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, len(entries)) require.Equal(t, 2, len(entries))
} }
@ -671,18 +672,18 @@ func TestGetActivityEntriesFilterByNetworks(t *testing.T) {
var filter Filter var filter Filter
chainIDs := allNetworksFilter() chainIDs := allNetworksFilter()
entries, err := GetActivityEntries(db, []eth_common.Address{}, chainIDs, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, []eth_common.Address{}, chainIDs, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 10, len(entries)) require.Equal(t, 10, len(entries))
chainIDs = []common.ChainID{5674839210} chainIDs = []common.ChainID{5674839210}
entries, err = GetActivityEntries(db, []eth_common.Address{}, chainIDs, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, chainIDs, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
// TODO: update after multi-transactions are filterable by ChainID // TODO: update after multi-transactions are filterable by ChainID
require.Equal(t, 2 /*0*/, len(entries)) require.Equal(t, 2 /*0*/, len(entries))
chainIDs = []common.ChainID{td.pendingTr.ChainID, td.mTr.ChainID, trs[3].ChainID} chainIDs = []common.ChainID{td.pendingTr.ChainID, td.mTr.ChainID, trs[3].ChainID}
entries, err = GetActivityEntries(db, []eth_common.Address{}, chainIDs, filter, 0, 15) entries, err = getActivityEntries(context.Background(), db, []eth_common.Address{}, chainIDs, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
// TODO: update after multi-transactions are filterable by ChainID // TODO: update after multi-transactions are filterable by ChainID
require.Equal(t, 4 /*3*/, len(entries)) require.Equal(t, 4 /*3*/, len(entries))
@ -704,7 +705,7 @@ func TestGetActivityEntriesCheckToAndFrom(t *testing.T) {
td.singletonMTr.From, td.mTr.To, trs[0].To, trs[1].To} td.singletonMTr.From, td.mTr.To, trs[0].To, trs[1].To}
var filter Filter var filter Filter
entries, err := GetActivityEntries(db, addresses, []common.ChainID{}, filter, 0, 15) entries, err := getActivityEntries(context.Background(), db, addresses, []common.ChainID{}, filter, 0, 15)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 6, len(entries)) require.Equal(t, 6, len(entries))
@ -729,3 +730,17 @@ func TestGetActivityEntriesCheckToAndFrom(t *testing.T) {
} }
// TODO test sub-transaction count for multi-transactions // TODO test sub-transaction count for multi-transactions
func TestGetActivityEntriesCheckContextCancellation(t *testing.T) {
db, close := setupTestActivityDB(t)
defer close()
_, _, _ = fillTestData(t, db)
cancellableCtx, cancelFn := context.WithCancel(context.Background())
cancelFn()
activities, err := getActivityEntries(cancellableCtx, db, []eth.Address{}, []common.ChainID{}, Filter{}, 0, 10)
require.ErrorIs(t, err, context.Canceled)
require.Equal(t, 0, len(activities))
}

View File

@ -0,0 +1,123 @@
package activity
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
w_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/walletevent"
)
const (
// FilterResponse json is sent as a message in the EventActivityFilteringDone event
EventActivityFilteringDone walletevent.EventType = "wallet-activity-filtering-done"
)
type Service struct {
db *sql.DB
eventFeed *event.Feed
context context.Context
cancelFn context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
}
func NewService(db *sql.DB, eventFeed *event.Feed) *Service {
return &Service{
db: db,
eventFeed: eventFeed,
}
}
type ErrorCode = int
const (
ErrorCodeSuccess ErrorCode = iota + 1
ErrorCodeFilterCanceled
ErrorCodeFilterFailed
)
type FilterResponse struct {
Activities []Entry `json:"activities"`
ThereMightBeMore bool `json:"thereMightBeMore"`
ErrorCode ErrorCode `json:"errorCode"`
}
// FilterActivityAsync allows only one filter task to run at a time
// and it cancels the current one if a new one is started
// All calls will trigger an EventActivityFilteringDone event with the result of the filtering
func (s *Service) FilterActivityAsync(ctx context.Context, addresses []common.Address, chainIDs []w_common.ChainID, filter Filter, offset int, limit int) error {
s.mu.Lock()
defer s.mu.Unlock()
// If a previous task is running, cancel it and wait to finish
if s.cancelFn != nil {
s.cancelFn()
s.wg.Wait()
}
if ctx.Err() != nil {
return fmt.Errorf("context error: %w", ctx.Err())
}
s.context, s.cancelFn = context.WithCancel(context.Background())
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.cancelFn = nil
}()
activities, err := getActivityEntries(s.context, s.db, addresses, chainIDs, filter, offset, limit)
res := FilterResponse{
ErrorCode: ErrorCodeFilterFailed,
}
if errors.Is(err, context.Canceled) {
res.ErrorCode = ErrorCodeFilterCanceled
} else if err == nil {
res.Activities = activities
res.ThereMightBeMore = len(activities) == limit
res.ErrorCode = ErrorCodeSuccess
}
s.sendResponseEvent(res)
}()
return nil
}
func (s *Service) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
// If a previous task is running, cancel it and wait to finish
if s.cancelFn != nil {
s.cancelFn()
s.wg.Wait()
s.cancelFn = nil
}
}
func (s *Service) sendResponseEvent(response FilterResponse) {
payload, err := json.Marshal(response)
if err != nil {
log.Error("Error marshaling response: %v", err)
}
s.eventFeed.Send(walletevent.Event{
Type: EventActivityFilteringDone,
Message: string(payload),
})
}

View File

@ -20,7 +20,7 @@ import (
"github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/token"
"github.com/status-im/status-go/services/wallet/transfer" "github.com/status-im/status-go/services/wallet/transfer"
wallet_common "github.com/status-im/status-go/services/wallet/common" wcommon "github.com/status-im/status-go/services/wallet/common"
) )
func NewAPI(s *Service) *API { func NewAPI(s *Service) *API {
@ -528,7 +528,7 @@ func (api *API) FetchAllCurrencyFormats() (currency.FormatPerSymbol, error) {
return api.s.currency.FetchAllCurrencyFormats() return api.s.currency.FetchAllCurrencyFormats()
} }
func (api *API) GetActivityEntries(addresses []common.Address, chainIDs []wallet_common.ChainID, filter activity.Filter, offset int, limit int) ([]activity.Entry, error) { func (api *API) FilterActivityAsync(ctx context.Context, addresses []common.Address, chainIDs []wcommon.ChainID, filter activity.Filter, offset int, limit int) error {
log.Debug("call to GetActivityEntries") log.Debug("[WalletAPI:: FilterActivityAsync] addr.count", len(addresses), "chainIDs.count", len(chainIDs), "filter", filter, "offset", offset, "limit", limit)
return activity.GetActivityEntries(api.s.db, addresses, chainIDs, filter, offset, limit) return api.s.activity.FilterActivityAsync(ctx, addresses, chainIDs, filter, offset, limit)
} }

View File

@ -17,6 +17,7 @@ import (
"github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/services/ens" "github.com/status-im/status-go/services/ens"
"github.com/status-im/status-go/services/stickers" "github.com/status-im/status-go/services/stickers"
"github.com/status-im/status-go/services/wallet/activity"
"github.com/status-im/status-go/services/wallet/collectibles" "github.com/status-im/status-go/services/wallet/collectibles"
"github.com/status-im/status-go/services/wallet/currency" "github.com/status-im/status-go/services/wallet/currency"
"github.com/status-im/status-go/services/wallet/history" "github.com/status-im/status-go/services/wallet/history"
@ -93,6 +94,7 @@ func NewService(
reader := NewReader(rpcClient, tokenManager, marketManager, accountsDB, NewPersistence(db), walletFeed) reader := NewReader(rpcClient, tokenManager, marketManager, accountsDB, NewPersistence(db), walletFeed)
history := history.NewService(db, walletFeed, rpcClient, tokenManager, marketManager) history := history.NewService(db, walletFeed, rpcClient, tokenManager, marketManager)
currency := currency.NewService(db, walletFeed, tokenManager, marketManager) currency := currency.NewService(db, walletFeed, tokenManager, marketManager)
activity := activity.NewService(db, walletFeed)
alchemyClient := alchemy.NewClient(config.WalletConfig.AlchemyAPIKeys) alchemyClient := alchemy.NewClient(config.WalletConfig.AlchemyAPIKeys)
infuraClient := infura.NewClient(config.WalletConfig.InfuraAPIKey, config.WalletConfig.InfuraAPIKeySecret) infuraClient := infura.NewClient(config.WalletConfig.InfuraAPIKey, config.WalletConfig.InfuraAPIKeySecret)
@ -118,6 +120,7 @@ func NewService(
reader: reader, reader: reader,
history: history, history: history,
currency: currency, currency: currency,
activity: activity,
} }
} }
@ -144,6 +147,7 @@ type Service struct {
reader *Reader reader *Reader
history *history.Service history *history.Service
currency *currency.Service currency *currency.Service
activity *activity.Service
} }
// Start signals transmitter. // Start signals transmitter.
@ -169,6 +173,7 @@ func (s *Service) Stop() error {
s.currency.Stop() s.currency.Stop()
s.reader.Stop() s.reader.Stop()
s.history.Stop() s.history.Stop()
s.activity.Stop()
s.started = false s.started = false
log.Info("wallet stopped") log.Info("wallet stopped")
return nil return nil