fix(wallet) fix pending transactions notification

Also, add regression test.
This commit is contained in:
Stefan 2023-08-30 17:14:57 +01:00 committed by Stefan Dunca
parent 0f58d01cc4
commit 71800a19f1
8 changed files with 292 additions and 124 deletions

View File

@ -422,7 +422,7 @@ func (b *StatusNode) ensService(timesource func() time.Time) *ens.Service {
func (b *StatusNode) pendingTrackerService(walletFeed *event.Feed) *transactions.PendingTxTracker { func (b *StatusNode) pendingTrackerService(walletFeed *event.Feed) *transactions.PendingTxTracker {
if b.pendingTracker == nil { if b.pendingTracker == nil {
b.pendingTracker = transactions.NewPendingTxTracker(b.walletDB, b.rpcClient, b.rpcFiltersSrvc, walletFeed) b.pendingTracker = transactions.NewPendingTxTracker(b.walletDB, b.rpcClient, b.rpcFiltersSrvc, walletFeed, transactions.PendingCheckInterval)
} }
return b.pendingTracker return b.pendingTracker
} }

View File

@ -136,12 +136,6 @@ func (e *Entry) UnmarshalJSON(data []byte) error {
return nil return nil
} }
func newAndSet[T any](v T) *T {
res := new(T)
*res = v
return res
}
func newActivityEntryWithPendingTransaction(transaction *transfer.TransactionIdentity, timestamp int64, activityType Type, activityStatus Status) Entry { func newActivityEntryWithPendingTransaction(transaction *transfer.TransactionIdentity, timestamp int64, activityType Type, activityStatus Status) Entry {
return newActivityEntryWithTransaction(true, transaction, timestamp, activityType, activityStatus) return newActivityEntryWithTransaction(true, transaction, timestamp, activityType, activityStatus)
} }
@ -848,11 +842,11 @@ func getActivityEntries(ctx context.Context, deps FilterDependencies, addresses
// Extract tokens // Extract tokens
if fromTokenCode.Valid { if fromTokenCode.Valid {
entry.tokenOut = deps.tokenFromSymbol(outChainID, fromTokenCode.String) entry.tokenOut = deps.tokenFromSymbol(outChainID, fromTokenCode.String)
entry.symbolOut = newAndSet(fromTokenCode.String) entry.symbolOut = common.NewAndSet(fromTokenCode.String)
} }
if toTokenCode.Valid { if toTokenCode.Valid {
entry.tokenIn = deps.tokenFromSymbol(inChainID, toTokenCode.String) entry.tokenIn = deps.tokenFromSymbol(inChainID, toTokenCode.String)
entry.symbolIn = newAndSet(toTokenCode.String) entry.symbolIn = common.NewAndSet(toTokenCode.String)
} }
// Complete the data // Complete the data
@ -987,13 +981,13 @@ func lookupAndFillInTokens(deps FilterDependencies, tokenOut *Token, tokenIn *To
if tokenOut != nil { if tokenOut != nil {
symbol := deps.tokenSymbol(*tokenOut) symbol := deps.tokenSymbol(*tokenOut)
if len(symbol) > 0 { if len(symbol) > 0 {
symbolOut = newAndSet(symbol) symbolOut = common.NewAndSet(symbol)
} }
} }
if tokenIn != nil { if tokenIn != nil {
symbol := deps.tokenSymbol(*tokenIn) symbol := deps.tokenSymbol(*tokenIn)
if len(symbol) > 0 { if len(symbol) > 0 {
symbolIn = newAndSet(symbol) symbolIn = common.NewAndSet(symbol)
} }
} }
return symbolOut, symbolIn return symbolOut, symbolIn

View File

@ -225,7 +225,7 @@ func TestGetActivityEntriesAll(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &td.tr1.TestTransaction), tokenOut: TTrToToken(t, &td.tr1.TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &td.tr1.From, sender: &td.tr1.From,
recipient: &td.tr1.To, recipient: &td.tr1.To,
@ -244,7 +244,7 @@ func TestGetActivityEntriesAll(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &td.pendingTr.TestTransaction), tokenOut: TTrToToken(t, &td.pendingTr.TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &td.pendingTr.From, sender: &td.pendingTr.From,
recipient: &td.pendingTr.To, recipient: &td.pendingTr.To,
@ -263,8 +263,8 @@ func TestGetActivityEntriesAll(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)), amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)),
tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken), tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken),
tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken), tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken),
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: newAndSet("DAI"), symbolIn: common.NewAndSet("DAI"),
sender: &td.multiTx1.FromAddress, sender: &td.multiTx1.FromAddress,
recipient: &td.multiTx1.ToAddress, recipient: &td.multiTx1.ToAddress,
}, entries[1]) }, entries[1])
@ -277,8 +277,8 @@ func TestGetActivityEntriesAll(t *testing.T) {
activityStatus: PendingAS, activityStatus: PendingAS,
amountOut: (*hexutil.Big)(big.NewInt(td.multiTx2.FromAmount)), amountOut: (*hexutil.Big)(big.NewInt(td.multiTx2.FromAmount)),
amountIn: (*hexutil.Big)(big.NewInt(td.multiTx2.ToAmount)), amountIn: (*hexutil.Big)(big.NewInt(td.multiTx2.ToAmount)),
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: newAndSet("SNT"), symbolIn: common.NewAndSet("SNT"),
tokenOut: tokenFromSymbol(nil, td.multiTx2.FromToken), tokenOut: tokenFromSymbol(nil, td.multiTx2.FromToken),
tokenIn: tokenFromSymbol(nil, td.multiTx2.ToToken), tokenIn: tokenFromSymbol(nil, td.multiTx2.ToToken),
sender: &td.multiTx2.FromAddress, sender: &td.multiTx2.FromAddress,
@ -366,7 +366,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[5].TestTransaction), tokenOut: TTrToToken(t, &trs[5].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: nil, symbolIn: nil,
sender: &trs[5].From, sender: &trs[5].From,
recipient: &trs[5].To, recipient: &trs[5].To,
@ -385,8 +385,8 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)), amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)),
tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken), tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken),
tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken), tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken),
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: newAndSet("DAI"), symbolIn: common.NewAndSet("DAI"),
sender: &td.multiTx1.FromAddress, sender: &td.multiTx1.FromAddress,
recipient: &td.multiTx1.ToAddress, recipient: &td.multiTx1.ToAddress,
chainIDOut: nil, chainIDOut: nil,
@ -412,7 +412,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[2].TestTransaction), tokenOut: TTrToToken(t, &trs[2].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &trs[2].From, sender: &trs[2].From,
recipient: &trs[2].To, recipient: &trs[2].To,
@ -431,8 +431,8 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)), amountIn: (*hexutil.Big)(big.NewInt(td.multiTx1.ToAmount)),
tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken), tokenOut: tokenFromSymbol(nil, td.multiTx1.FromToken),
tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken), tokenIn: tokenFromSymbol(nil, td.multiTx1.ToToken),
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: newAndSet("DAI"), symbolIn: common.NewAndSet("DAI"),
sender: &td.multiTx1.FromAddress, sender: &td.multiTx1.FromAddress,
recipient: &td.multiTx1.ToAddress, recipient: &td.multiTx1.ToAddress,
chainIDOut: nil, chainIDOut: nil,
@ -457,7 +457,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[2].TestTransaction), tokenOut: TTrToToken(t, &trs[2].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &trs[2].From, sender: &trs[2].From,
recipient: &trs[2].To, recipient: &trs[2].To,
@ -476,7 +476,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &td.tr1.TestTransaction), tokenOut: TTrToToken(t, &td.tr1.TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &td.tr1.From, sender: &td.tr1.From,
recipient: &td.tr1.To, recipient: &td.tr1.To,
@ -522,7 +522,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[8].TestTransaction), tokenOut: TTrToToken(t, &trs[8].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &trs[8].From, sender: &trs[8].From,
recipient: &trs[8].To, recipient: &trs[8].To,
@ -541,7 +541,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[6].TestTransaction), tokenOut: TTrToToken(t, &trs[6].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("DAI"), symbolOut: common.NewAndSet("DAI"),
symbolIn: nil, symbolIn: nil,
sender: &trs[6].From, sender: &trs[6].From,
recipient: &trs[6].To, recipient: &trs[6].To,
@ -566,7 +566,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[6].TestTransaction), tokenOut: TTrToToken(t, &trs[6].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("DAI"), symbolOut: common.NewAndSet("DAI"),
symbolIn: nil, symbolIn: nil,
sender: &trs[6].From, sender: &trs[6].From,
recipient: &trs[6].To, recipient: &trs[6].To,
@ -585,7 +585,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[4].TestTransaction), tokenOut: TTrToToken(t, &trs[4].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: nil, symbolIn: nil,
sender: &trs[4].From, sender: &trs[4].From,
recipient: &trs[4].To, recipient: &trs[4].To,
@ -610,7 +610,7 @@ func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[2].TestTransaction), tokenOut: TTrToToken(t, &trs[2].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: nil, symbolIn: nil,
sender: &trs[2].From, sender: &trs[2].From,
recipient: &trs[2].To, recipient: &trs[2].To,
@ -733,7 +733,7 @@ func TestGetActivityEntriesFilterByAddresses(t *testing.T) {
tokenOut: nil, tokenOut: nil,
tokenIn: TTrToToken(t, &trs[4].TestTransaction), tokenIn: TTrToToken(t, &trs[4].TestTransaction),
symbolOut: nil, symbolOut: nil,
symbolIn: newAndSet("USDC"), symbolIn: common.NewAndSet("USDC"),
sender: &trs[4].From, sender: &trs[4].From,
recipient: &trs[4].To, recipient: &trs[4].To,
chainIDOut: nil, chainIDOut: nil,
@ -751,7 +751,7 @@ func TestGetActivityEntriesFilterByAddresses(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(0)), amountIn: (*hexutil.Big)(big.NewInt(0)),
tokenOut: TTrToToken(t, &trs[1].TestTransaction), tokenOut: TTrToToken(t, &trs[1].TestTransaction),
tokenIn: nil, tokenIn: nil,
symbolOut: newAndSet("ETH"), symbolOut: common.NewAndSet("ETH"),
symbolIn: nil, symbolIn: nil,
sender: &trs[1].From, sender: &trs[1].From,
recipient: &trs[1].To, recipient: &trs[1].To,
@ -770,8 +770,8 @@ func TestGetActivityEntriesFilterByAddresses(t *testing.T) {
amountIn: (*hexutil.Big)(big.NewInt(td.multiTx2.ToAmount)), amountIn: (*hexutil.Big)(big.NewInt(td.multiTx2.ToAmount)),
tokenOut: tokenFromSymbol(nil, td.multiTx2.FromToken), tokenOut: tokenFromSymbol(nil, td.multiTx2.FromToken),
tokenIn: tokenFromSymbol(nil, td.multiTx2.ToToken), tokenIn: tokenFromSymbol(nil, td.multiTx2.ToToken),
symbolOut: newAndSet("USDC"), symbolOut: common.NewAndSet("USDC"),
symbolIn: newAndSet("SNT"), symbolIn: common.NewAndSet("SNT"),
sender: &td.multiTx2.FromAddress, sender: &td.multiTx2.FromAddress,
recipient: &td.multiTx2.ToAddress, recipient: &td.multiTx2.ToAddress,
chainIDOut: nil, chainIDOut: nil,

View File

@ -252,7 +252,9 @@ func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identit
func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) (err error) { func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) (err error) {
log.Debug("wallet.api.WatchTransactionByChainID", "chainID", chainID, "transactionHash", transactionHash) log.Debug("wallet.api.WatchTransactionByChainID", "chainID", chainID, "transactionHash", transactionHash)
var status *transactions.TxStatus var status *transactions.TxStatus
defer log.Debug("wallet.api.WatchTransactionByChainID return", "err", err, "chainID", chainID, "transactionHash", transactionHash) defer func() {
log.Debug("wallet.api.WatchTransactionByChainID return", "err", err, "chainID", chainID, "transactionHash", transactionHash)
}()
// Workaround to keep the blocking call until the clients use the PendingTxTracker APIs // Workaround to keep the blocking call until the clients use the PendingTxTracker APIs
eventChan := make(chan walletevent.Event, 2) eventChan := make(chan walletevent.Event, 2)

View File

@ -0,0 +1,7 @@
package common
func NewAndSet[T any](v T) *T {
res := new(T)
*res = v
return res
}

View File

@ -23,10 +23,10 @@ type ConditionalRepeater struct {
task TaskFunc task TaskFunc
// nil if not running // nil if not running
ctx context.Context ctx context.Context
ctxMu sync.Mutex
cancel context.CancelFunc cancel context.CancelFunc
runNowCh chan bool runNowCh chan bool
runNowMu sync.Mutex runNowMu sync.Mutex
onceMu sync.Mutex
} }
func NewConditionalRepeater(interval time.Duration, task TaskFunc) *ConditionalRepeater { func NewConditionalRepeater(interval time.Duration, task TaskFunc) *ConditionalRepeater {
@ -41,12 +41,14 @@ func NewConditionalRepeater(interval time.Duration, task TaskFunc) *ConditionalR
// interval until the task returns true. Can be called multiple times but it // interval until the task returns true. Can be called multiple times but it
// does not allow multiple concurrent executions of the task. // does not allow multiple concurrent executions of the task.
func (t *ConditionalRepeater) RunUntilDone() { func (t *ConditionalRepeater) RunUntilDone() {
t.onceMu.Lock() t.ctxMu.Lock()
defer func() { defer func() {
t.runNowMu.Lock()
if len(t.runNowCh) == 0 { if len(t.runNowCh) == 0 {
t.runNowCh <- true t.runNowCh <- true
} }
t.onceMu.Unlock() t.runNowMu.Unlock()
t.ctxMu.Unlock()
}() }()
if t.ctx != nil { if t.ctx != nil {
@ -56,8 +58,8 @@ func (t *ConditionalRepeater) RunUntilDone() {
go func() { go func() {
defer func() { defer func() {
t.runNowMu.Lock() t.ctxMu.Lock()
defer t.runNowMu.Unlock() defer t.ctxMu.Unlock()
t.cancel() t.cancel()
t.ctx = nil t.ctx = nil
}() }()
@ -79,7 +81,12 @@ func (t *ConditionalRepeater) RunUntilDone() {
case <-t.runNowCh: case <-t.runNowCh:
ticker.Reset(t.interval) ticker.Reset(t.interval)
if t.task(t.ctx) { if t.task(t.ctx) {
return t.runNowMu.Lock()
if len(t.runNowCh) == 0 {
t.runNowMu.Unlock()
return
}
t.runNowMu.Unlock()
} }
} }
} }
@ -88,8 +95,15 @@ func (t *ConditionalRepeater) RunUntilDone() {
// Stop forcefully stops the running task by canceling its context. // Stop forcefully stops the running task by canceling its context.
func (t *ConditionalRepeater) Stop() { func (t *ConditionalRepeater) Stop() {
t.onceMu.Lock() t.ctxMu.Lock()
defer t.onceMu.Unlock() defer t.ctxMu.Unlock()
t.cancel() if t.ctx != nil {
t.ctx = nil t.cancel()
}
}
func (t *ConditionalRepeater) IsRunning() bool {
t.ctxMu.Lock()
defer t.ctxMu.Unlock()
return t.ctx != nil
} }

View File

@ -29,7 +29,7 @@ const (
// Caries StatusChangedPayload in message // Caries StatusChangedPayload in message
EventPendingTransactionStatusChanged walletevent.EventType = "pending-transaction-status-changed" EventPendingTransactionStatusChanged walletevent.EventType = "pending-transaction-status-changed"
pendingCheckInterval = 10 * time.Second PendingCheckInterval = 10 * time.Second
) )
var ( var (
@ -65,17 +65,19 @@ type PendingTxTracker struct {
eventFeed *event.Feed eventFeed *event.Feed
taskRunner *ConditionalRepeater taskRunner *ConditionalRepeater
log log.Logger
} }
func NewPendingTxTracker(db *sql.DB, rpcClient rpc.ClientInterface, rpcFilter *rpcfilters.Service, eventFeed *event.Feed) *PendingTxTracker { func NewPendingTxTracker(db *sql.DB, rpcClient rpc.ClientInterface, rpcFilter *rpcfilters.Service, eventFeed *event.Feed, checkInterval time.Duration) *PendingTxTracker {
tm := &PendingTxTracker{ tm := &PendingTxTracker{
db: db, db: db,
rpcClient: rpcClient, rpcClient: rpcClient,
eventFeed: eventFeed, eventFeed: eventFeed,
rpcFilter: rpcFilter, rpcFilter: rpcFilter,
log: log.New("package", "status-go/transactions.PendingTxTracker"),
} }
tm.taskRunner = NewConditionalRepeater(pendingCheckInterval, func(ctx context.Context) bool { tm.taskRunner = NewConditionalRepeater(checkInterval, func(ctx context.Context) bool {
return tm.fetchTransactions(ctx) return tm.fetchAndUpdateDB(ctx)
}) })
return tm return tm
} }
@ -86,14 +88,15 @@ type txStatusRes struct {
hash eth.Hash hash eth.Hash
} }
func (tm *PendingTxTracker) fetchTransactions(ctx context.Context) bool { func (tm *PendingTxTracker) fetchAndUpdateDB(ctx context.Context) bool {
res := WorkDone res := WorkNotDone
txs, err := tm.GetAllPending() txs, err := tm.GetAllPending()
if err != nil { if err != nil {
log.Error("Failed to get pending transactions", "error", err) tm.log.Error("Failed to get pending transactions", "error", err)
return WorkDone return WorkDone
} }
tm.log.Debug("Checking for PT status", "count", len(txs))
txsMap := make(map[common.ChainID][]eth.Hash) txsMap := make(map[common.ChainID][]eth.Hash)
for _, tx := range txs { for _, tx := range txs {
@ -101,32 +104,44 @@ func (tm *PendingTxTracker) fetchTransactions(ctx context.Context) bool {
txsMap[chainID] = append(txsMap[chainID], tx.Hash) txsMap[chainID] = append(txsMap[chainID], tx.Hash)
} }
doneCount := 0
// Batch request for each chain // Batch request for each chain
for chainID, txs := range txsMap { for chainID, txs := range txsMap {
log.Debug("Processing pending transactions", "chainID", chainID, "count", len(txs)) tm.log.Debug("Processing PTs", "chainID", chainID, "count", len(txs))
batchRes, err := fetchBatchTxStatus(ctx, tm.rpcClient, chainID, txs) batchRes, err := fetchBatchTxStatus(ctx, tm.rpcClient, chainID, txs, tm.log)
if err != nil { if err != nil {
log.Error("Failed to batch fetch pending transactions status for", "chainID", chainID, "error", err) tm.log.Error("Failed to batch fetch pending transactions status for", "chainID", chainID, "error", err)
continue continue
} }
if len(batchRes) == 0 {
tm.log.Debug("No change to PTs status", "chainID", chainID)
continue
}
tm.log.Debug("PTs done", "chainID", chainID, "count", len(batchRes))
doneCount += len(batchRes)
updateRes, err := tm.updateDBStatus(ctx, chainID, batchRes) updateRes, err := tm.updateDBStatus(ctx, chainID, batchRes)
if err != nil { if err != nil {
log.Error("Failed to update pending transactions status for", "chainID", chainID, "error", err) tm.log.Error("Failed to update pending transactions status for", "chainID", chainID, "error", err)
continue continue
} }
if len(updateRes) != len(batchRes) { tm.log.Debug("Emit notifications for PTs", "chainID", chainID, "count", len(updateRes))
res = WorkNotDone
}
tm.emitNotifications(chainID, updateRes) tm.emitNotifications(chainID, updateRes)
} }
if len(txs) == doneCount {
res = WorkDone
}
tm.log.Debug("Done PTs iteration", "count", doneCount, "completed", res)
return res return res
} }
// fetchBatchTxStatus will exclude the still pending or errored request from the result // fetchBatchTxStatus returns not pending transactions (confirmed or errored)
func fetchBatchTxStatus(ctx context.Context, rpcClient rpc.ClientInterface, chainID common.ChainID, hashes []eth.Hash) ([]txStatusRes, error) { // it excludes the still pending or errored request from the result
func fetchBatchTxStatus(ctx context.Context, rpcClient rpc.ClientInterface, chainID common.ChainID, hashes []eth.Hash, log log.Logger) ([]txStatusRes, error) {
chainClient, err := rpcClient.AbstractEthClient(chainID) chainClient, err := rpcClient.AbstractEthClient(chainID)
if err != nil { if err != nil {
log.Error("Failed to get chain client", "error", err) log.Error("Failed to get chain client", "error", err)
@ -210,9 +225,9 @@ func (tm *PendingTxTracker) updateDBStatus(ctx context.Context, chainID common.C
err = row.Scan(&autoDel) err = row.Scan(&autoDel)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
log.Warn("Missing entry while checking for auto_delete", "hash", br.hash) tm.log.Warn("Missing entry while checking for auto_delete", "hash", br.hash)
} else { } else {
log.Error("Failed to retrieve auto_delete for pending transaction", "error", err, "hash", br.hash) tm.log.Error("Failed to retrieve auto_delete for pending transaction", "error", err, "hash", br.hash)
} }
continue continue
} }
@ -220,7 +235,7 @@ func (tm *PendingTxTracker) updateDBStatus(ctx context.Context, chainID common.C
if autoDel { if autoDel {
notifyFn, err := tm.DeleteBySQLTx(tx, chainID, br.hash) notifyFn, err := tm.DeleteBySQLTx(tx, chainID, br.hash)
if err != nil && err != ErrStillPending { if err != nil && err != ErrStillPending {
log.Error("Failed to delete pending transaction", "error", err, "hash", br.hash) tm.log.Error("Failed to delete pending transaction", "error", err, "hash", br.hash)
continue continue
} }
notifyFunctions = append(notifyFunctions, notifyFn) notifyFunctions = append(notifyFunctions, notifyFn)
@ -231,17 +246,17 @@ func (tm *PendingTxTracker) updateDBStatus(ctx context.Context, chainID common.C
res, err := updateStmt.ExecContext(ctx, txStatus, chainID, br.hash) res, err := updateStmt.ExecContext(ctx, txStatus, chainID, br.hash)
if err != nil { if err != nil {
log.Error("Failed to update pending transaction status", "error", err, "hash", br.hash) tm.log.Error("Failed to update pending transaction status", "error", err, "hash", br.hash)
continue continue
} }
affected, err := res.RowsAffected() affected, err := res.RowsAffected()
if err != nil { if err != nil {
log.Error("Failed to get updated rows", "error", err, "hash", br.hash) tm.log.Error("Failed to get updated rows", "error", err, "hash", br.hash)
continue continue
} }
if affected == 0 { if affected == 0 {
log.Warn("Missing entry to update for", "hash", br.hash) tm.log.Warn("Missing entry to update for", "hash", br.hash)
continue continue
} }
} }
@ -272,7 +287,7 @@ func (tm *PendingTxTracker) emitNotifications(chainID common.ChainID, changes []
jsonPayload, err := json.Marshal(status) jsonPayload, err := json.Marshal(status)
if err != nil { if err != nil {
log.Error("Failed to marshal pending transaction status", "error", err, "hash", hash) tm.log.Error("Failed to marshal pending transaction status", "error", err, "hash", hash)
continue continue
} }
tm.eventFeed.Send(walletevent.Event{ tm.eventFeed.Send(walletevent.Event{
@ -408,8 +423,6 @@ func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err
} }
func (tm *PendingTxTracker) GetAllPending() ([]*PendingTransaction, error) { func (tm *PendingTxTracker) GetAllPending() ([]*PendingTransaction, error) {
log.Debug("Getting all pending transactions")
rows, err := tm.db.Query(selectFromPending+"WHERE status = ?", Pending) rows, err := tm.db.Query(selectFromPending+"WHERE status = ?", Pending)
if err != nil { if err != nil {
return nil, err return nil, err
@ -420,10 +433,8 @@ func (tm *PendingTxTracker) GetAllPending() ([]*PendingTransaction, error) {
} }
func (tm *PendingTxTracker) GetPendingByAddress(chainIDs []uint64, address eth.Address) ([]*PendingTransaction, error) { func (tm *PendingTxTracker) GetPendingByAddress(chainIDs []uint64, address eth.Address) ([]*PendingTransaction, error) {
log.Debug("Getting pending transaction by address", "chainIDs", chainIDs, "address", address)
if len(chainIDs) == 0 { if len(chainIDs) == 0 {
return nil, errors.New("at least 1 chainID is required") return nil, errors.New("GetPendingByAddress: at least 1 chainID is required")
} }
inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?"
@ -445,8 +456,6 @@ func (tm *PendingTxTracker) GetPendingByAddress(chainIDs []uint64, address eth.A
// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity // GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity
func (tm *PendingTxTracker) GetPendingEntry(chainID common.ChainID, hash eth.Hash) (*PendingTransaction, error) { func (tm *PendingTxTracker) GetPendingEntry(chainID common.ChainID, hash eth.Hash) (*PendingTransaction, error) {
log.Debug("Getting pending transaction", "chainID", chainID, "hash", hash)
rows, err := tm.db.Query(selectFromPending+"WHERE network_id = ? AND hash = ?", chainID, hash) rows, err := tm.db.Query(selectFromPending+"WHERE network_id = ? AND hash = ?", chainID, hash)
if err != nil { if err != nil {
return nil, err return nil, err
@ -566,8 +575,6 @@ func GetTransferData(tx *sql.Tx, chainID common.ChainID, hash eth.Hash) (txType
// Watch returns sql.ErrNoRows if no pending transaction is found for the given identity // Watch returns sql.ErrNoRows if no pending transaction is found for the given identity
// tx.Status is not nill if err is nil // tx.Status is not nill if err is nil
func (tm *PendingTxTracker) Watch(ctx context.Context, chainID common.ChainID, hash eth.Hash) (*TxStatus, error) { func (tm *PendingTxTracker) Watch(ctx context.Context, chainID common.ChainID, hash eth.Hash) (*TxStatus, error) {
log.Debug("Watching transaction", "chainID", chainID, "hash", hash)
tx, err := tm.GetPendingEntry(chainID, hash) tx, err := tm.GetPendingEntry(chainID, hash)
if err != nil { if err != nil {
return nil, err return nil, err
@ -579,8 +586,6 @@ func (tm *PendingTxTracker) Watch(ctx context.Context, chainID common.ChainID, h
// Delete returns ErrStillPending if the deleted transaction was still pending // Delete returns ErrStillPending if the deleted transaction was still pending
// The transactions are suppose to be deleted by the client only after they are confirmed // The transactions are suppose to be deleted by the client only after they are confirmed
func (tm *PendingTxTracker) Delete(ctx context.Context, chainID common.ChainID, transactionHash eth.Hash) error { func (tm *PendingTxTracker) Delete(ctx context.Context, chainID common.ChainID, transactionHash eth.Hash) error {
log.Debug("Delete pending transaction to confirm it", "chainID", chainID, "hash", transactionHash)
tx, err := tm.db.BeginTx(ctx, nil) tx, err := tm.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return fmt.Errorf("failed to begin transaction: %w", err)

View File

@ -2,9 +2,11 @@ package transactions
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"sync"
"testing" "testing"
"time" "time"
@ -60,25 +62,35 @@ func (m *MockChainClient) AbstractEthClient(chainID common.ChainID) (chain.Clien
return m.clients[chainID], nil return m.clients[chainID], nil
} }
func setupTestTransactionDB(t *testing.T) (*PendingTxTracker, func(), *MockChainClient, *event.Feed) { // setupTestTransactionDB will use the default pending check interval if checkInterval is nil
func setupTestTransactionDB(t *testing.T, checkInterval *time.Duration) (*PendingTxTracker, func(), *MockChainClient, *event.Feed) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err) require.NoError(t, err)
chainClient := newMockChainClient() chainClient := newMockChainClient()
eventFeed := &event.Feed{} eventFeed := &event.Feed{}
return NewPendingTxTracker(db, chainClient, nil, eventFeed), func() { pendingCheckInterval := PendingCheckInterval
if checkInterval != nil {
pendingCheckInterval = *checkInterval
}
return NewPendingTxTracker(db, chainClient, nil, eventFeed, pendingCheckInterval), func() {
require.NoError(t, db.Close()) require.NoError(t, db.Close())
}, chainClient, eventFeed }, chainClient, eventFeed
} }
func waitForTaskToStop(pt *PendingTxTracker) {
for pt.taskRunner.IsRunning() {
time.Sleep(1 * time.Microsecond)
}
}
const ( const (
transactionSuccessStatus = "0x1" transactionBlockNo = "0x1"
transactionFailStatus = "0x0"
transactionByHashRPCName = "eth_getTransactionByHash" transactionByHashRPCName = "eth_getTransactionByHash"
) )
func TestPendingTxTracker_ValidateConfirmed(t *testing.T) { func TestPendingTxTracker_ValidateConfirmed(t *testing.T) {
m, stop, chainClient, eventFeed := setupTestTransactionDB(t) m, stop, chainClient, eventFeed := setupTestTransactionDB(t, nil)
defer stop() defer stop()
txs := generateTestTransactions(1) txs := generateTestTransactions(1)
@ -91,10 +103,10 @@ func TestPendingTxTracker_ValidateConfirmed(t *testing.T) {
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{}) res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionSuccessStatus (*res)["blockNumber"] = transactionBlockNo
}) })
eventChan := make(chan walletevent.Event, 2) eventChan := make(chan walletevent.Event, 3)
sub := eventFeed.Subscribe(eventChan) sub := eventFeed.Subscribe(eventChan)
err := m.StoreAndTrackPendingTx(&txs[0]) err := m.StoreAndTrackPendingTx(&txs[0])
@ -123,6 +135,8 @@ func TestPendingTxTracker_ValidateConfirmed(t *testing.T) {
err = m.Stop() err = m.Stop()
require.NoError(t, err) require.NoError(t, err)
waitForTaskToStop(m)
res, err := m.GetAllPending() res, err := m.GetAllPending()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(res)) require.Equal(t, 0, len(res))
@ -131,7 +145,7 @@ func TestPendingTxTracker_ValidateConfirmed(t *testing.T) {
} }
func TestPendingTxTracker_InterruptWatching(t *testing.T) { func TestPendingTxTracker_InterruptWatching(t *testing.T) {
m, stop, chainClient, eventFeed := setupTestTransactionDB(t) m, stop, chainClient, eventFeed := setupTestTransactionDB(t, nil)
defer stop() defer stop()
txs := generateTestTransactions(2) txs := generateTestTransactions(2)
@ -143,10 +157,11 @@ func TestPendingTxTracker_InterruptWatching(t *testing.T) {
return (len(b) == 2 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash && b[1].Method == transactionByHashRPCName && b[1].Args[0] == txs[1].Hash) return (len(b) == 2 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash && b[1].Method == transactionByHashRPCName && b[1].Args[0] == txs[1].Hash)
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = nil // Simulate still pending by excluding "blockNumber" in elems[0]
res = elems[1].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionFailStatus res := elems[1].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionBlockNo
}) })
eventChan := make(chan walletevent.Event, 2) eventChan := make(chan walletevent.Event, 2)
@ -193,6 +208,8 @@ func TestPendingTxTracker_InterruptWatching(t *testing.T) {
err = m.Stop() err = m.Stop()
require.NoError(t, err) require.NoError(t, err)
waitForTaskToStop(m)
res, err := m.GetAllPending() res, err := m.GetAllPending()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(res), "should have only one pending tx") require.Equal(t, 1, len(res), "should have only one pending tx")
@ -204,7 +221,7 @@ func TestPendingTxTracker_InterruptWatching(t *testing.T) {
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{}) res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionSuccessStatus (*res)["blockNumber"] = transactionBlockNo
}) })
err = m.Start() err = m.Start()
@ -232,6 +249,8 @@ func TestPendingTxTracker_InterruptWatching(t *testing.T) {
err = m.Stop() err = m.Stop()
require.NoError(t, err) require.NoError(t, err)
waitForTaskToStop(m)
res, err = m.GetAllPending() res, err = m.GetAllPending()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(res)) require.Equal(t, 0, len(res))
@ -240,7 +259,7 @@ func TestPendingTxTracker_InterruptWatching(t *testing.T) {
} }
func TestPendingTxTracker_MultipleClients(t *testing.T) { func TestPendingTxTracker_MultipleClients(t *testing.T) {
m, stop, chainClient, eventFeed := setupTestTransactionDB(t) m, stop, chainClient, eventFeed := setupTestTransactionDB(t, nil)
defer stop() defer stop()
txs := generateTestTransactions(2) txs := generateTestTransactions(2)
@ -254,7 +273,7 @@ func TestPendingTxTracker_MultipleClients(t *testing.T) {
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{}) res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionFailStatus (*res)["blockNumber"] = transactionBlockNo
}) })
cl = chainClient.clients[txs[1].ChainID] cl = chainClient.clients[txs[1].ChainID]
cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool {
@ -262,42 +281,55 @@ func TestPendingTxTracker_MultipleClients(t *testing.T) {
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{}) res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionSuccessStatus (*res)["blockNumber"] = transactionBlockNo
}) })
eventChan := make(chan walletevent.Event, 6)
sub := eventFeed.Subscribe(eventChan)
for i := range txs { for i := range txs {
err := m.TrackPendingTransaction(txs[i].ChainID, txs[i].Hash, txs[i].From, txs[i].Type, true) err := m.TrackPendingTransaction(txs[i].ChainID, txs[i].Hash, txs[i].From, txs[i].Type, AutoDelete)
require.NoError(t, err) require.NoError(t, err)
} }
eventChan := make(chan walletevent.Event)
sub := eventFeed.Subscribe(eventChan)
err := m.Start() err := m.Start()
require.NoError(t, err) require.NoError(t, err)
storeEventCount := 0
statusEventCount := 0
validateStatusChange := func(we *walletevent.Event) {
if we.Type == EventPendingTransactionUpdate {
storeEventCount++
} else if we.Type == EventPendingTransactionStatusChanged {
statusEventCount++
require.Equal(t, EventPendingTransactionStatusChanged, we.Type)
var p StatusChangedPayload
err := json.Unmarshal([]byte(we.Message), &p)
require.NoError(t, err)
require.Nil(t, p.Status)
}
}
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
for j := 0; j < 2; j++ { for j := 0; j < 3; j++ {
select { select {
case we := <-eventChan: case we := <-eventChan:
if j == 0 { validateStatusChange(&we)
require.Equal(t, EventPendingTransactionUpdate, we.Type)
} else {
require.Equal(t, EventPendingTransactionStatusChanged, we.Type)
var p StatusChangedPayload
err := json.Unmarshal([]byte(we.Message), &p)
require.NoError(t, err)
require.Nil(t, p.Status)
}
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event") t.Fatal("timeout waiting for event", i, j, storeEventCount, statusEventCount)
} }
} }
} }
require.Equal(t, 4, storeEventCount)
require.Equal(t, 2, statusEventCount)
err = m.Stop() err = m.Stop()
require.NoError(t, err) require.NoError(t, err)
waitForTaskToStop(m)
res, err := m.GetAllPending() res, err := m.GetAllPending()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(res)) require.Equal(t, 0, len(res))
@ -306,33 +338,33 @@ func TestPendingTxTracker_MultipleClients(t *testing.T) {
} }
func TestPendingTxTracker_Watch(t *testing.T) { func TestPendingTxTracker_Watch(t *testing.T) {
m, stop, chainClient, eventFeed := setupTestTransactionDB(t) m, stop, chainClient, eventFeed := setupTestTransactionDB(t, nil)
defer stop() defer stop()
txs := generateTestTransactions(2) txs := generateTestTransactions(2)
// Make the second already confirmed // Make the second already confirmed
*txs[1].Status = Done *txs[0].Status = Done
// Mock the first call to getTransactionByHash // Mock the first call to getTransactionByHash
chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID}) chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID})
cl := chainClient.clients[txs[0].ChainID] cl := chainClient.clients[txs[0].ChainID]
cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool {
return len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash return len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[1].Hash
})).Return(nil).Once().Run(func(args mock.Arguments) { })).Return(nil).Once().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem) elems := args.Get(1).([]rpc.BatchElem)
res := elems[0].Result.(*map[string]interface{}) res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionFailStatus (*res)["blockNumber"] = transactionBlockNo
}) })
eventChan := make(chan walletevent.Event, 2) eventChan := make(chan walletevent.Event, 3)
sub := eventFeed.Subscribe(eventChan) sub := eventFeed.Subscribe(eventChan)
// Track the first transaction // Track the first transaction
err := m.TrackPendingTransaction(txs[0].ChainID, txs[0].Hash, txs[0].From, txs[0].Type, false) err := m.TrackPendingTransaction(txs[1].ChainID, txs[1].Hash, txs[1].From, txs[1].Type, Keep)
require.NoError(t, err) require.NoError(t, err)
// Store the confirmed already // Store the confirmed already
err = m.StoreAndTrackPendingTx(&txs[1]) err = m.StoreAndTrackPendingTx(&txs[0])
require.NoError(t, err) require.NoError(t, err)
storeEventCount := 0 storeEventCount := 0
@ -347,8 +379,8 @@ func TestPendingTxTracker_Watch(t *testing.T) {
var p StatusChangedPayload var p StatusChangedPayload
err := json.Unmarshal([]byte(we.Message), &p) err := json.Unmarshal([]byte(we.Message), &p)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, txs[0].ChainID, p.ChainID) require.Equal(t, txs[1].ChainID, p.ChainID)
require.Equal(t, txs[0].Hash, p.Hash) require.Equal(t, txs[1].Hash, p.Hash)
require.Nil(t, p.Status) require.Nil(t, p.Status)
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
@ -362,15 +394,17 @@ func TestPendingTxTracker_Watch(t *testing.T) {
err = m.Stop() err = m.Stop()
require.NoError(t, err) require.NoError(t, err)
waitForTaskToStop(m)
res, err := m.GetAllPending() res, err := m.GetAllPending()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(res), "should have only one pending tx") require.Equal(t, 0, len(res), "should have no pending tx")
status, err := m.Watch(context.Background(), txs[0].ChainID, txs[0].Hash) status, err := m.Watch(context.Background(), txs[1].ChainID, txs[1].Hash)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, Pending, status) require.NotEqual(t, Pending, status)
err = m.Delete(context.Background(), txs[0].ChainID, txs[0].Hash) err = m.Delete(context.Background(), txs[1].ChainID, txs[1].Hash)
require.NoError(t, err) require.NoError(t, err)
select { select {
@ -383,8 +417,120 @@ func TestPendingTxTracker_Watch(t *testing.T) {
sub.Unsubscribe() sub.Unsubscribe()
} }
func TestPendingTxTracker_Watch_StatusChangeIncrementally(t *testing.T) {
m, stop, chainClient, eventFeed := setupTestTransactionDB(t, common.NewAndSet(1*time.Nanosecond))
defer stop()
txs := generateTestTransactions(2)
var firsDoneWG sync.WaitGroup
firsDoneWG.Add(1)
// Mock the first call to getTransactionByHash
chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID})
cl := chainClient.clients[txs[0].ChainID]
cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool {
if len(cl.Calls) == 0 {
res := len(b) > 0 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash
// If the first processing call picked up the second validate this case also
if len(b) == 2 {
res = res && b[1].Method == transactionByHashRPCName && b[1].Args[0] == txs[1].Hash
}
return res
}
// Second call we expect only one left
return len(b) == 1 && (b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[1].Hash)
})).Return(nil).Twice().Run(func(args mock.Arguments) {
elems := args.Get(1).([]rpc.BatchElem)
if len(cl.Calls) == 2 {
firsDoneWG.Wait()
}
// Only first item is processed, second is left pending
res := elems[0].Result.(*map[string]interface{})
(*res)["blockNumber"] = transactionBlockNo
})
eventChan := make(chan walletevent.Event, 6)
sub := eventFeed.Subscribe(eventChan)
for i := range txs {
// Track the first transaction
err := m.TrackPendingTransaction(txs[i].ChainID, txs[i].Hash, txs[i].From, txs[i].Type, Keep)
require.NoError(t, err)
}
storeEventCount := 0
statusEventCount := 0
validateStatusChange := func(we *walletevent.Event) {
var p StatusChangedPayload
err := json.Unmarshal([]byte(we.Message), &p)
require.NoError(t, err)
if statusEventCount == 0 {
require.Equal(t, txs[0].ChainID, p.ChainID)
require.Equal(t, txs[0].Hash, p.Hash)
require.Nil(t, p.Status)
status, err := m.Watch(context.Background(), txs[0].ChainID, txs[0].Hash)
require.NoError(t, err)
require.Equal(t, Done, *status)
err = m.Delete(context.Background(), txs[0].ChainID, txs[0].Hash)
require.NoError(t, err)
status, err = m.Watch(context.Background(), txs[1].ChainID, txs[1].Hash)
require.NoError(t, err)
require.Equal(t, Pending, *status)
firsDoneWG.Done()
} else {
_, err := m.Watch(context.Background(), txs[0].ChainID, txs[0].Hash)
require.Equal(t, err, sql.ErrNoRows)
status, err := m.Watch(context.Background(), txs[1].ChainID, txs[1].Hash)
require.NoError(t, err)
require.Equal(t, Done, *status)
err = m.Delete(context.Background(), txs[1].ChainID, txs[1].Hash)
require.NoError(t, err)
}
statusEventCount++
}
for j := 0; j < 6; j++ {
select {
case we := <-eventChan:
if EventPendingTransactionUpdate == we.Type {
storeEventCount++
} else if EventPendingTransactionStatusChanged == we.Type {
validateStatusChange(&we)
}
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for the status update event")
}
}
_, err := m.Watch(context.Background(), txs[1].ChainID, txs[1].Hash)
require.Equal(t, err, sql.ErrNoRows)
// One for add and one for delete
require.Equal(t, 4, storeEventCount)
require.Equal(t, 2, statusEventCount)
err = m.Stop()
require.NoError(t, err)
waitForTaskToStop(m)
res, err := m.GetAllPending()
require.NoError(t, err)
require.Equal(t, 0, len(res), "should have no pending tx")
sub.Unsubscribe()
}
func TestPendingTransactions(t *testing.T) { func TestPendingTransactions(t *testing.T) {
manager, stop, _, _ := setupTestTransactionDB(t) manager, stop, _, _ := setupTestTransactionDB(t, nil)
defer stop() defer stop()
tx := generateTestTransactions(1)[0] tx := generateTestTransactions(1)[0]