From ae9b697eda5e2fde120a0b1bf629960deb8ec1ea Mon Sep 17 00:00:00 2001 From: Stefan Date: Thu, 8 Feb 2024 22:55:33 -0300 Subject: [PATCH] feat(wallet) implement mixed incremental updates for activity filter Refactor the activity filter session API to account for the new structure Also: - Refactor test helpers to mock for different chains with mixed answers - Test implementation Updates status-desktop #12120 --- services/wallet/activity/TODO.md | 103 -------------------- services/wallet/activity/service.go | 1 - services/wallet/activity/service_test.go | 115 ++++++++++++++++++----- services/wallet/activity/session.go | 113 +++++++++++++--------- services/wallet/walletevent/events.go | 14 +++ transactions/pendingtxtracker.go | 1 - transactions/testhelpers.go | 103 ++++++++++++++------ 7 files changed, 253 insertions(+), 197 deletions(-) delete mode 100644 services/wallet/activity/TODO.md diff --git a/services/wallet/activity/TODO.md b/services/wallet/activity/TODO.md deleted file mode 100644 index d1dc26ae8..000000000 --- a/services/wallet/activity/TODO.md +++ /dev/null @@ -1,103 +0,0 @@ -# Provide dynamic activity updates - -Task: https://github.com/status-im/status-desktop/issues/12120 - -## Intro - -In the current approach only static paginated filtering is possible because the filtering is done in SQL - -The updated requirements need to support dynamic updates of the current visualized filter - -## Plan - -- [ ] Required common (runtime/SQL) infrastructure - - [-] Refactor into a session based filter - - [-] Keep a mirror of identities for session - - [-] Capture events (new downloaded and pending first) - - [-] Have the simplest filter to handle new and updated and emit wallet event - - [ ] Handle update filter events in UX and alter the model (add/remove) -- [ ] Asses how the runtime filter grows in complexity/risk -- [ ] Quick prototype of SQL only filter if still make sense -- [ ] Refactor the async handling to fit the session based better (use channels and goroutine) - -## How to - -I see two ways: - -- Keep a **runtime** (go/nim) dynamic in memory filter that is in sync with the SQL filter and use the filter to process transactions updates and propagate to the current visualized model - - The filter will push changes to the in memory model based on the sorting and filtering criteria - - If the filter is completely in sync withe the SQL one, then the dynamic updates to the model should have the same content as fetched from scratch from the DB - - *Advantages* - - Less memory and performance requirements - - *Disadvantages* - - Two sources of truth for the filter - - With tests for each event this can be mitigated - - Complexity around the multi-transaction/sub-transaction relation - - If we miss doing equivalent changes in bot filters (SQL and runtime) the filter might not be in sync with the SQL one and have errors in update -- **Refresh SQL filter** on every transaction (or bulk) update to DB and compare with the current visualized filter to extract differences and push as change notifications - - This approach is more expensive in terms of memory and performance but will use only one source of truth implementation - - This way we know for sure that the updated model is in sync with a newly fetched one - - *Advantages* - - Less complexity and less risk to be out of sync with the SQL filter - - *Disadvantages* - - More memory and performance requirements - - The real improvement will be to do the postponed refactoring of the activity in DB - -## Requirements - -Expected filter states to be addressed - -- Filter is set -- No Filter -- Filter is cleared - - How about if only partially cleared? - -Expected dynamic events - -- **New transactions** - - Pending - - Downloaded (external) - - Multi-transactions? -- **Transaction changed state** - - Pending to confirmed (new transaction/removed transaction) - -Filter criteria - -- time interval: start-end -- activity type (send/receive/buy/swap/bridge/contract_deploy/mint) -- status (pending/failed/confirmed/finalized) -- addresses -- tokens -- multi-transaction filtering transaction - -## Implementation - -### SQL filter - -For new events - -- keep a mirror of identities on status-go side (optional session based) -- on update events fetch identities and check against the mirror if any is new -- for new entries send the notification with the transaction details -- keep pending changes (not added) - - remove entries that were processed for this session - -For update? - -- check if entry is in the mirror and propagate update event - -### Mirror filter - -For new events - -- keep a mirror of identities -- on update events pass them through the filter and if they pass send updates - - the filter checks criteria and available mirror interval to dismiss from mirror -- sub-transactions challenge - - TODO -- token challenges - - TODO - -For update? - -- check if entry is in the mirror and propagate update event \ No newline at end of file diff --git a/services/wallet/activity/service.go b/services/wallet/activity/service.go index 238094b5e..b59af4e3d 100644 --- a/services/wallet/activity/service.go +++ b/services/wallet/activity/service.go @@ -70,7 +70,6 @@ type Service struct { // sessionsRWMutex is used to protect all sessions related members sessionsRWMutex sync.RWMutex - // TODO #12120: sort out session dependencies pendingTracker *transactions.PendingTxTracker } diff --git a/services/wallet/activity/service_test.go b/services/wallet/activity/service_test.go index fbcc38c7c..57e2ef63a 100644 --- a/services/wallet/activity/service_test.go +++ b/services/wallet/activity/service_test.go @@ -3,7 +3,6 @@ package activity import ( "context" "database/sql" - "encoding/json" "math/big" "testing" "time" @@ -25,6 +24,8 @@ import ( "github.com/stretchr/testify/require" ) +const shouldNotWaitTimeout = 19999 * time.Second + // mockCollectiblesManager implements the collectibles.ManagerInterface type mockCollectiblesManager struct { mock.Mock @@ -179,17 +180,16 @@ func TestService_UpdateCollectibleInfo(t *testing.T) { case res := <-ch: switch res.Type { case EventActivityFilteringDone: - var payload FilterResponse - err := json.Unmarshal([]byte(res.Message), &payload) + payload, err := walletevent.GetPayload[FilterResponse](res) require.NoError(t, err) require.Equal(t, ErrorCodeSuccess, payload.ErrorCode) require.Equal(t, 3, len(payload.Activities)) filterResponseCount++ case EventActivityFilteringUpdate: - err := json.Unmarshal([]byte(res.Message), &updates) + err := walletevent.ExtractPayload(res, &updates) require.NoError(t, err) } - case <-time.NewTimer(1 * time.Second).C: + case <-time.NewTimer(shouldNotWaitTimeout).C: require.Fail(t, "timeout while waiting for event") } } @@ -230,8 +230,7 @@ func TestService_UpdateCollectibleInfo_Error(t *testing.T) { case res := <-ch: switch res.Type { case EventActivityFilteringDone: - var payload FilterResponse - err := json.Unmarshal([]byte(res.Message), &payload) + payload, err := walletevent.GetPayload[FilterResponse](res) require.NoError(t, err) require.Equal(t, ErrorCodeSuccess, payload.ErrorCode) require.Equal(t, 2, len(payload.Activities)) @@ -255,13 +254,17 @@ func setupTransactions(t *testing.T, state testState, txCount int, testTxs []tra sub := state.eventFeed.Subscribe(ch) pendings = transactions.MockTestTransactions(t, state.chainClient, testTxs) + for _, p := range pendings { + allAddresses = append(allAddresses, p.From, p.To) + } txs, fromTrs, toTrs := transfer.GenerateTestTransfers(t, state.service.db, len(pendings), txCount) for i := range txs { transfer.InsertTestTransfer(t, state.service.db, txs[i].To, &txs[i]) } - allAddresses = append(append(fromTrs, toTrs...), pendings[0].From, pendings[0].To) + allAddresses = append(append(allAddresses, fromTrs...), toTrs...) + state.tokenMock.On("LookupTokenIdentity", mock.Anything, mock.Anything, mock.Anything).Return( &token.Token{ ChainID: 5, @@ -283,24 +286,35 @@ func setupTransactions(t *testing.T, state testState, txCount int, testTxs []tra } } -func validateSessionUpdateEvent(t *testing.T, ch chan walletevent.Event, filterResponseCount *int) (pendingTransactionUpdate, sessionUpdatesCount int) { - for sessionUpdatesCount < 1 { +func getValidateSessionUpdateHasNewOnTopFn(t *testing.T) func(payload SessionUpdate) bool { + return func(payload SessionUpdate) bool { + require.NotNil(t, payload.HasNewOnTop) + require.True(t, *payload.HasNewOnTop) + return false + } +} + +// validateSessionUpdateEvent expects will give up early if checkPayloadFn return true and not wait for expectCount +func validateSessionUpdateEvent(t *testing.T, ch chan walletevent.Event, filterResponseCount *int, expectCount int, checkPayloadFn func(payload SessionUpdate) bool) (pendingTransactionUpdate, sessionUpdatesCount int) { + for sessionUpdatesCount < expectCount { select { case res := <-ch: switch res.Type { case transactions.EventPendingTransactionUpdate: pendingTransactionUpdate++ case EventActivitySessionUpdated: - var payload SessionUpdate - err := json.Unmarshal([]byte(res.Message), &payload) + payload, err := walletevent.GetPayload[SessionUpdate](res) require.NoError(t, err) - require.NotNil(t, payload.HasNewEntries) - require.True(t, *payload.HasNewEntries) + + if checkPayloadFn != nil && checkPayloadFn(*payload) { + return + } + sessionUpdatesCount++ case EventActivityFilteringDone: (*filterResponseCount)++ } - case <-time.NewTimer(1 * time.Second).C: + case <-time.NewTimer(shouldNotWaitTimeout).C: require.Fail(t, "timeout while waiting for EventActivitySessionUpdated") } } @@ -333,8 +347,7 @@ func validateFilteringDone(t *testing.T, ch chan walletevent.Event, resCount int case res := <-ch: switch res.Type { case EventActivityFilteringDone: - var payload FilterResponse - err := json.Unmarshal([]byte(res.Message), &payload) + payload, err := walletevent.GetPayload[FilterResponse](res) require.NoError(t, err) expectOffset, expectErrorCode := getOptionalExpectations(extra) @@ -346,10 +359,10 @@ func validateFilteringDone(t *testing.T, ch chan walletevent.Event, resCount int filterResponseCount++ if checkPayloadFn != nil { - checkPayloadFn(payload) + checkPayloadFn(*payload) } } - case <-time.NewTimer(1 * time.Second).C: + case <-time.NewTimer(shouldNotWaitTimeout).C: require.Fail(t, "timeout while waiting for EventActivityFilteringDone") } } @@ -374,7 +387,8 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) { err := state.pendingTracker.StoreAndTrackPendingTx(&exp) require.NoError(t, err) - pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount) + vFn := getValidateSessionUpdateHasNewOnTopFn(t) + pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount, 1, vFn) err = state.service.ResetFilterSession(sessionID, 5) require.NoError(t, err) @@ -424,6 +438,59 @@ func TestService_IncrementalUpdateOnTop(t *testing.T) { require.Equal(t, 1, eventActivityDoneCount) } +func TestService_IncrementalUpdateMixed(t *testing.T) { + state := setupTestService(t) + defer state.close() + + transactionCount := 5 + allAddresses, pendings, ch, cleanup := setupTransactions(t, state, transactionCount, + []transactions.TestTxSummary{ + {DontConfirm: true, Timestamp: 2}, + {DontConfirm: true, Timestamp: 4}, + {DontConfirm: true, Timestamp: 6}, + }, + ) + defer cleanup() + + sessionID := state.service.StartFilterSession(allAddresses, true, allNetworksFilter(), Filter{}, 5) + require.Greater(t, sessionID, SessionID(0)) + defer state.service.StopFilterSession(sessionID) + + filterResponseCount := validateFilteringDone(t, ch, 5, nil, nil) + + for i := range pendings { + err := state.pendingTracker.StoreAndTrackPendingTx(&pendings[i]) + require.NoError(t, err) + } + + pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount, 2, func(payload SessionUpdate) bool { + require.Nil(t, payload.HasNewOnTop) + require.NotEmpty(t, payload.New) + for _, update := range payload.New { + require.True(t, update.Entry.isNew) + foundIdx := -1 + for i, pTx := range pendings { + if pTx.Hash == update.Entry.transaction.Hash && pTx.ChainID == update.Entry.transaction.ChainID { + foundIdx = i + break + } + } + require.Greater(t, foundIdx, -1, "the updated transaction should be found in the pending list") + pendings = append(pendings[:foundIdx], pendings[foundIdx+1:]...) + } + return len(pendings) == 1 + }) + + // Validate that the last one (oldest) is out of the window + require.Equal(t, 1, len(pendings)) + require.Equal(t, uint64(2), pendings[0].Timestamp) + + require.Equal(t, 3, pendingTransactionUpdate) + require.LessOrEqual(t, sessionUpdatesCount, 3) + require.Equal(t, 1, filterResponseCount) + +} + func TestService_IncrementalUpdateFetchWindow(t *testing.T) { state := setupTestService(t) defer state.close() @@ -442,7 +509,8 @@ func TestService_IncrementalUpdateFetchWindow(t *testing.T) { err := state.pendingTracker.StoreAndTrackPendingTx(&exp) require.NoError(t, err) - pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount) + vFn := getValidateSessionUpdateHasNewOnTopFn(t) + pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount, 1, vFn) err = state.service.ResetFilterSession(sessionID, 2) require.NoError(t, err) @@ -493,7 +561,8 @@ func TestService_IncrementalUpdateFetchWindowNoReset(t *testing.T) { err := state.pendingTracker.StoreAndTrackPendingTx(&exp) require.NoError(t, err) - pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount) + vFn := getValidateSessionUpdateHasNewOnTopFn(t) + pendingTransactionUpdate, sessionUpdatesCount := validateSessionUpdateEvent(t, ch, &filterResponseCount, 1, vFn) require.Equal(t, 1, pendingTransactionUpdate) require.Equal(t, 1, filterResponseCount) require.Equal(t, 1, sessionUpdatesCount) @@ -501,7 +570,7 @@ func TestService_IncrementalUpdateFetchWindowNoReset(t *testing.T) { err = state.service.GetMoreForFilterSession(sessionID, 2) require.NoError(t, err) - // Validate that client doesn't anything of the internal state + // Validate that client continue loading the next window without being affected by the internal state of new eventActivityDoneCount := validateFilteringDone(t, ch, 2, func(payload FilterResponse) { require.False(t, payload.Activities[0].isNew) require.Equal(t, int64(transactionCount-2), payload.Activities[0].timestamp) diff --git a/services/wallet/activity/session.go b/services/wallet/activity/session.go index dd51282b7..db265f222 100644 --- a/services/wallet/activity/session.go +++ b/services/wallet/activity/session.go @@ -5,8 +5,6 @@ import ( "errors" "strconv" - "golang.org/x/exp/slices" - eth "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" @@ -25,11 +23,14 @@ type EntryIdentity struct { id transfer.MultiTransactionIDType } -// func (e EntryIdentity) same(a EntryIdentity) bool { -// return a.payloadType == e.payloadType && (a.transaction == e.transaction && (a.transaction == nil || (a.transaction.ChainID == e.transaction.ChainID && -// a.transaction.Hash == e.transaction.Hash && -// a.transaction.Address == e.transaction.Address))) && a.id == e.id -// } +func (e EntryIdentity) same(a EntryIdentity) bool { + return a.payloadType == e.payloadType && + ((a.transaction == nil && e.transaction == nil) || + (a.transaction.ChainID == e.transaction.ChainID && + a.transaction.Hash == e.transaction.Hash && + a.transaction.Address == e.transaction.Address)) && + a.id == e.id +} func (e EntryIdentity) key() string { txID := nilStr @@ -57,11 +58,16 @@ type Session struct { new []EntryIdentity } +type EntryUpdate struct { + Pos int `json:"pos"` + Entry *Entry `json:"entry"` +} + // SessionUpdate payload for EventActivitySessionUpdated type SessionUpdate struct { - HasNewEntries *bool `json:"hasNewEntries,omitempty"` - Removed []EntryIdentity `json:"removed,omitempty"` - Updated []Entry `json:"updated,omitempty"` + HasNewOnTop *bool `json:"hasNewOnTop,omitempty"` + New []*EntryUpdate `json:"new,omitempty"` + Removed []EntryIdentity `json:"removed,omitempty"` } type fullFilterParams struct { @@ -102,16 +108,6 @@ func (s *Service) internalFilter(f fullFilterParams, offset int, count int, proc func (s *Service) StartFilterSession(addresses []eth.Address, allAddresses bool, chainIDs []common.ChainID, filter Filter, firstPageCount int) SessionID { sessionID := s.nextSessionID() - // TODO #12120: sort rest of the filters - // TODO #12120: prettyfy this - slices.SortFunc(addresses, func(a eth.Address, b eth.Address) bool { - return a.Hex() < b.Hex() - }) - slices.Sort(chainIDs) - slices.SortFunc(filter.CounterpartyAddresses, func(a eth.Address, b eth.Address) bool { - return a.Hex() < b.Hex() - }) - s.sessionsRWMutex.Lock() subscribeToEvents := len(s.sessions) == 0 session := &Session{ @@ -248,55 +244,79 @@ func (s *Service) subscribeToEvents() { go s.processEvents() } -// TODO #12120: check that it exits on channel close +// processEvents runs only if more than one session is active func (s *Service) processEvents() { for event := range s.ch { - // TODO #12120: process rest of the events - // TODO #12120: debounce for 1s + // TODO #12120: process rest of the events transactions.EventPendingTransactionStatusChanged, transfer.EventNewTransfers + // TODO #12120: debounce for 1s and sum all events as extraCount to be sure we don't miss any change if event.Type == transactions.EventPendingTransactionUpdate { for sessionID := range s.sessions { session := s.sessions[sessionID] - activities, err := getActivityEntries(context.Background(), s.getDeps(), session.addresses, session.allAddresses, session.chainIDs, session.filter, 0, len(session.model)) + + extraCount := 1 + fetchLen := len(session.model) + extraCount + activities, err := getActivityEntries(context.Background(), s.getDeps(), session.addresses, session.allAddresses, session.chainIDs, session.filter, 0, fetchLen) if err != nil { log.Error("Error getting activity entries", "error", err) continue } s.sessionsRWMutex.RLock() - allData := append(session.model, session.new...) + allData := append(session.new, session.model...) new, _ /*removed*/ := findUpdates(allData, activities) s.sessionsRWMutex.RUnlock() s.sessionsRWMutex.Lock() lastProcessed := -1 + onTop := true + var mixed []*EntryUpdate for i, idRes := range new { - if i-lastProcessed > 1 { - // The events are not continuous, therefore these are not on top but mixed between existing entries - break + // Detect on top + if onTop { + // mixedIdentityResult.newPos includes session.new, therefore compensate for it + if ((idRes.newPos - len(session.new)) - lastProcessed) > 1 { + // From now on the events are not on top and continuous but mixed between existing entries + onTop = false + mixed = make([]*EntryUpdate, 0, len(new)-i) + } + lastProcessed = idRes.newPos } - lastProcessed = idRes.newPos - // TODO #12120: make it more generic to follow the detection function - // TODO #12120: hold the first few and only send mixed and removed - if session.new == nil { - session.new = make([]EntryIdentity, 0, len(new)) - } - session.new = append(session.new, idRes.id) - } - // TODO #12120: mixed + if onTop { + if session.new == nil { + session.new = make([]EntryIdentity, 0, len(new)) + } + session.new = append(session.new, idRes.id) + } else { + modelPos := idRes.newPos - len(session.new) + entry := activities[idRes.newPos] + entry.isNew = true + mixed = append(mixed, &EntryUpdate{ + Pos: modelPos, + Entry: &entry, + }) + // Insert in session model at modelPos index + session.model = append(session.model[:modelPos], append([]EntryIdentity{{payloadType: entry.payloadType, transaction: entry.transaction, id: entry.id}}, session.model[modelPos:]...)...) + } + } s.sessionsRWMutex.Unlock() - go notify(s.eventFeed, sessionID, len(session.new) > 0) + if len(session.new) > 0 || len(mixed) > 0 { + go notify(s.eventFeed, sessionID, len(session.new) > 0, mixed) + } } } } } -func notify(eventFeed *event.Feed, id SessionID, hasNewEntries bool) { - payload := SessionUpdate{} - if hasNewEntries { - payload.HasNewEntries = &hasNewEntries +func notify(eventFeed *event.Feed, id SessionID, hasNewOnTop bool, mixed []*EntryUpdate) { + payload := SessionUpdate{ + New: mixed, + } + + if hasNewOnTop { + payload.HasNewOnTop = &hasNewOnTop } sendResponseEvent(eventFeed, (*int32)(&id), EventActivitySessionUpdated, payload, nil) @@ -305,6 +325,8 @@ func notify(eventFeed *event.Feed, id SessionID, hasNewEntries bool) { // unsubscribeFromEvents should be called with sessionsRWMutex locked for writing func (s *Service) unsubscribeFromEvents() { s.subscriptions.Unsubscribe() + close(s.ch) + s.ch = nil s.subscriptions = nil } @@ -369,6 +391,9 @@ func entriesToMap(entries []Entry) map[string]Entry { // // implementation assumes the order of each identity doesn't change from old state (identities) and new state (updated); we have either add or removed. func findUpdates(identities []EntryIdentity, updated []Entry) (new []mixedIdentityResult, removed []EntryIdentity) { + if len(updated) == 0 { + return + } idsMap := entryIdsToMap(identities) updatedMap := entriesToMap(updated) @@ -381,6 +406,10 @@ func findUpdates(identities []EntryIdentity, updated []Entry) (new []mixedIdenti id: id, }) } + + if len(identities) > 0 && entry.getIdentity().same(identities[len(identities)-1]) { + break + } } // Account for new entries diff --git a/services/wallet/walletevent/events.go b/services/wallet/walletevent/events.go index f3d931088..e4e4986f9 100644 --- a/services/wallet/walletevent/events.go +++ b/services/wallet/walletevent/events.go @@ -1,6 +1,7 @@ package walletevent import ( + "encoding/json" "math/big" "strings" @@ -31,3 +32,16 @@ type Event struct { // For Internal events only, not serialized EventParams interface{} } + +func GetPayload[T any](e Event) (*T, error) { + var payload T + err := json.Unmarshal([]byte(e.Message), &payload) + if err != nil { + return nil, err + } + return &payload, nil +} + +func ExtractPayload[T any](e Event, payload *T) error { + return json.Unmarshal([]byte(e.Message), payload) +} diff --git a/transactions/pendingtxtracker.go b/transactions/pendingtxtracker.go index fd6fc6163..d1f480f13 100644 --- a/transactions/pendingtxtracker.go +++ b/transactions/pendingtxtracker.go @@ -55,7 +55,6 @@ const ( Keep AutoDeleteType = false ) -// TODO #12120: unify it with TransactionIdentity type TxIdentity struct { ChainID common.ChainID `json:"chainId"` Hash eth.Hash `json:"hash"` diff --git a/transactions/testhelpers.go b/transactions/testhelpers.go index f4481f95f..14e2ff921 100644 --- a/transactions/testhelpers.go +++ b/transactions/testhelpers.go @@ -82,6 +82,27 @@ func GenerateTestPendingTransactions(start int, count int) []PendingTransaction return txs } +// groupSliceInMap groups a slice of S into a map[K][]N using the getKeyValue function to extract the key and new value for each entry +func groupSliceInMap[S any, K comparable, N any](s []S, getKeyValue func(entry S, i int) (K, N)) map[K][]N { + m := make(map[K][]N) + for i, x := range s { + k, v := getKeyValue(x, i) + m[k] = append(m[k], v) + } + return m +} + +func keysInMap[K comparable, V any](m map[K]V) (res []K) { + if len(m) > 0 { + res = make([]K, 0, len(m)) + } + + for k := range m { + res = append(res, k) + } + return +} + type TestTxSummary struct { failStatus bool DontConfirm bool @@ -89,30 +110,53 @@ type TestTxSummary struct { Timestamp int } +type summaryTxPair struct { + summary TestTxSummary + tx PendingTransaction + answered bool +} + func MockTestTransactions(t *testing.T, chainClient *MockChainClient, testTxs []TestTxSummary) []PendingTransaction { - txs := GenerateTestPendingTransactions(0, len(testTxs)) - - for txIdx := range txs { - tx := &txs[txIdx] - if testTxs[txIdx].Timestamp > 0 { - tx.Timestamp = uint64(testTxs[txIdx].Timestamp) + genTxs := GenerateTestPendingTransactions(0, len(testTxs)) + for i, tx := range testTxs { + if tx.Timestamp > 0 { + genTxs[i].Timestamp = uint64(tx.Timestamp) } + } - // Mock the first call to getTransactionByHash - chainClient.SetAvailableClients([]common.ChainID{tx.ChainID}) - cl := chainClient.Clients[tx.ChainID] + grouped := groupSliceInMap(genTxs, func(tx PendingTransaction, i int) (common.ChainID, summaryTxPair) { + return tx.ChainID, summaryTxPair{ + summary: testTxs[i], + tx: tx, + } + }) + + chains := keysInMap(grouped) + chainClient.SetAvailableClients(chains) + + for chainID, chainSummaries := range grouped { + // Mock the one call to getTransactionReceipt + // It is expected that pending transactions manager will call out of order, therefore match based on hash + cl := chainClient.Clients[chainID] call := cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { - ok := len(b) == len(testTxs) - for i := range b { - ok = ok && b[i].Method == GetTransactionReceiptRPCName && b[i].Args[0] == tx.Hash + if len(b) > len(chainSummaries) { + return false } - return ok + for i := range b { + for _, sum := range chainSummaries { + tx := &sum.tx + if sum.answered { + continue + } + require.Equal(t, GetTransactionReceiptRPCName, b[i].Method) + if tx.Hash == b[i].Args[0].(eth.Hash) { + sum.answered = true + return true + } + } + } + return false })).Return(nil) - if testTxs[txIdx].DontConfirm { - call = call.Times(0) - } else { - call = call.Once() - } call.Run(func(args mock.Arguments) { elems := args.Get(1).([]rpc.BatchElem) @@ -121,19 +165,24 @@ func MockTestTransactions(t *testing.T, chainClient *MockChainClient, testTxs [] require.True(t, ok) require.NotNil(t, receiptWrapper) // Simulate parsing of eth_getTransactionReceipt response - if !testTxs[i].DontConfirm { - status := types.ReceiptStatusSuccessful - if testTxs[i].failStatus { - status = types.ReceiptStatusFailed - } + for _, sum := range chainSummaries { + tx := &sum.tx + if tx.Hash == elems[i].Args[0].(eth.Hash) { + if !sum.summary.DontConfirm { + status := types.ReceiptStatusSuccessful + if sum.summary.failStatus { + status = types.ReceiptStatusFailed + } - receiptWrapper.Receipt = &types.Receipt{ - BlockNumber: new(big.Int).SetUint64(1), - Status: status, + receiptWrapper.Receipt = &types.Receipt{ + BlockNumber: new(big.Int).SetUint64(1), + Status: status, + } + } } } } }) } - return txs + return genTxs }