feat(wallet) add GetMoreForFilterSession API method

Also fix StopFilterSession to always notify client

Updates #12120
This commit is contained in:
Stefan 2024-02-08 20:13:12 -03:00 committed by Stefan Dunca
parent 4fc9420efc
commit e9ff0fbefe
5 changed files with 240 additions and 214 deletions

View File

@ -69,7 +69,7 @@ type Entry struct {
transferType *TransferType
contractAddress *eth.Address
isNew bool
isNew bool // isNew is used to indicate if the entry is newer than session start (changed state also)
}
// Only used for JSON marshalling

View File

@ -307,7 +307,27 @@ func validateSessionUpdateEvent(t *testing.T, ch chan walletevent.Event, filterR
return
}
func validateSessionUpdateEventWithPending(t *testing.T, ch chan walletevent.Event) (filterResponseCount int) {
type extraExpect struct {
offset *int
errorCode *ErrorCode
}
func getOptionalExpectations(e *extraExpect) (expectOffset int, expectErrorCode ErrorCode) {
expectOffset = 0
expectErrorCode = ErrorCodeSuccess
if e != nil {
if e.offset != nil {
expectOffset = *e.offset
}
if e.errorCode != nil {
expectErrorCode = *e.errorCode
}
}
return
}
func validateFilteringDone(t *testing.T, ch chan walletevent.Event, resCount int, checkPayloadFn func(payload FilterResponse), extra *extraExpect) (filterResponseCount int) {
for filterResponseCount < 1 {
select {
case res := <-ch:
@ -316,9 +336,18 @@ func validateSessionUpdateEventWithPending(t *testing.T, ch chan walletevent.Eve
var payload FilterResponse
err := json.Unmarshal([]byte(res.Message), &payload)
require.NoError(t, err)
require.Equal(t, ErrorCodeSuccess, payload.ErrorCode)
require.Equal(t, 2, len(payload.Activities))
expectOffset, expectErrorCode := getOptionalExpectations(extra)
require.Equal(t, expectErrorCode, payload.ErrorCode)
require.Equal(t, resCount, len(payload.Activities))
require.Equal(t, expectOffset, payload.Offset)
filterResponseCount++
if checkPayloadFn != nil {
checkPayloadFn(payload)
}
}
case <-time.NewTimer(1 * time.Second).C:
require.Fail(t, "timeout while waiting for EventActivityFilteringDone")
@ -331,14 +360,15 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) {
state := setupTestService(t)
defer state.close()
allAddresses, pendings, ch, cleanup := setupTransactions(t, state, 2, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: 3}})
transactionCount := 2
allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}})
defer cleanup()
sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 5)
require.Greater(t, sessionID, SessionID(0))
defer state.service.StopFilterSession(sessionID)
filterResponseCount := validateSessionUpdateEventWithPending(t, ch)
filterResponseCount := validateFilteringDone(t, ch, 2, nil, nil)
exp := pendings[0]
err := state.pendingTracker.StoreAndTrackPendingTx(&exp)
@ -350,52 +380,43 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) {
require.NoError(t, err)
// Validate the reset data
eventActivityDoneCount := 0
for eventActivityDoneCount < 1 {
select {
case res := <-ch:
switch res.Type {
case EventActivityFilteringDone:
var payload FilterResponse
err := json.Unmarshal([]byte(res.Message), &payload)
require.NoError(t, err)
require.Equal(t, ErrorCodeSuccess, payload.ErrorCode)
require.Equal(t, 3, len(payload.Activities))
eventActivityDoneCount := validateFilteringDone(t, ch, 3, func(payload FilterResponse) {
require.True(t, payload.Activities[0].isNew)
require.False(t, payload.Activities[1].isNew)
require.False(t, payload.Activities[2].isNew)
tx := payload.Activities[0]
require.Equal(t, PendingTransactionPT, tx.payloadType)
// Check the new transaction data
newTx := payload.Activities[0]
require.Equal(t, PendingTransactionPT, newTx.payloadType)
// We don't keep type in the DB
require.Equal(t, (*int)(nil), tx.transferType)
require.Equal(t, SendAT, tx.activityType)
require.Equal(t, PendingAS, tx.activityStatus)
require.Equal(t, exp.ChainID, tx.transaction.ChainID)
require.Equal(t, exp.ChainID, *tx.chainIDOut)
require.Equal(t, (*common.ChainID)(nil), tx.chainIDIn)
require.Equal(t, exp.Hash, tx.transaction.Hash)
require.Equal(t, (*int)(nil), newTx.transferType)
require.Equal(t, SendAT, newTx.activityType)
require.Equal(t, PendingAS, newTx.activityStatus)
require.Equal(t, exp.ChainID, newTx.transaction.ChainID)
require.Equal(t, exp.ChainID, *newTx.chainIDOut)
require.Equal(t, (*common.ChainID)(nil), newTx.chainIDIn)
require.Equal(t, exp.Hash, newTx.transaction.Hash)
// Pending doesn't have address as part of identity
require.Equal(t, eth.Address{}, tx.transaction.Address)
require.Equal(t, exp.From, *tx.sender)
require.Equal(t, exp.To, *tx.recipient)
require.Equal(t, 0, exp.Value.Int.Cmp((*big.Int)(tx.amountOut)))
require.Equal(t, exp.Timestamp, uint64(tx.timestamp))
require.Equal(t, exp.Symbol, *tx.symbolOut)
require.Equal(t, (*string)(nil), tx.symbolIn)
require.Equal(t, eth.Address{}, newTx.transaction.Address)
require.Equal(t, exp.From, *newTx.sender)
require.Equal(t, exp.To, *newTx.recipient)
require.Equal(t, 0, exp.Value.Int.Cmp((*big.Int)(newTx.amountOut)))
require.Equal(t, exp.Timestamp, uint64(newTx.timestamp))
require.Equal(t, exp.Symbol, *newTx.symbolOut)
require.Equal(t, (*string)(nil), newTx.symbolIn)
require.Equal(t, &Token{
TokenType: Native,
ChainID: 5,
}, tx.tokenOut)
require.Equal(t, (*Token)(nil), tx.tokenIn)
require.Equal(t, (*eth.Address)(nil), tx.contractAddress)
eventActivityDoneCount++
}
case <-time.NewTimer(1 * time.Second).C:
require.Fail(t, "timeout while waiting for EventActivitySessionUpdated")
}
}
}, newTx.tokenOut)
require.Equal(t, (*Token)(nil), newTx.tokenIn)
require.Equal(t, (*eth.Address)(nil), newTx.contractAddress)
// Check the order of the following transaction data
require.Equal(t, SimpleTransactionPT, payload.Activities[1].payloadType)
require.Equal(t, int64(transactionCount), payload.Activities[1].timestamp)
require.Equal(t, SimpleTransactionPT, payload.Activities[2].payloadType)
require.Equal(t, int64(transactionCount-1), payload.Activities[2].timestamp)
}, nil)
require.Equal(t, 1, pendingTransactionUpdate)
require.Equal(t, 1, filterResponseCount)
@ -403,18 +424,19 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) {
require.Equal(t, 1, eventActivityDoneCount)
}
func TestService_IncrementalUpdateFetchWindowRegression(t *testing.T) {
func TestService_IncrementalUpdateFetchWindow(t *testing.T) {
state := setupTestService(t)
defer state.close()
allAddresses, pendings, ch, cleanup := setupTransactions(t, state, 3, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: 4}})
transactionCount := 5
allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}})
defer cleanup()
sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 2)
require.Greater(t, sessionID, SessionID(0))
defer state.service.StopFilterSession(sessionID)
filterResponseCount := validateSessionUpdateEventWithPending(t, ch)
filterResponseCount := validateFilteringDone(t, ch, 2, nil, nil)
exp := pendings[0]
err := state.pendingTracker.StoreAndTrackPendingTx(&exp)
@ -426,29 +448,65 @@ func TestService_IncrementalUpdateFetchWindowRegression(t *testing.T) {
require.NoError(t, err)
// Validate the reset data
eventActivityDoneCount := 0
for eventActivityDoneCount < 1 {
select {
case res := <-ch:
switch res.Type {
case EventActivityFilteringDone:
var payload FilterResponse
err := json.Unmarshal([]byte(res.Message), &payload)
require.NoError(t, err)
require.Equal(t, ErrorCodeSuccess, payload.ErrorCode)
require.Equal(t, 2, len(payload.Activities))
eventActivityDoneCount := validateFilteringDone(t, ch, 2, func(payload FilterResponse) {
require.True(t, payload.Activities[0].isNew)
require.Equal(t, int64(transactionCount+1), payload.Activities[0].timestamp)
require.False(t, payload.Activities[1].isNew)
eventActivityDoneCount++
}
case <-time.NewTimer(1 * time.Second).C:
require.Fail(t, "timeout while waiting for EventActivitySessionUpdated")
}
}
require.Equal(t, int64(transactionCount), payload.Activities[1].timestamp)
}, nil)
require.Equal(t, 1, pendingTransactionUpdate)
require.Equal(t, 1, filterResponseCount)
require.Equal(t, 1, sessionUpdatesCount)
require.Equal(t, 1, eventActivityDoneCount)
err = state.service.GetMoreForFilterSession(sessionID, 2)
require.NoError(t, err)
eventActivityDoneCount = validateFilteringDone(t, ch, 2, func(payload FilterResponse) {
require.False(t, payload.Activities[0].isNew)
require.Equal(t, int64(transactionCount-1), payload.Activities[0].timestamp)
require.False(t, payload.Activities[1].isNew)
require.Equal(t, int64(transactionCount-2), payload.Activities[1].timestamp)
}, common.NewAndSet(extraExpect{common.NewAndSet(2), nil}))
require.Equal(t, 1, eventActivityDoneCount)
}
func TestService_IncrementalUpdateFetchWindowNoReset(t *testing.T) {
state := setupTestService(t)
defer state.close()
transactionCount := 5
allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, []transactions.TestTxSummary{{DontConfirm: true, Timestamp: transactionCount + 1}})
defer cleanup()
sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 2)
require.Greater(t, sessionID, SessionID(0))
defer state.service.StopFilterSession(sessionID)
filterResponseCount := validateFilteringDone(t, ch, 2, func(payload FilterResponse) {
require.Equal(t, int64(transactionCount), payload.Activities[0].timestamp)
require.Equal(t, int64(transactionCount-1), payload.Activities[1].timestamp)
}, nil)
exp := pendings[0]
err := state.pendingTracker.StoreAndTrackPendingTx(&exp)
require.NoError(t, err)
pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount)
require.Equal(t, 1, pendingTransactionUpdate)
require.Equal(t, 1, filterResponseCount)
require.Equal(t, 1, sessionUpdatesCount)
err = state.service.GetMoreForFilterSession(sessionID, 2)
require.NoError(t, err)
// Validate that client doesn't anything of the internal state
eventActivityDoneCount := validateFilteringDone(t, ch, 2, func(payload FilterResponse) {
require.False(t, payload.Activities[0].isNew)
require.Equal(t, int64(transactionCount-2), payload.Activities[0].timestamp)
require.False(t, payload.Activities[1].isNew)
require.Equal(t, int64(transactionCount-3), payload.Activities[1].timestamp)
}, common.NewAndSet(extraExpect{common.NewAndSet(2), nil}))
require.Equal(t, 1, eventActivityDoneCount)
}

View File

@ -53,7 +53,7 @@ type Session struct {
// model is a mirror of the data model presentation has (sent by EventActivityFilteringDone)
model []EntryIdentity
// new holds the new entries until user requests update
// new holds the new entries until user requests update by calling ResetFilterSession
new []EntryIdentity
}
@ -72,7 +72,7 @@ type fullFilterParams struct {
filter Filter
}
func (s *Service) internalFilter(f fullFilterParams, offset int, count int, processResults func(entries []Entry)) {
func (s *Service) internalFilter(f fullFilterParams, offset int, count int, processResults func(entries []Entry) (offsetOverride int)) {
s.scheduler.Enqueue(int32(f.sessionID), filterTask, func(ctx context.Context) (interface{}, error) {
activities, err := getActivityEntries(ctx, s.getDeps(), f.addresses, f.allAddresses, f.chainIDs, f.filter, offset, count)
return activities, err
@ -86,11 +86,10 @@ func (s *Service) internalFilter(f fullFilterParams, offset int, count int, proc
} else if err == nil {
activities := result.([]Entry)
res.Activities = activities
res.Offset = 0
res.HasMore = len(activities) == count
res.ErrorCode = ErrorCodeSuccess
processResults(activities)
res.Offset = processResults(activities)
}
int32SessionID := int32(f.sessionID)
@ -132,13 +131,17 @@ func (s *Service) StartFilterSession(addresses []eth.Address, allAddresses bool,
}
s.sessionsRWMutex.Unlock()
s.internalFilter(fullFilterParams{
s.internalFilter(
fullFilterParams{
sessionID: sessionID,
addresses: addresses,
allAddresses: allAddresses,
chainIDs: chainIDs,
filter: filter,
}, 0, firstPageCount, func(entries []Entry) {
},
0,
firstPageCount,
func(entries []Entry) (offset int) {
// Mirror identities for update use
s.sessionsRWMutex.Lock()
defer s.sessionsRWMutex.Unlock()
@ -151,7 +154,9 @@ func (s *Service) StartFilterSession(addresses []eth.Address, allAddresses bool,
id: a.id,
})
}
})
return 0
},
)
return sessionID
}
@ -162,13 +167,17 @@ func (s *Service) ResetFilterSession(id SessionID, firstPageCount int) error {
return errors.New("session not found")
}
s.internalFilter(fullFilterParams{
s.internalFilter(
fullFilterParams{
sessionID: id,
addresses: session.addresses,
allAddresses: session.allAddresses,
chainIDs: session.chainIDs,
filter: session.filter,
}, 0, firstPageCount, func(entries []Entry) {
},
0,
firstPageCount,
func(entries []Entry) (offset int) {
s.sessionsRWMutex.Lock()
defer s.sessionsRWMutex.Unlock()
@ -189,12 +198,48 @@ func (s *Service) ResetFilterSession(id SessionID, firstPageCount int) error {
id: a.id,
})
}
})
return 0
},
)
return nil
}
// TODO #12120: extend the session based API
//func (s *Service) GetMoreForFilterSession(count int) {}
func (s *Service) GetMoreForFilterSession(id SessionID, pageCount int) error {
session, found := s.sessions[id]
if !found {
return errors.New("session not found")
}
prevModelLen := len(session.model)
s.internalFilter(
fullFilterParams{
sessionID: id,
addresses: session.addresses,
allAddresses: session.allAddresses,
chainIDs: session.chainIDs,
filter: session.filter,
},
prevModelLen+len(session.new),
pageCount,
func(entries []Entry) (offset int) {
s.sessionsRWMutex.Lock()
defer s.sessionsRWMutex.Unlock()
// Mirror client identities for checking updates
for _, a := range entries {
session.model = append(session.model, EntryIdentity{
payloadType: a.payloadType,
transaction: a.transaction,
id: a.id,
})
}
// Overwrite the offset to account for new entries
return prevModelLen
},
)
return nil
}
// subscribeToEvents should be called with sessionsRWMutex locked for writing
func (s *Service) subscribeToEvents() {
@ -203,34 +248,6 @@ func (s *Service) subscribeToEvents() {
go s.processEvents()
}
// func (s *Service) processEvents() {
// for event := range s.ch {
// if event.Type == transactions.EventPendingTransactionUpdate {
// var p transactions.PendingTxUpdatePayload
// err := json.Unmarshal([]byte(event.Message), &p)
// if err != nil {
// log.Error("Error unmarshalling PendingTxUpdatePayload", "error", err)
// continue
// }
// for id := range s.sessions {
// s.sessionsRWMutex.RLock()
// pTx, pass := s.checkFilterForPending(s.sessions[id], p.TxIdentity)
// if pass {
// s.sessionsRWMutex.RUnlock()
// s.sessionsRWMutex.Lock()
// addOnTop(s.sessions[id], p.TxIdentity)
// s.sessionsRWMutex.Unlock()
// // TODO #12120: can't send events from an event handler
// go notify(s.eventFeed, id, *pTx)
// } else {
// s.sessionsRWMutex.RUnlock()
// }
// }
// }
// }
// }
// TODO #12120: check that it exits on channel close
func (s *Service) processEvents() {
for event := range s.ch {
@ -276,60 +293,6 @@ func (s *Service) processEvents() {
}
}
// // checkFilterForPending should be called with sessionsRWMutex locked for reading
// func (s *Service) checkFilterForPending(session *Session, id transactions.TxIdentity) (tr *transactions.PendingTransaction, pass bool) {
// allChains := len(session.chainIDs) == 0
// if !allChains {
// _, found := slices.BinarySearch(session.chainIDs, id.ChainID)
// if !found {
// return nil, false
// }
// }
// tr, err := s.pendingTracker.GetPendingEntry(id.ChainID, id.Hash)
// if err != nil {
// log.Error("Error getting pending entry", "error", err)
// return nil, false
// }
// if !session.allAddresses {
// _, found := slices.BinarySearchFunc(session.addresses, tr.From, func(a eth.Address, b eth.Address) int {
// // TODO #12120: optimize this
// if a.Hex() < b.Hex() {
// return -1
// }
// if a.Hex() > b.Hex() {
// return 1
// }
// return 0
// })
// if !found {
// return nil, false
// }
// }
// fl := session.filter
// if fl.Period.StartTimestamp != NoLimitTimestampForPeriod || fl.Period.EndTimestamp != NoLimitTimestampForPeriod {
// ts := int64(tr.Timestamp)
// if ts < fl.Period.StartTimestamp || ts > fl.Period.EndTimestamp {
// return nil, false
// }
// }
// // TODO #12120 check filter
// // Types []Type `json:"types"`
// // Statuses []Status `json:"statuses"`
// // CounterpartyAddresses []eth.Address `json:"counterpartyAddresses"`
// // // Tokens
// // Assets []Token `json:"assets"`
// // Collectibles []Token `json:"collectibles"`
// // FilterOutAssets bool `json:"filterOutAssets"`
// // FilterOutCollectibles bool `json:"filterOutCollectibles"`
// return tr, true
// }
func notify(eventFeed *event.Feed, id SessionID, hasNewEntries bool) {
payload := SessionUpdate{}
if hasNewEntries {
@ -356,9 +319,7 @@ func (s *Service) StopFilterSession(id SessionID) {
// Cancel any pending or ongoing task
s.scheduler.Enqueue(int32(id), filterTask, func(ctx context.Context) (interface{}, error) {
return nil, nil
}, func(result interface{}, taskType async.TaskType, err error) {
// Ignore result
})
}, func(result interface{}, taskType async.TaskType, err error) {})
}
func (s *Service) getActivityDetailsAsync(requestID int32, entries []Entry) {

View File

@ -4,7 +4,8 @@ import (
"reflect"
"testing"
"github.com/ethereum/go-ethereum/common"
eth "github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/services/wallet/transfer"
)
@ -13,8 +14,8 @@ func TestFindUpdates(t *testing.T) {
txIds := []transfer.TransactionIdentity{
transfer.TransactionIdentity{
ChainID: 1,
Hash: common.HexToHash("0x1234"),
Address: common.HexToAddress("0x1234"),
Hash: eth.HexToHash("0x1234"),
Address: eth.HexToAddress("0x1234"),
},
}

View File

@ -602,12 +602,18 @@ func (api *API) StartActivityFilterSession(addresses []common.Address, allAddres
return api.s.activity.StartFilterSession(addresses, allAddresses, chainIDs, filter, firstPageCount), nil
}
func (api *API) ResetFilterSession(id activity.SessionID, firstPageCount int) error {
log.Debug("wallet.api.ResetFilterSession", "id", id, "firstPageCount", firstPageCount)
func (api *API) ResetActivityFilterSession(id activity.SessionID, firstPageCount int) error {
log.Debug("wallet.api.ResetActivityFilterSession", "id", id, "firstPageCount", firstPageCount)
return api.s.activity.ResetFilterSession(id, firstPageCount)
}
func (api *API) GetMoreForActivityFilterSession(id activity.SessionID, pageCount int) error {
log.Debug("wallet.api.GetMoreForActivityFilterSession", "id", id, "pageCount", pageCount)
return api.s.activity.GetMoreForFilterSession(id, pageCount)
}
func (api *API) StopActivityFilterSession(id activity.SessionID) {
log.Debug("wallet.api.StopActivityFilterSession", "id", id)