From 524c21834b7249da97c1c3bc712763a1d37b786d Mon Sep 17 00:00:00 2001 From: Stefan Date: Tue, 1 Aug 2023 19:50:30 +0100 Subject: [PATCH] fix(wallet) propagate multi-transactions IDs to transfers Mainly refactor API to have control on pending_transactions operations. Use the new API to migrate the multi-transaction ID from to transfers in one SQL transaction. The refactoring was done to better mirror the purpose of pending_transactions Also: - Externalize TransactionManager from WalletService to be used by other services - Extract walletEvent as a dependency for all services that need to propagate events - Batch chain requests - Remove unused APIs - Add auto delete option for clients that fire and forget transactions Updates status-desktop #11754 --- api/geth_backend.go | 53 +- node/get_status_node.go | 5 + node/status_node_services.go | 28 +- rpc/chain/client.go | 13 + rpc/client.go | 15 + services/collectibles/api.go | 110 ++-- services/collectibles/service.go | 6 +- services/ens/api.go | 62 +- services/ens/service.go | 10 +- services/stickers/api.go | 8 +- services/stickers/service.go | 8 +- services/stickers/transactions.go | 20 +- services/wallet/activity/activity.go | 16 +- services/wallet/api.go | 77 ++- services/wallet/service.go | 44 +- services/wallet/transfer/commands.go | 91 +-- .../wallet/transfer/commands_sequential.go | 6 +- services/wallet/transfer/controller.go | 18 +- services/wallet/transfer/database.go | 48 +- services/wallet/transfer/database_test.go | 6 +- services/wallet/transfer/reactor.go | 6 +- .../transfer/sequential_fetch_strategy.go | 4 +- services/wallet/transfer/transaction.go | 13 +- transactions/conditionalrepeater.go | 95 +++ transactions/conditionalrepeater_test.go | 79 +++ transactions/pending.go | 364 ----------- transactions/pendingtxtracker.go | 604 ++++++++++++++++++ transactions/pendingtxtracker_test.go | 453 +++++++++++++ transactions/transaction_test.go | 75 --- 29 files changed, 1615 insertions(+), 722 deletions(-) create mode 100644 transactions/conditionalrepeater.go create mode 100644 transactions/conditionalrepeater_test.go delete mode 100644 transactions/pending.go create mode 100644 transactions/pendingtxtracker.go create mode 100644 transactions/pendingtxtracker_test.go delete mode 100644 transactions/transaction_test.go diff --git a/api/geth_backend.go b/api/geth_backend.go index e89c6047e..9ec9838ab 100644 --- a/api/geth_backend.go +++ b/api/geth_backend.go @@ -44,8 +44,8 @@ import ( "github.com/status-im/status-go/server/pairing/statecontrol" "github.com/status-im/status-go/services/ext" "github.com/status-im/status-go/services/personal" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/typeddata" + wcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/signal" "github.com/status-im/status-go/sqlite" "github.com/status-im/status-go/transactions" @@ -1789,12 +1789,17 @@ func (b *GethStatusBackend) SendTransaction(sendArgs transactions.SendTxArgs, pa return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: common.Hash(hash), - Type: string(transactions.WalletTransfer), - From: common.Address(sendArgs.From), - ChainID: b.transactor.NetworkID(), - }) + err = b.statusNode.PendingTracker().TrackPendingTransaction( + wcommon.ChainID(b.transactor.NetworkID()), + common.Hash(hash), + common.Address(sendArgs.From), + transactions.WalletTransfer, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return + } return } @@ -1810,12 +1815,17 @@ func (b *GethStatusBackend) SendTransactionWithChainID(chainID uint64, sendArgs return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: common.Hash(hash), - Type: string(transactions.WalletTransfer), - From: common.Address(sendArgs.From), - ChainID: b.transactor.NetworkID(), - }) + err = b.statusNode.PendingTracker().TrackPendingTransaction( + wcommon.ChainID(b.transactor.NetworkID()), + common.Hash(hash), + common.Address(sendArgs.From), + transactions.WalletTransfer, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return + } return } @@ -1826,12 +1836,17 @@ func (b *GethStatusBackend) SendTransactionWithSignature(sendArgs transactions.S return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: common.Hash(hash), - Type: string(transactions.WalletTransfer), - From: common.Address(sendArgs.From), - ChainID: b.transactor.NetworkID(), - }) + err = b.statusNode.PendingTracker().TrackPendingTransaction( + wcommon.ChainID(b.transactor.NetworkID()), + common.Hash(hash), + common.Address(sendArgs.From), + transactions.WalletTransfer, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return + } return } diff --git a/node/get_status_node.go b/node/get_status_node.go index d43c524b0..aab447a02 100644 --- a/node/get_status_node.go +++ b/node/get_status_node.go @@ -14,6 +14,7 @@ import ( "github.com/syndtr/goleveldb/leveldb" "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" @@ -127,6 +128,9 @@ type StatusNode struct { stickersSrvc *stickers.Service chatSrvc *chat.Service updatesSrvc *updates.Service + pendingTracker *transactions.PendingTxTracker + + walletFeed event.Feed } // New makes new instance of StatusNode. @@ -502,6 +506,7 @@ func (n *StatusNode) stop() error { n.collectiblesSrvc = nil n.stickersSrvc = nil n.publicMethods = make(map[string]bool) + n.pendingTracker = nil return nil } diff --git a/node/status_node_services.go b/node/status_node_services.go index fcf035776..26451a414 100644 --- a/node/status_node_services.go +++ b/node/status_node_services.go @@ -9,6 +9,7 @@ import ( "github.com/status-im/status-go/server" "github.com/status-im/status-go/signal" + "github.com/status-im/status-go/transactions" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/p2p/enode" @@ -77,6 +78,7 @@ func (b *StatusNode) initServices(config *params.NodeConfig, mediaServer *server services = append(services, b.peerService()) services = append(services, b.personalService()) services = append(services, b.statusPublicService()) + services = append(services, b.pendingTrackerService(&b.walletFeed)) services = append(services, b.ensService(b.timeSourceNow())) services = append(services, b.collectiblesService()) services = append(services, b.stickersService(accDB)) @@ -92,7 +94,7 @@ func (b *StatusNode) initServices(config *params.NodeConfig, mediaServer *server // Wallet Service is used by wakuExtSrvc/wakuV2ExtSrvc // Keep this initialization before the other two if config.WalletConfig.Enabled { - walletService := b.walletService(accDB, accountsFeed) + walletService := b.walletService(accDB, accountsFeed, &b.walletFeed) services = append(services, walletService) } @@ -413,21 +415,28 @@ func (b *StatusNode) browsersService() *browsers.Service { func (b *StatusNode) ensService(timesource func() time.Time) *ens.Service { if b.ensSrvc == nil { - b.ensSrvc = ens.NewService(b.rpcClient, b.gethAccountManager, b.rpcFiltersSrvc, b.config, b.appDB, timesource) + b.ensSrvc = ens.NewService(b.rpcClient, b.gethAccountManager, b.pendingTracker, b.config, b.appDB, timesource) } return b.ensSrvc } +func (b *StatusNode) pendingTrackerService(walletFeed *event.Feed) *transactions.PendingTxTracker { + if b.pendingTracker == nil { + b.pendingTracker = transactions.NewPendingTxTracker(b.appDB, b.rpcClient, b.rpcFiltersSrvc, walletFeed) + } + return b.pendingTracker +} + func (b *StatusNode) collectiblesService() *collectibles.Service { if b.collectiblesSrvc == nil { - b.collectiblesSrvc = collectibles.NewService(b.rpcClient, b.gethAccountManager, b.rpcFiltersSrvc, b.config, b.appDB) + b.collectiblesSrvc = collectibles.NewService(b.rpcClient, b.gethAccountManager, b.pendingTracker, b.config, b.appDB) } return b.collectiblesSrvc } func (b *StatusNode) stickersService(accountDB *accounts.Database) *stickers.Service { if b.stickersSrvc == nil { - b.stickersSrvc = stickers.NewService(accountDB, b.rpcClient, b.gethAccountManager, b.rpcFiltersSrvc, b.config, b.downloader, b.httpServer) + b.stickersSrvc = stickers.NewService(accountDB, b.rpcClient, b.gethAccountManager, b.config, b.downloader, b.httpServer, b.pendingTracker) } return b.stickersSrvc } @@ -498,13 +507,14 @@ func (b *StatusNode) CollectiblesService() *collectibles.Service { return b.collectiblesSrvc } -func (b *StatusNode) walletService(accountsDB *accounts.Database, accountsFeed *event.Feed) *wallet.Service { +func (b *StatusNode) walletService(accountsDB *accounts.Database, accountsFeed *event.Feed, walletFeed *event.Feed) *wallet.Service { if b.walletSrvc == nil { b.walletSrvc = wallet.NewService( b.walletDB, accountsDB, b.rpcClient, accountsFeed, b.gethAccountManager, b.transactor, b.config, b.ensService(b.timeSourceNow()), b.stickersService(accountsDB), - b.rpcFiltersSrvc, + b.pendingTracker, + walletFeed, ) } return b.walletSrvc @@ -546,6 +556,10 @@ func (b *StatusNode) RPCFiltersService() *rpcfilters.Service { return b.rpcFiltersSrvc } +func (b *StatusNode) PendingTracker() *transactions.PendingTxTracker { + return b.pendingTracker +} + func (b *StatusNode) StopLocalNotifications() error { if b.localNotificationsSrvc == nil { return nil @@ -580,7 +594,7 @@ func (b *StatusNode) StartLocalNotifications() error { } } - err := b.localNotificationsSrvc.SubscribeWallet(b.walletSrvc.GetFeed()) + err := b.localNotificationsSrvc.SubscribeWallet(&b.walletFeed) if err != nil { b.log.Error("LocalNotifications service could not subscribe to wallet on StartLocalNotifications", "error", err) diff --git a/rpc/chain/client.go b/rpc/chain/client.go index d9e80b588..b24ab24d7 100644 --- a/rpc/chain/client.go +++ b/rpc/chain/client.go @@ -25,6 +25,10 @@ type FeeHistory struct { BaseFeePerGas []string `json:"baseFeePerGas"` } +type ClientInterface interface { + BatchCallContext(ctx context.Context, b []rpc.BatchElem) error +} + type ClientWithFallback struct { ChainID uint64 main *ethclient.Client @@ -816,6 +820,15 @@ func (c *ClientWithFallback) CallContext(ctx context.Context, result interface{} ) } +func (c *ClientWithFallback) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { + rpcstats.CountCall("eth_BatchCallContext") + + return c.makeCallNoReturn( + func() error { return c.mainRPC.BatchCallContext(ctx, b) }, + func() error { return c.fallbackRPC.BatchCallContext(ctx, b) }, + ) +} + func (c *ClientWithFallback) ToBigInt() *big.Int { return big.NewInt(int64(c.ChainID)) } diff --git a/rpc/client.go b/rpc/client.go index a25c48a20..973e30b98 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -17,6 +17,7 @@ import ( "github.com/status-im/status-go/rpc/chain" "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/rpcstats" + "github.com/status-im/status-go/services/wallet/common" ) const ( @@ -32,6 +33,10 @@ var ( // Handler defines handler for RPC methods. type Handler func(context.Context, uint64, ...interface{}) (interface{}, error) +type ClientInterface interface { + AbstractEthClient(chainID common.ChainID) (chain.ClientInterface, error) +} + // Client represents RPC client with custom routing // scheme. It automatically decides where RPC call // goes - Upstream or Local node. @@ -154,6 +159,16 @@ func (c *Client) EthClient(chainID uint64) (*chain.ClientWithFallback, error) { return client, nil } +// AbstractEthClient returns a partial abstraction used by new components for testing purposes +func (c *Client) AbstractEthClient(chainID common.ChainID) (chain.ClientInterface, error) { + client, err := c.getClientUsingCache(uint64(chainID)) + if err != nil { + return nil, err + } + + return client, nil +} + func (c *Client) EthClients(chainIDs []uint64) (map[uint64]*chain.ClientWithFallback, error) { clients := make(map[uint64]*chain.ClientWithFallback, 0) for _, chainID := range chainIDs { diff --git a/services/collectibles/api.go b/services/collectibles/api.go index 924574c50..afb85eae9 100644 --- a/services/collectibles/api.go +++ b/services/collectibles/api.go @@ -22,17 +22,17 @@ import ( "github.com/status-im/status-go/params" "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/rpc" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/utils" "github.com/status-im/status-go/services/wallet/bigint" + wcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/transactions" ) -func NewAPI(rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, config *params.NodeConfig, appDb *sql.DB) *API { +func NewAPI(rpcClient *rpc.Client, accountsManager *account.GethManager, pendingTracker *transactions.PendingTxTracker, config *params.NodeConfig, appDb *sql.DB) *API { return &API{ RPCClient: rpcClient, accountsManager: accountsManager, - rpcFiltersSrvc: rpcFiltersSrvc, + pendingTracker: pendingTracker, config: config, db: NewCommunityTokensDatabase(appDb), } @@ -41,7 +41,7 @@ func NewAPI(rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFilt type API struct { RPCClient *rpc.Client accountsManager *account.GethManager - rpcFiltersSrvc *rpcfilters.Service + pendingTracker *transactions.PendingTxTracker config *params.NodeConfig db *Database } @@ -126,12 +126,17 @@ func (api *API) DeployCollectibles(ctx context.Context, chainID uint64, deployme return DeploymentDetails{}, err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.DeployCommunityToken), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.DeployCommunityToken, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return DeploymentDetails{}, err + } return DeploymentDetails{address.Hex(), tx.Hash().Hex()}, nil } @@ -166,12 +171,17 @@ func (api *API) DeployOwnerToken(ctx context.Context, chainID uint64, ownerToken return DeploymentDetails{}, err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.DeployOwnerToken), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.DeployOwnerToken, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return DeploymentDetails{}, err + } return DeploymentDetails{address.Hex(), tx.Hash().Hex()}, nil } @@ -229,12 +239,17 @@ func (api *API) DeployAssets(ctx context.Context, chainID uint64, deploymentPara return DeploymentDetails{}, err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.DeployCommunityToken), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.DeployCommunityToken, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return DeploymentDetails{}, err + } return DeploymentDetails{address.Hex(), tx.Hash().Hex()}, nil } @@ -321,12 +336,17 @@ func (api *API) MintTokens(ctx context.Context, chainID uint64, contractAddress return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.AirdropCommunityToken), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.AirdropCommunityToken, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } return tx.Hash().Hex(), nil } @@ -426,12 +446,17 @@ func (api *API) RemoteBurn(ctx context.Context, chainID uint64, contractAddress return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.RemoteDestructCollectible), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.RemoteDestructCollectible, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } return tx.Hash().Hex(), nil } @@ -589,12 +614,17 @@ func (api *API) Burn(ctx context.Context, chainID uint64, contractAddress string return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.BurnCommunityToken), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.BurnCommunityToken, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } return tx.Hash().Hex(), nil } diff --git a/services/collectibles/service.go b/services/collectibles/service.go index ed6021238..e134fae46 100644 --- a/services/collectibles/service.go +++ b/services/collectibles/service.go @@ -8,7 +8,7 @@ import ( "github.com/status-im/status-go/account" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" - "github.com/status-im/status-go/services/rpcfilters" + "github.com/status-im/status-go/transactions" ) // Collectibles service @@ -17,9 +17,9 @@ type Service struct { } // Returns a new Collectibles Service. -func NewService(rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, config *params.NodeConfig, appDb *sql.DB) *Service { +func NewService(rpcClient *rpc.Client, accountsManager *account.GethManager, pendingTracker *transactions.PendingTxTracker, config *params.NodeConfig, appDb *sql.DB) *Service { return &Service{ - NewAPI(rpcClient, accountsManager, rpcFiltersSrvc, config, appDb), + NewAPI(rpcClient, accountsManager, pendingTracker, config, appDb), } } diff --git a/services/ens/api.go b/services/ens/api.go index 4d91f1d7a..7e79bf7bf 100644 --- a/services/ens/api.go +++ b/services/ens/api.go @@ -32,20 +32,20 @@ import ( "github.com/status-im/status-go/contracts/snt" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/utils" + wcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/transactions" ) const StatusDomain = "stateofus.eth" -func NewAPI(rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, config *params.NodeConfig, appDb *sql.DB, timeSource func() time.Time, syncUserDetailFunc *syncUsernameDetail) *API { +func NewAPI(rpcClient *rpc.Client, accountsManager *account.GethManager, pendingTracker *transactions.PendingTxTracker, config *params.NodeConfig, appDb *sql.DB, timeSource func() time.Time, syncUserDetailFunc *syncUsernameDetail) *API { return &API{ contractMaker: &contracts.ContractMaker{ RPCClient: rpcClient, }, accountsManager: accountsManager, - rpcFiltersSrvc: rpcFiltersSrvc, + pendingTracker: pendingTracker, config: config, addrPerChain: make(map[uint64]common.Address), db: NewEnsDatabase(appDb), @@ -68,7 +68,7 @@ type syncUsernameDetail func(context.Context, *UsernameDetail) error type API struct { contractMaker *contracts.ContractMaker accountsManager *account.GethManager - rpcFiltersSrvc *rpcfilters.Service + pendingTracker *transactions.PendingTxTracker config *params.NodeConfig addrPerChain map[uint64]common.Address @@ -353,12 +353,17 @@ func (api *API) Release(ctx context.Context, chainID uint64, txArgs transactions return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.ReleaseENS), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.ReleaseENS, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } err = api.Remove(ctx, chainID, fullDomainName(username)) @@ -443,14 +448,19 @@ func (api *API) Register(ctx context.Context, chainID uint64, txArgs transaction return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.RegisterENS), - From: common.Address(txArgs.From), - ChainID: chainID, - }) - err = api.Add(ctx, chainID, fullDomainName(username)) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.RegisterENS, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } + err = api.Add(ctx, chainID, fullDomainName(username)) if err != nil { log.Warn("Registering ENS username: transaction successful, but adding failed") } @@ -554,12 +564,18 @@ func (api *API) SetPubKey(ctx context.Context, chainID uint64, txArgs transactio return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.SetPubKey), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.SetPubKey, + transactions.AutoDelete, + ) + if err != nil { + log.Error("TrackPendingTransaction error", "error", err) + return "", err + } + err = api.Add(ctx, chainID, fullDomainName(username)) if err != nil { diff --git a/services/ens/service.go b/services/ens/service.go index 5edd70dcc..18ae8bbd0 100644 --- a/services/ens/service.go +++ b/services/ens/service.go @@ -9,20 +9,20 @@ import ( "github.com/status-im/status-go/account" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" - "github.com/status-im/status-go/services/rpcfilters" + "github.com/status-im/status-go/transactions" ) // NewService initializes service instance. -func NewService(rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, config *params.NodeConfig, appDb *sql.DB, timeSource func() time.Time) *Service { +func NewService(rpcClient *rpc.Client, accountsManager *account.GethManager, pendingTracker *transactions.PendingTxTracker, config *params.NodeConfig, appDb *sql.DB, timeSource func() time.Time) *Service { service := &Service{ rpcClient, accountsManager, - rpcFiltersSrvc, + pendingTracker, config, nil, nil, } - service.api = NewAPI(rpcClient, accountsManager, rpcFiltersSrvc, config, appDb, timeSource, &service.syncUserDetailFunc) + service.api = NewAPI(rpcClient, accountsManager, pendingTracker, config, appDb, timeSource, &service.syncUserDetailFunc) return service } @@ -30,7 +30,7 @@ func NewService(rpcClient *rpc.Client, accountsManager *account.GethManager, rpc type Service struct { rpcClient *rpc.Client accountsManager *account.GethManager - rpcFiltersSrvc *rpcfilters.Service + pendingTracker *transactions.PendingTxTracker config *params.NodeConfig api *API syncUserDetailFunc syncUsernameDetail diff --git a/services/stickers/api.go b/services/stickers/api.go index a99298c8d..20d787aaf 100644 --- a/services/stickers/api.go +++ b/services/stickers/api.go @@ -20,8 +20,8 @@ import ( "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/server" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/transactions" ) const maxConcurrentRequests = 3 @@ -41,7 +41,7 @@ type API struct { contractMaker *contracts.ContractMaker accountsManager *account.GethManager accountsDB *accounts.Database - rpcFiltersSrvc *rpcfilters.Service + pendingTracker *transactions.PendingTxTracker keyStoreDir string downloader *ipfs.Downloader @@ -86,14 +86,14 @@ type ednStickerPackInfo struct { Meta ednStickerPack } -func NewAPI(ctx context.Context, acc *accounts.Database, rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, keyStoreDir string, downloader *ipfs.Downloader, httpServer *server.MediaServer) *API { +func NewAPI(ctx context.Context, acc *accounts.Database, rpcClient *rpc.Client, accountsManager *account.GethManager, pendingTracker *transactions.PendingTxTracker, keyStoreDir string, downloader *ipfs.Downloader, httpServer *server.MediaServer) *API { result := &API{ contractMaker: &contracts.ContractMaker{ RPCClient: rpcClient, }, accountsManager: accountsManager, accountsDB: acc, - rpcFiltersSrvc: rpcFiltersSrvc, + pendingTracker: pendingTracker, keyStoreDir: keyStoreDir, downloader: downloader, ctx: ctx, diff --git a/services/stickers/service.go b/services/stickers/service.go index 37ecc6bd9..ae6b0ef06 100644 --- a/services/stickers/service.go +++ b/services/stickers/service.go @@ -11,24 +11,23 @@ import ( "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/server" - "github.com/status-im/status-go/services/rpcfilters" + "github.com/status-im/status-go/transactions" ) // NewService initializes service instance. -func NewService(acc *accounts.Database, rpcClient *rpc.Client, accountsManager *account.GethManager, rpcFiltersSrvc *rpcfilters.Service, config *params.NodeConfig, downloader *ipfs.Downloader, httpServer *server.MediaServer) *Service { +func NewService(acc *accounts.Database, rpcClient *rpc.Client, accountsManager *account.GethManager, config *params.NodeConfig, downloader *ipfs.Downloader, httpServer *server.MediaServer, pendingTracker *transactions.PendingTxTracker) *Service { ctx, cancel := context.WithCancel(context.Background()) return &Service{ accountsDB: acc, rpcClient: rpcClient, accountsManager: accountsManager, - rpcFiltersSrvc: rpcFiltersSrvc, keyStoreDir: config.KeyStoreDir, downloader: downloader, httpServer: httpServer, ctx: ctx, cancel: cancel, - api: NewAPI(ctx, acc, rpcClient, accountsManager, rpcFiltersSrvc, config.KeyStoreDir, downloader, httpServer), + api: NewAPI(ctx, acc, rpcClient, accountsManager, pendingTracker, config.KeyStoreDir, downloader, httpServer), } } @@ -37,7 +36,6 @@ type Service struct { accountsDB *accounts.Database rpcClient *rpc.Client accountsManager *account.GethManager - rpcFiltersSrvc *rpcfilters.Service downloader *ipfs.Downloader keyStoreDir string httpServer *server.MediaServer diff --git a/services/stickers/transactions.go b/services/stickers/transactions.go index 6ad0ce3fb..81b73dfa1 100644 --- a/services/stickers/transactions.go +++ b/services/stickers/transactions.go @@ -14,9 +14,9 @@ import ( "github.com/status-im/status-go/contracts/snt" "github.com/status-im/status-go/contracts/stickers" "github.com/status-im/status-go/eth-node/types" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/utils" "github.com/status-im/status-go/services/wallet/bigint" + wcommon "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/transactions" ) @@ -70,13 +70,17 @@ func (api *API) Buy(ctx context.Context, chainID uint64, txArgs transactions.Sen return "", err } - // TODO: track pending transaction (do this in ENS service too) - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ - Hash: tx.Hash(), - Type: string(transactions.BuyStickerPack), - From: common.Address(txArgs.From), - ChainID: chainID, - }) + err = api.pendingTracker.TrackPendingTransaction( + wcommon.ChainID(chainID), + tx.Hash(), + common.Address(txArgs.From), + transactions.BuyStickerPack, + transactions.AutoDelete, + ) + if err != nil { + return "", err + } + return tx.Hash().String(), nil } diff --git a/services/wallet/activity/activity.go b/services/wallet/activity/activity.go index 5fe5b018d..627a8ebb3 100644 --- a/services/wallet/activity/activity.go +++ b/services/wallet/activity/activity.go @@ -18,6 +18,7 @@ import ( "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/transfer" + "github.com/status-im/status-go/transactions" "golang.org/x/exp/constraints" ) @@ -301,7 +302,9 @@ const ( ? AS includeAllTokenTypeAssets, - ? AS includeAllNetworks + ? AS includeAllNetworks, + + ? AS pendingStatus ), filter_addresses(address) AS ( SELECT HEX(address) FROM %s WHERE (SELECT filterAllAddresses FROM filter_conditions) != 0 @@ -348,16 +351,16 @@ const ( COUNT(*) AS count, network_id FROM - pending_transactions - WHERE pending_transactions.multi_transaction_id != 0 + pending_transactions, filter_conditions + WHERE pending_transactions.multi_transaction_id != 0 AND pending_transactions.status = pendingStatus GROUP BY pending_transactions.multi_transaction_id ), pending_network_ids AS ( SELECT multi_transaction_id FROM - pending_transactions - WHERE pending_transactions.multi_transaction_id != 0 + pending_transactions, filter_conditions + WHERE pending_transactions.multi_transaction_id != 0 AND pending_transactions.status = pendingStatus AND pending_transactions.network_id IN filter_networks GROUP BY pending_transactions.multi_transaction_id ) @@ -485,7 +488,7 @@ const ( filter_addresses from_join ON HEX(pending_transactions.from_address) = from_join.address LEFT JOIN filter_addresses to_join ON HEX(pending_transactions.to_address) = to_join.address - WHERE pending_transactions.multi_transaction_id = 0 + WHERE pending_transactions.multi_transaction_id = 0 AND pending_transactions.status = pendingStatus AND (filterAllActivityStatus OR filterStatusPending) AND ((startFilterDisabled OR timestamp >= startTimestamp) AND (endFilterDisabled OR timestamp <= endTimestamp) @@ -678,6 +681,7 @@ func getActivityEntries(ctx context.Context, deps FilterDependencies, addresses FailedAS, CompleteAS, PendingAS, includeAllTokenTypeAssets, includeAllNetworks, + transactions.Pending, limit, offset) if err != nil { return nil, err diff --git a/services/wallet/api.go b/services/wallet/api.go index 7181f7f15..612c238e8 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -2,6 +2,8 @@ package wallet import ( "context" + "encoding/json" + "errors" "fmt" "math/big" "strings" @@ -24,6 +26,7 @@ import ( "github.com/status-im/status-go/services/wallet/thirdparty/opensea" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" + "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/transactions" ) @@ -232,67 +235,61 @@ func (api *API) DeleteSavedAddress(ctx context.Context, address common.Address, } func (api *API) GetPendingTransactions(ctx context.Context) ([]*transactions.PendingTransaction, error) { - log.Debug("call to get pending transactions") - rst, err := api.s.pendingTxManager.GetAllPending([]uint64{api.s.rpcClient.UpstreamChainID}) - log.Debug("result from database for pending transactions", "len", len(rst)) - return rst, err -} - -func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*transactions.PendingTransaction, error) { - log.Debug("call to get pending transactions") - rst, err := api.s.pendingTxManager.GetAllPending(chainIDs) - log.Debug("result from database for pending transactions", "len", len(rst)) + log.Debug("wallet.api.GetPendingTransactions") + rst, err := api.s.pendingTxManager.GetAllPending() + log.Debug("wallet.api.GetPendingTransactions RESULT", "len", len(rst)) return rst, err } func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identities []transfer.TransactionIdentity) ( result []*transactions.PendingTransaction, err error) { - log.Debug("call to GetPendingTransactionsForIdentities") + log.Debug("wallet.api.GetPendingTransactionsForIdentities") result = make([]*transactions.PendingTransaction, 0, len(identities)) var pt *transactions.PendingTransaction for _, identity := range identities { - pt, err = api.s.pendingTxManager.GetPendingEntry(uint64(identity.ChainID), identity.Hash) + pt, err = api.s.pendingTxManager.GetPendingEntry(identity.ChainID, identity.Hash) result = append(result, pt) } - log.Debug("result from GetPendingTransactionsForIdentities", "len", len(result)) + log.Debug("wallet.api.GetPendingTransactionsForIdentities RES", "len", len(result)) return } -func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ( - []*transactions.PendingTransaction, error) { +// TODO - #11861: Remove this and replace with EventPendingTransactionStatusChanged event and Delete to confirm the transaction where it is needed +func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) (err error) { + log.Debug("wallet.api.WatchTransactionByChainID", "chainID", chainID, "transactionHash", transactionHash) + var status *transactions.TxStatus + defer log.Debug("wallet.api.WatchTransactionByChainID return", "err", err, "chainID", chainID, "transactionHash", transactionHash) - log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.pendingTxManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address) - log.Debug("result from database for pending transactions by address", "len", len(rst)) - return rst, err -} + // Workaround to keep the blocking call until the clients use the PendingTxTracker APIs + eventChan := make(chan walletevent.Event, 2) + sub := api.s.feed.Subscribe(eventChan) + defer sub.Unsubscribe() -func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, - address common.Address) ([]*transactions.PendingTransaction, error) { - - log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.pendingTxManager.GetPendingByAddress(chainIDs, address) - log.Debug("result from database for pending transactions by address", "len", len(rst)) - return rst, err -} - -func (api *API) WatchTransaction(ctx context.Context, transactionHash common.Hash) error { - chainClient, err := api.s.rpcClient.EthClient(api.s.rpcClient.UpstreamChainID) - if err != nil { - return err + status, err = api.s.pendingTxManager.Watch(ctx, wcommon.ChainID(chainID), transactionHash) + if err == nil && *status != transactions.Pending { + return nil } - return api.s.pendingTxManager.Watch(ctx, transactionHash, chainClient) -} -func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) error { - chainClient, err := api.s.rpcClient.EthClient(chainID) - if err != nil { - return err + for { + select { + case we := <-eventChan: + if transactions.EventPendingTransactionStatusChanged == we.Type { + var p transactions.StatusChangedPayload + err = json.Unmarshal([]byte(we.Message), &p) + if err != nil { + return err + } + if p.ChainID == wcommon.ChainID(chainID) && p.Hash == transactionHash { + return nil + } + } + case <-time.After(10 * time.Minute): + return errors.New("timeout watching for pending transaction") + } } - return api.s.pendingTxManager.Watch(ctx, transactionHash, chainClient) } func (api *API) GetCryptoOnRamps(ctx context.Context) ([]CryptoOnRamp, error) { diff --git a/services/wallet/service.go b/services/wallet/service.go index 819e36947..4d21fc638 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -17,7 +17,6 @@ import ( "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/ens" - "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/stickers" "github.com/status-im/status-go/services/wallet/activity" "github.com/status-im/status-go/services/wallet/collectibles" @@ -51,14 +50,15 @@ func NewService( config *params.NodeConfig, ens *ens.Service, stickers *stickers.Service, - rpcFilterSrvc *rpcfilters.Service, + pendingTxManager *transactions.PendingTxTracker, + feed *event.Feed, ) *Service { cryptoOnRampManager := NewCryptoOnRampManager(&CryptoOnRampOptions{ dataSourceType: DataSourceStatic, }) - walletFeed := &event.Feed{} + signals := &walletevent.SignalsTransmitter{ - Publisher: walletFeed, + Publisher: feed, } blockchainStatus := make(map[uint64]string) mutex := sync.Mutex{} @@ -83,7 +83,7 @@ func NewService( return } - walletFeed.Send(walletevent.Event{ + feed.Send(walletevent.Event{ Type: EventBlockchainStatusChanged, Accounts: []common.Address{}, Message: string(encodedmessage), @@ -93,21 +93,20 @@ func NewService( }) tokenManager := token.NewTokenManager(db, rpcClient, rpcClient.NetworkManager) savedAddressesManager := &SavedAddressesManager{db: db} - pendingTxManager := transactions.NewTransactionManager(db, rpcFilterSrvc.TransactionSentToUpstreamEvent(), walletFeed) - transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, walletFeed) - transferController := transfer.NewTransferController(db, rpcClient, accountFeed, walletFeed, transactionManager, pendingTxManager, + transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, feed) + transferController := transfer.NewTransferController(db, rpcClient, accountFeed, feed, transactionManager, pendingTxManager, tokenManager, config.WalletConfig.LoadAllTransfers) cryptoCompare := cryptocompare.NewClient() coingecko := coingecko.NewClient() - marketManager := market.NewManager(cryptoCompare, coingecko, walletFeed) - reader := NewReader(rpcClient, tokenManager, marketManager, accountsDB, NewPersistence(db), walletFeed) - history := history.NewService(db, walletFeed, rpcClient, tokenManager, marketManager) - currency := currency.NewService(db, walletFeed, tokenManager, marketManager) - activity := activity.NewService(db, tokenManager, walletFeed, accountsDB) + marketManager := market.NewManager(cryptoCompare, coingecko, feed) + reader := NewReader(rpcClient, tokenManager, marketManager, accountsDB, NewPersistence(db), feed) + history := history.NewService(db, feed, rpcClient, tokenManager, marketManager) + currency := currency.NewService(db, feed, tokenManager, marketManager) + activity := activity.NewService(db, tokenManager, feed, accountsDB) openseaHTTPClient := opensea.NewHTTPClient() - openseaClient := opensea.NewClient(config.WalletConfig.OpenseaAPIKey, openseaHTTPClient, walletFeed) - openseaV2Client := opensea.NewClientV2(config.WalletConfig.OpenseaAPIKey, openseaHTTPClient, walletFeed) + openseaClient := opensea.NewClient(config.WalletConfig.OpenseaAPIKey, openseaHTTPClient, feed) + openseaV2Client := opensea.NewClientV2(config.WalletConfig.OpenseaAPIKey, openseaHTTPClient, feed) infuraClient := infura.NewClient(config.WalletConfig.InfuraAPIKey, config.WalletConfig.InfuraAPIKeySecret) alchemyClient := alchemy.NewClient(config.WalletConfig.AlchemyAPIKeys) @@ -138,7 +137,7 @@ func NewService( } collectiblesManager := collectibles.NewManager(db, rpcClient, contractOwnershipProviders, accountOwnershipProviders, collectibleDataProviders, collectionDataProviders, openseaClient) - collectibles := collectibles.NewService(db, walletFeed, accountsDB, accountFeed, rpcClient.NetworkManager, collectiblesManager) + collectibles := collectibles.NewService(db, feed, accountsDB, accountFeed, rpcClient.NetworkManager, collectiblesManager) return &Service{ db: db, accountsDB: accountsDB, @@ -157,8 +156,7 @@ func NewService( transactor: transactor, ens: ens, stickers: stickers, - rpcFilterSrvc: rpcFilterSrvc, - feed: walletFeed, + feed: feed, signals: signals, reader: reader, history: history, @@ -176,7 +174,7 @@ type Service struct { savedAddressesManager *SavedAddressesManager tokenManager *token.Manager transactionManager *transfer.TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker cryptoOnRampManager *CryptoOnRampManager transferController *transfer.Controller feesManager *FeeManager @@ -188,7 +186,6 @@ type Service struct { transactor *transactions.Transactor ens *ens.Service stickers *stickers.Service - rpcFilterSrvc *rpcfilters.Service feed *event.Feed signals *walletevent.SignalsTransmitter reader *Reader @@ -204,17 +201,11 @@ func (s *Service) Start() error { s.currency.Start() err := s.signals.Start() s.history.Start() - _ = s.pendingTxManager.Start() s.collectibles.Start() s.started = true return err } -// GetFeed returns signals feed. -func (s *Service) GetFeed() *event.Feed { - return s.transferController.TransferFeed -} - // Set external Collectibles metadata provider func (s *Service) SetCollectibleMetadataProvider(provider thirdparty.CollectibleMetadataProvider) { s.collectiblesManager.SetMetadataProvider(provider) @@ -229,7 +220,6 @@ func (s *Service) Stop() error { s.reader.Stop() s.history.Stop() s.activity.Stop() - s.pendingTxManager.Stop() s.collectibles.Stop() s.started = false log.Info("wallet stopped") diff --git a/services/wallet/transfer/commands.go b/services/wallet/transfer/commands.go index 864a3b339..cdc598d31 100644 --- a/services/wallet/transfer/commands.go +++ b/services/wallet/transfer/commands.go @@ -2,7 +2,7 @@ package transfer import ( "context" - "database/sql" + "fmt" "math/big" "strings" "time" @@ -187,7 +187,7 @@ type controlCommand struct { errorsCount int nonArchivalRPCNode bool transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager } @@ -366,7 +366,7 @@ type transfersCommand struct { chainClient *chain.ClientWithFallback blocksLimit int transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager feed *event.Feed @@ -410,9 +410,9 @@ func (c *transfersCommand) Run(ctx context.Context) (err error) { } if len(allTransfers) > 0 { - err = c.db.SaveTransfersMarkBlocksLoaded(c.chainClient.ChainID, c.address, allTransfers, []*big.Int{blockNum}) + err := c.saveAndConfirmPending(allTransfers, blockNum) if err != nil { - log.Error("SaveTransfers error", "error", err) + log.Error("saveAndConfirmPending error", "error", err) return err } } else { @@ -448,6 +448,54 @@ func (c *transfersCommand) Run(ctx context.Context) (err error) { return nil } +func (c *transfersCommand) saveAndConfirmPending(allTransfers []Transfer, blockNum *big.Int) error { + tx, err := c.db.client.Begin() + if err != nil { + return err + } + + notifyFunctions := make([]func(), 0) + // Confirm all pending transactions that are included in this block + for i, transfer := range allTransfers { + txType, MTID, err := transactions.GetTransferData(tx, w_common.ChainID(transfer.NetworkID), transfer.Receipt.TxHash) + if err != nil { + log.Error("GetTransferData error", "error", err) + } + if MTID != nil { + allTransfers[i].MultiTransactionID = MultiTransactionIDType(*MTID) + } + if txType != nil && *txType == transactions.WalletTransfer { + notify, err := c.pendingTxManager.DeleteBySQLTx(tx, w_common.ChainID(transfer.NetworkID), transfer.Receipt.TxHash) + if err != nil && err != transactions.ErrStillPending { + log.Error("DeleteBySqlTx error", "error", err) + } + notifyFunctions = append(notifyFunctions, notify) + } + } + + err = saveTransfersMarkBlocksLoaded(tx, c.chainClient.ChainID, c.address, allTransfers, []*big.Int{blockNum}) + if err != nil { + log.Error("SaveTransfers error", "error", err) + return err + } + + if err == nil { + err = tx.Commit() + if err != nil { + return err + } + for _, notify := range notifyFunctions { + notify() + } + } else { + err = tx.Rollback() + if err != nil { + return fmt.Errorf("failed to rollback: %w", err) + } + } + return nil +} + // Mark all subTxs of a given Tx with the same multiTxID func setMultiTxID(tx Transaction, multiTxID MultiTransactionIDType) { for _, subTx := range tx { @@ -455,28 +503,6 @@ func setMultiTxID(tx Transaction, multiTxID MultiTransactionIDType) { } } -func (c *transfersCommand) propagatePendingMultiTx(tx Transaction) error { - multiTxID := NoMultiTransactionID - // If any subTx matches a pending entry, mark all of them with the corresponding multiTxID - for _, subTx := range tx { - // Update MultiTransactionID from pending entry - entry, err := c.pendingTxManager.GetPendingEntry(c.chainClient.ChainID, subTx.ID) - if err == nil { - // Propagate the MultiTransactionID, in case the pending entry was a multi-transaction - multiTxID = MultiTransactionIDType(entry.MultiTransactionID) - break - } else if err != sql.ErrNoRows { - log.Error("GetPendingEntry error", "error", err) - return err - } - } - - if multiTxID != NoMultiTransactionID { - setMultiTxID(tx, multiTxID) - } - return nil -} - func (c *transfersCommand) checkAndProcessSwapMultiTx(ctx context.Context, tx Transaction) (bool, error) { for _, subTx := range tx { switch subTx.Type { @@ -527,13 +553,6 @@ func (c *transfersCommand) processMultiTransactions(ctx context.Context, allTran // Detect / Generate multitransactions // Iterate over all detected transactions for _, tx := range txByTxHash { - var err error - // First check for pre-existing pending transaction - err = c.propagatePendingMultiTx(tx) - if err != nil { - return err - } - // Then check for a Swap transaction txProcessed, err := c.checkAndProcessSwapMultiTx(ctx, tx) if err != nil { @@ -571,7 +590,7 @@ type loadTransfersCommand struct { chainClient *chain.ClientWithFallback blocksByAddress map[common.Address][]*big.Int transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker blocksLimit int tokenManager *token.Manager feed *event.Feed @@ -761,7 +780,7 @@ func (c *findAndCheckBlockRangeCommand) fastIndexErc20(ctx context.Context, from func loadTransfers(ctx context.Context, accounts []common.Address, blockDAO *BlockDAO, db *Database, chainClient *chain.ClientWithFallback, blocksLimitPerAccount int, blocksByAddress map[common.Address][]*big.Int, - transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager, feed *event.Feed) error { log.Info("loadTransfers start", "accounts", accounts, "chain", chainClient.ChainID, "limit", blocksLimitPerAccount) diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index 42d5c3a13..d7a68b14e 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -319,7 +319,7 @@ func (c *findBlocksCommand) fastIndexErc20(ctx context.Context, fromBlockNumber } func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *BlockDAO, db *Database, - chainClient *chain.ClientWithFallback, transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + chainClient *chain.ClientWithFallback, transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager, feed *event.Feed, blocksLoadedCh <-chan []*DBHeader) { log.Debug("loadTransfersLoop start", "chain", chainClient.ChainID, "account", account) @@ -348,7 +348,7 @@ func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *Bl func newLoadBlocksAndTransfersCommand(account common.Address, db *Database, blockDAO *BlockDAO, chainClient *chain.ClientWithFallback, feed *event.Feed, - transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager) *loadBlocksAndTransfersCommand { return &loadBlocksAndTransfersCommand{ @@ -377,7 +377,7 @@ type loadBlocksAndTransfersCommand struct { errorsCount int // nonArchivalRPCNode bool // TODO Make use of it transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager blocksLoadedCh chan []*DBHeader diff --git a/services/wallet/transfer/controller.go b/services/wallet/transfer/controller.go index 148bf0f98..2d850c542 100644 --- a/services/wallet/transfer/controller.go +++ b/services/wallet/transfer/controller.go @@ -3,6 +3,7 @@ package transfer import ( "context" "database/sql" + "fmt" "math/big" "time" @@ -27,13 +28,13 @@ type Controller struct { TransferFeed *event.Feed group *async.Group transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager loadAllTransfers bool } func NewTransferController(db *sql.DB, rpcClient *rpc.Client, accountFeed *event.Feed, transferFeed *event.Feed, - transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, tokenManager *token.Manager, loadAllTransfers bool) *Controller { + transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager, loadAllTransfers bool) *Controller { blockDAO := &BlockDAO{db} return &Controller{ @@ -211,12 +212,21 @@ func (c *Controller) LoadTransferByHash(ctx context.Context, rpcClient *rpc.Clie return err } - blocks := []*big.Int{transfer.BlockNumber} - err = c.db.SaveTransfersMarkBlocksLoaded(rpcClient.UpstreamChainID, address, transfers, blocks) + tx, err := c.db.client.BeginTx(ctx, nil) if err != nil { return err } + blocks := []*big.Int{transfer.BlockNumber} + err = saveTransfersMarkBlocksLoaded(tx, rpcClient.UpstreamChainID, address, transfers, blocks) + if err != nil { + rollErr := tx.Rollback() + if rollErr != nil { + return fmt.Errorf("failed to rollback transaction due to error: %v", err) + } + return err + } + return nil } diff --git a/services/wallet/transfer/database.go b/services/wallet/transfer/database.go index e6908f642..f85350776 100644 --- a/services/wallet/transfer/database.go +++ b/services/wallet/transfer/database.go @@ -170,53 +170,17 @@ func (db *Database) ProcessTransfers(chainID uint64, transfers []Transfer, remov return } -// SaveTransfersMarkBlocksLoaded -func (db *Database) SaveTransfersMarkBlocksLoaded(chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) { - err = db.SaveTransfers(chainID, address, transfers) - if err != nil { - return - } - - var tx *sql.Tx - tx, err = db.client.Begin() - if err != nil { - return err - } - defer func() { - if err == nil { - err = tx.Commit() - return - } - _ = tx.Rollback() - }() - err = markBlocksAsLoaded(chainID, tx, address, blocks) - if err != nil { - return - } - - return -} - -// SaveTransfers -func (db *Database) SaveTransfers(chainID uint64, address common.Address, transfers []Transfer) (err error) { - var tx *sql.Tx - tx, err = db.client.Begin() - if err != nil { - return err - } - defer func() { - if err == nil { - err = tx.Commit() - return - } - _ = tx.Rollback() - }() - +func saveTransfersMarkBlocksLoaded(tx *sql.Tx, chainID uint64, address common.Address, transfers []Transfer, blocks []*big.Int) (err error) { err = updateOrInsertTransfers(chainID, tx, transfers) if err != nil { return } + err = markBlocksAsLoaded(chainID, tx, address, blocks) + if err != nil { + return + } + return } diff --git a/services/wallet/transfer/database_test.go b/services/wallet/transfer/database_test.go index e5014e95e..365cca8a0 100644 --- a/services/wallet/transfer/database_test.go +++ b/services/wallet/transfer/database_test.go @@ -58,7 +58,11 @@ func TestDBProcessBlocks(t *testing.T) { From: common.Address{1}, }, } - require.NoError(t, db.SaveTransfersMarkBlocksLoaded(777, address, transfers, []*big.Int{big.NewInt(1), big.NewInt(2)})) + tx, err := db.client.BeginTx(context.Background(), nil) + require.NoError(t, err) + + require.NoError(t, saveTransfersMarkBlocksLoaded(tx, 777, address, transfers, []*big.Int{big.NewInt(1), big.NewInt(2)})) + require.NoError(t, tx.Commit()) } func TestDBProcessTransfer(t *testing.T) { diff --git a/services/wallet/transfer/reactor.go b/services/wallet/transfer/reactor.go index f20847067..5e6b6f826 100644 --- a/services/wallet/transfer/reactor.go +++ b/services/wallet/transfer/reactor.go @@ -65,7 +65,7 @@ type OnDemandFetchStrategy struct { group *async.Group balanceCache *balanceCache transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager chainClients map[uint64]*chain.ClientWithFallback accounts []common.Address @@ -239,13 +239,13 @@ type Reactor struct { blockDAO *BlockDAO feed *event.Feed transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager strategy HistoryFetcher } func NewReactor(db *Database, blockDAO *BlockDAO, feed *event.Feed, tm *TransactionManager, - pendingTxManager *transactions.TransactionManager, tokenManager *token.Manager) *Reactor { + pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager) *Reactor { return &Reactor{ db: db, blockDAO: blockDAO, diff --git a/services/wallet/transfer/sequential_fetch_strategy.go b/services/wallet/transfer/sequential_fetch_strategy.go index 5e941d9e0..1f6adc597 100644 --- a/services/wallet/transfer/sequential_fetch_strategy.go +++ b/services/wallet/transfer/sequential_fetch_strategy.go @@ -16,7 +16,7 @@ import ( ) func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Feed, - transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + transactionManager *TransactionManager, pendingTxManager *transactions.PendingTxTracker, tokenManager *token.Manager, chainClients map[uint64]*chain.ClientWithFallback, accounts []common.Address) *SequentialFetchStrategy { @@ -40,7 +40,7 @@ type SequentialFetchStrategy struct { mu sync.Mutex group *async.Group transactionManager *TransactionManager - pendingTxManager *transactions.TransactionManager + pendingTxManager *transactions.PendingTxTracker tokenManager *token.Manager chainClients map[uint64]*chain.ClientWithFallback accounts []common.Address diff --git a/services/wallet/transfer/transaction.go b/services/wallet/transfer/transaction.go index 9deb5df25..2ee1f4fe8 100644 --- a/services/wallet/transfer/transaction.go +++ b/services/wallet/transfer/transaction.go @@ -39,7 +39,7 @@ type TransactionManager struct { transactor *transactions.Transactor config *params.NodeConfig accountsDB *accounts.Database - pendingManager *transactions.TransactionManager + pendingTracker *transactions.PendingTxTracker eventFeed *event.Feed } @@ -49,7 +49,7 @@ func NewTransactionManager( transactor *transactions.Transactor, config *params.NodeConfig, accountsDB *accounts.Database, - pendingTxManager *transactions.TransactionManager, + pendingTxManager *transactions.PendingTxTracker, eventFeed *event.Feed, ) *TransactionManager { return &TransactionManager{ @@ -58,7 +58,7 @@ func NewTransactionManager( transactor: transactor, config: config, accountsDB: accountsDB, - pendingManager: pendingTxManager, + pendingTracker: pendingTxManager, eventFeed: eventFeed, } } @@ -307,7 +307,7 @@ func (tm *TransactionManager) storePendingTransactions(multiTransaction *MultiTr txs := createPendingTransactions(hashes, data, multiTransaction) for _, tx := range txs { - err := tm.pendingManager.AddPending(tx) + err := tm.pendingTracker.StoreAndTrackPendingTx(tx) if err != nil { return err } @@ -329,10 +329,13 @@ func createPendingTransactions(hashes map[uint64][]types.Hash, data []*bridge.Tr To: common.Address(tx.To()), Data: tx.Data().String(), Type: transactions.WalletTransfer, - ChainID: tx.ChainID, + ChainID: wallet_common.ChainID(tx.ChainID), MultiTransactionID: int64(multiTransaction.ID), Symbol: multiTransaction.FromAsset, + AutoDelete: new(bool), } + // Transaction downloader will delete pending transaction as soon as it is confirmed + *pendingTransaction.AutoDelete = false txs = append(txs, pendingTransaction) } } diff --git a/transactions/conditionalrepeater.go b/transactions/conditionalrepeater.go new file mode 100644 index 000000000..a534b339b --- /dev/null +++ b/transactions/conditionalrepeater.go @@ -0,0 +1,95 @@ +package transactions + +import ( + "context" + "sync" + "time" +) + +// TaskFunc defines the task to be run. The context is canceled when Stop is +// called to early stop scheduled task. +type TaskFunc func(ctx context.Context) (done bool) + +const ( + WorkNotDone = false + WorkDone = true +) + +// ConditionalRepeater runs a task at regular intervals until the task returns +// true. It doesn't allow running task in parallel and can be triggered early +// by call to RunUntilDone. +type ConditionalRepeater struct { + interval time.Duration + task TaskFunc + // nil if not running + ctx context.Context + cancel context.CancelFunc + runNowCh chan bool + runNowMu sync.Mutex + onceMu sync.Mutex +} + +func NewConditionalRepeater(interval time.Duration, task TaskFunc) *ConditionalRepeater { + return &ConditionalRepeater{ + interval: interval, + task: task, + runNowCh: make(chan bool, 1), + } +} + +// RunUntilDone starts the task immediately and continues to run it at the defined +// interval until the task returns true. Can be called multiple times but it +// does not allow multiple concurrent executions of the task. +func (t *ConditionalRepeater) RunUntilDone() { + t.onceMu.Lock() + defer func() { + if len(t.runNowCh) == 0 { + t.runNowCh <- true + } + t.onceMu.Unlock() + }() + + if t.ctx != nil { + return + } + t.ctx, t.cancel = context.WithCancel(context.Background()) + + go func() { + defer func() { + t.runNowMu.Lock() + defer t.runNowMu.Unlock() + t.cancel() + t.ctx = nil + }() + + ticker := time.NewTicker(t.interval) + defer ticker.Stop() + + for { + select { + // Stop was called or task returned true + case <-t.ctx.Done(): + return + // Scheduled execution + case <-ticker.C: + if t.task(t.ctx) { + return + } + // Start right away if requested + case <-t.runNowCh: + ticker.Reset(t.interval) + if t.task(t.ctx) { + return + } + } + } + }() +} + +// Stop forcefully stops the running task by canceling its context. +func (t *ConditionalRepeater) Stop() { + t.onceMu.Lock() + defer t.onceMu.Unlock() + t.cancel() + t.ctx = nil +} diff --git a/transactions/conditionalrepeater_test.go b/transactions/conditionalrepeater_test.go new file mode 100644 index 000000000..8551844e3 --- /dev/null +++ b/transactions/conditionalrepeater_test.go @@ -0,0 +1,79 @@ +package transactions + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestConditionalRepeater_RunOnce(t *testing.T) { + var wg sync.WaitGroup + runCount := 0 + wg.Add(1) + taskRunner := NewConditionalRepeater(1*time.Nanosecond, func(ctx context.Context) bool { + runCount++ + defer wg.Done() + return WorkDone + }) + taskRunner.RunUntilDone() + // Wait for task to run + wg.Wait() + taskRunner.Stop() + require.Greater(t, runCount, 0) +} + +func TestConditionalRepeater_RunUntilDone_MultipleCalls(t *testing.T) { + var wg sync.WaitGroup + wg.Add(5) + runCount := 0 + taskRunner := NewConditionalRepeater(1*time.Nanosecond, func(ctx context.Context) bool { + runCount++ + wg.Done() + return runCount == 5 + }) + for i := 0; i < 10; i++ { + taskRunner.RunUntilDone() + } + // Wait for all tasks to run + wg.Wait() + taskRunner.Stop() + require.Greater(t, runCount, 4) +} + +func TestConditionalRepeater_Stop(t *testing.T) { + var taskRunningWG, taskCanceledWG, taskFinishedWG sync.WaitGroup + taskRunningWG.Add(1) + taskCanceledWG.Add(1) + taskFinishedWG.Add(1) + taskRunner := NewConditionalRepeater(1*time.Nanosecond, func(ctx context.Context) bool { + defer taskFinishedWG.Done() + select { + case <-ctx.Done(): + require.Fail(t, "task should not be canceled yet") + default: + } + + // Wait to caller to stop the task + taskRunningWG.Done() + taskCanceledWG.Wait() + + select { + case <-ctx.Done(): + require.Error(t, ctx.Err()) + default: + require.Fail(t, "task should be canceled") + } + + return WorkDone + }) + taskRunner.RunUntilDone() + taskRunningWG.Wait() + + taskRunner.Stop() + taskCanceledWG.Done() + + taskFinishedWG.Wait() +} diff --git a/transactions/pending.go b/transactions/pending.go deleted file mode 100644 index 9af735d31..000000000 --- a/transactions/pending.go +++ /dev/null @@ -1,364 +0,0 @@ -package transactions - -import ( - "context" - "database/sql" - "errors" - "fmt" - "math/big" - "strings" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/log" - "github.com/status-im/status-go/rpc/chain" - "github.com/status-im/status-go/services/rpcfilters" - "github.com/status-im/status-go/services/wallet/async" - "github.com/status-im/status-go/services/wallet/bigint" - "github.com/status-im/status-go/services/wallet/walletevent" -) - -const ( - // PendingTransactionUpdate is emitted when a pending transaction is updated (added or deleted) - EventPendingTransactionUpdate walletevent.EventType = "pending-transaction-update" -) - -type TransactionManager struct { - db *sql.DB - pendingTxEvent rpcfilters.ChainEvent - eventFeed *event.Feed - quit chan struct{} -} - -func NewTransactionManager(db *sql.DB, pendingTxEvent rpcfilters.ChainEvent, eventFeed *event.Feed) *TransactionManager { - return &TransactionManager{ - db: db, - eventFeed: eventFeed, - pendingTxEvent: pendingTxEvent, - } -} - -func (tm *TransactionManager) Start() error { - if tm.quit != nil { - return errors.New("latest transaction sent to upstream event is already started") - } - - tm.quit = make(chan struct{}) - - go func() { - _, chi := tm.pendingTxEvent.Subscribe() - ch, ok := chi.(chan *rpcfilters.PendingTxInfo) - if !ok { - panic("pendingTxEvent returned wront type of channel") - } - - for { - select { - case tx := <-ch: - log.Info("Pending transaction event received", tx) - err := tm.AddPending(&PendingTransaction{ - Hash: tx.Hash, - Timestamp: uint64(time.Now().Unix()), - From: tx.From, - ChainID: tx.ChainID, - }) - if err != nil { - log.Error("Failed to add pending transaction", "error", err, "hash", tx.Hash, - "chainID", tx.ChainID) - } - case <-tm.quit: - return - } - } - }() - - return tm.pendingTxEvent.Start() -} - -func (tm *TransactionManager) Stop() { - if tm.quit == nil { - return - } - - select { - case <-tm.quit: - return - default: - close(tm.quit) - } - - tm.quit = nil -} - -type PendingTrxType string - -const ( - RegisterENS PendingTrxType = "RegisterENS" - ReleaseENS PendingTrxType = "ReleaseENS" - SetPubKey PendingTrxType = "SetPubKey" - BuyStickerPack PendingTrxType = "BuyStickerPack" - WalletTransfer PendingTrxType = "WalletTransfer" - DeployCommunityToken PendingTrxType = "DeployCommunityToken" - AirdropCommunityToken PendingTrxType = "AirdropCommunityToken" - RemoteDestructCollectible PendingTrxType = "RemoteDestructCollectible" - BurnCommunityToken PendingTrxType = "BurnCommunityToken" - DeployOwnerToken PendingTrxType = "DeployOwnerToken" -) - -type PendingTransaction struct { - Hash common.Hash `json:"hash"` - Timestamp uint64 `json:"timestamp"` - Value bigint.BigInt `json:"value"` - From common.Address `json:"from"` - To common.Address `json:"to"` - Data string `json:"data"` - Symbol string `json:"symbol"` - GasPrice bigint.BigInt `json:"gasPrice"` - GasLimit bigint.BigInt `json:"gasLimit"` - Type PendingTrxType `json:"type"` - AdditionalData string `json:"additionalData"` - ChainID uint64 `json:"network_id"` - MultiTransactionID int64 `json:"multi_transaction_id"` -} - -const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data, - symbol, gas_price, gas_limit, type, additional_data, - network_id, COALESCE(multi_transaction_id, 0) - FROM pending_transactions - ` - -func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err error) { - for rows.Next() { - transaction := &PendingTransaction{ - Value: bigint.BigInt{Int: new(big.Int)}, - GasPrice: bigint.BigInt{Int: new(big.Int)}, - GasLimit: bigint.BigInt{Int: new(big.Int)}, - } - err := rows.Scan(&transaction.Hash, - &transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - &transaction.From, - &transaction.To, - &transaction.Data, - &transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - &transaction.Type, - &transaction.AdditionalData, - &transaction.ChainID, - &transaction.MultiTransactionID, - ) - if err != nil { - return nil, err - } - - transactions = append(transactions, transaction) - } - return transactions, nil -} - -func (tm *TransactionManager) GetAllPending(chainIDs []uint64) ([]*PendingTransaction, error) { - log.Info("Getting all pending transactions", "chainIDs", chainIDs) - - if len(chainIDs) == 0 { - return nil, errors.New("at least 1 chainID is required") - } - - inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" - var parameters []interface{} - for _, c := range chainIDs { - parameters = append(parameters, c) - } - - rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s)", inVector), parameters...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToTransactions(rows) -} - -func (tm *TransactionManager) GetPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) { - log.Info("Getting pending transaction by address", "chainIDs", chainIDs, "address", address) - - if len(chainIDs) == 0 { - return nil, errors.New("at least 1 chainID is required") - } - - inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" - var parameters []interface{} - for _, c := range chainIDs { - parameters = append(parameters, c) - } - - parameters = append(parameters, address) - - rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s) AND from_address = ?", inVector), parameters...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToTransactions(rows) -} - -// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity -// TODO: consider using address also in case we expect to have also for the receiver -func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) { - log.Info("Getting pending transaction", "chainID", chainID, "hash", hash) - - row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data, - symbol, gas_price, gas_limit, type, additional_data, - network_id, COALESCE(multi_transaction_id, 0) - FROM pending_transactions - WHERE network_id = ? AND hash = ?`, chainID, hash) - transaction := &PendingTransaction{ - Hash: hash, - Value: bigint.BigInt{Int: new(big.Int)}, - GasPrice: bigint.BigInt{Int: new(big.Int)}, - GasLimit: bigint.BigInt{Int: new(big.Int)}, - ChainID: chainID, - } - err := row.Scan( - &transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - &transaction.From, - &transaction.To, - &transaction.Data, - &transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - &transaction.Type, - &transaction.AdditionalData, - &transaction.ChainID, - &transaction.MultiTransactionID, - ) - if err != nil { - return nil, err - } - - return transaction, nil -} - -func (tm *TransactionManager) AddPending(transaction *PendingTransaction) error { - insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO pending_transactions - (network_id, hash, timestamp, value, from_address, to_address, - data, symbol, gas_price, gas_limit, type, additional_data, multi_transaction_id) - VALUES - (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) - if err != nil { - return err - } - _, err = insert.Exec( - transaction.ChainID, - transaction.Hash, - transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - transaction.From, - transaction.To, - transaction.Data, - transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - transaction.Type, - transaction.AdditionalData, - transaction.MultiTransactionID, - ) - // Notify listeners of new pending transaction (used in activity history) - if err == nil { - tm.notifyPendingTransactionListeners(transaction.ChainID, []common.Address{transaction.From, transaction.To}, transaction.Timestamp) - } - return err -} - -func (tm *TransactionManager) notifyPendingTransactionListeners(chainID uint64, addresses []common.Address, timestamp uint64) { - if tm.eventFeed != nil { - tm.eventFeed.Send(walletevent.Event{ - Type: EventPendingTransactionUpdate, - ChainID: chainID, - Accounts: addresses, - At: int64(timestamp), - }) - } -} - -func (tm *TransactionManager) deletePending(chainID uint64, hash common.Hash) error { - tx, err := tm.db.BeginTx(context.Background(), nil) - if err != nil { - return err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - - row := tx.QueryRow(`SELECT from_address, to_address, timestamp FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) - var from, to common.Address - var timestamp uint64 - err = row.Scan(&from, &to, ×tamp) - if err != nil { - return err - } - - _, err = tx.Exec(`DELETE FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) - if err != nil { - return err - } - err = tx.Commit() - if err == nil { - tm.notifyPendingTransactionListeners(chainID, []common.Address{from, to}, timestamp) - } - return err -} - -func (tm *TransactionManager) Watch(ctx context.Context, transactionHash common.Hash, client *chain.ClientWithFallback) error { - log.Info("Watching transaction", "chainID", client.ChainID, "hash", transactionHash) - - watchTxCommand := &watchTransactionCommand{ - hash: transactionHash, - client: client, - } - - commandContext, cancel := context.WithTimeout(ctx, 10*time.Minute) - defer cancel() - - err := watchTxCommand.Command()(commandContext) - if err != nil { - log.Error("watchTxCommand error", "error", err, "chainID", client.ChainID, "hash", transactionHash) - return err - } - - return tm.deletePending(client.ChainID, transactionHash) -} - -type watchTransactionCommand struct { - client *chain.ClientWithFallback - hash common.Hash -} - -func (c *watchTransactionCommand) Command() async.Command { - return async.FiniteCommand{ - Interval: 10 * time.Second, - Runable: c.Run, - }.Run -} - -func (c *watchTransactionCommand) Run(ctx context.Context) error { - requestContext, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - _, isPending, err := c.client.TransactionByHash(requestContext, c.hash) - - if err != nil { - log.Error("Watching transaction error", "error", err) - return err - } - - if isPending { - return errors.New("transaction is pending") - } - - return nil -} diff --git a/transactions/pendingtxtracker.go b/transactions/pendingtxtracker.go new file mode 100644 index 000000000..1856d8aa4 --- /dev/null +++ b/transactions/pendingtxtracker.go @@ -0,0 +1,604 @@ +package transactions + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "math/big" + "strings" + "time" + + eth "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + ethrpc "github.com/ethereum/go-ethereum/rpc" + + "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/services/rpcfilters" + "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" +) + +const ( + // PendingTransactionUpdate is emitted when a pending transaction is updated (added or deleted) + EventPendingTransactionUpdate walletevent.EventType = "pending-transaction-update" + // Caries StatusChangedPayload in message + EventPendingTransactionStatusChanged walletevent.EventType = "pending-transaction-status-changed" + + pendingCheckInterval = 10 * time.Second +) + +var ( + ErrStillPending = errors.New("transaction is still pending") +) + +type TxStatus = string + +// Values for status column in pending_transactions +const ( + Pending TxStatus = "Pending" + Done TxStatus = "Done" +) + +type AutoDeleteType = bool + +const ( + AutoDelete AutoDeleteType = true + Keep AutoDeleteType = false +) + +type StatusChangedPayload struct { + ChainID common.ChainID `json:"chainId"` + Hash eth.Hash `json:"hash"` + Status *TxStatus `json:"status,omitempty"` +} + +type PendingTxTracker struct { + db *sql.DB + rpcClient rpc.ClientInterface + + rpcFilter *rpcfilters.Service + eventFeed *event.Feed + + taskRunner *ConditionalRepeater +} + +func NewPendingTxTracker(db *sql.DB, rpcClient rpc.ClientInterface, rpcFilter *rpcfilters.Service, eventFeed *event.Feed) *PendingTxTracker { + tm := &PendingTxTracker{ + db: db, + rpcClient: rpcClient, + eventFeed: eventFeed, + rpcFilter: rpcFilter, + } + tm.taskRunner = NewConditionalRepeater(pendingCheckInterval, func(ctx context.Context) bool { + return tm.fetchTransactions(ctx) + }) + return tm +} + +type txStatusRes struct { + // TODO - 11861: propagate real status + Status TxStatus + hash eth.Hash +} + +func (tm *PendingTxTracker) fetchTransactions(ctx context.Context) bool { + res := WorkDone + + txs, err := tm.GetAllPending() + if err != nil { + log.Error("Failed to get pending transactions", "error", err) + return WorkDone + } + + txsMap := make(map[common.ChainID][]eth.Hash) + for _, tx := range txs { + chainID := tx.ChainID + txsMap[chainID] = append(txsMap[chainID], tx.Hash) + } + + // Batch request for each chain + for chainID, txs := range txsMap { + log.Debug("Processing pending transactions", "chainID", chainID, "count", len(txs)) + batchRes, err := fetchBatchTxStatus(ctx, tm.rpcClient, chainID, txs) + if err != nil { + log.Error("Failed to batch fetch pending transactions status for", "chainID", chainID, "error", err) + continue + } + updateRes, err := tm.updateDBStatus(ctx, chainID, batchRes) + if err != nil { + log.Error("Failed to update pending transactions status for", "chainID", chainID, "error", err) + continue + } + + if len(updateRes) != len(batchRes) { + res = WorkNotDone + } + + tm.emitNotifications(chainID, updateRes) + } + + return res +} + +// fetchBatchTxStatus will exclude the still pending or errored request from the result +func fetchBatchTxStatus(ctx context.Context, rpcClient rpc.ClientInterface, chainID common.ChainID, hashes []eth.Hash) ([]txStatusRes, error) { + chainClient, err := rpcClient.AbstractEthClient(chainID) + if err != nil { + log.Error("Failed to get chain client", "error", err) + return nil, err + } + + reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + batch := make([]ethrpc.BatchElem, 0, len(hashes)) + for _, hash := range hashes { + jsonRes := make(map[string]interface{}) + batch = append(batch, ethrpc.BatchElem{ + Method: "eth_getTransactionByHash", + Args: []interface{}{hash}, + Result: &jsonRes, + }) + } + + err = chainClient.BatchCallContext(reqCtx, batch) + if err != nil { + log.Error("Transactions request fail", "error", err) + return nil, err + } + + res := make([]txStatusRes, 0, len(batch)) + for i, b := range batch { + isPending := true + err := b.Error + if err != nil { + log.Error("Failed to get transaction", "error", err, "hash", hashes[i]) + continue + } else { + jsonRes := *(b.Result.(*map[string]interface{})) + if jsonRes != nil { + if blNo, ok := jsonRes["blockNumber"]; ok { + isPending = blNo == nil + } + } + } + + if !isPending { + res = append(res, txStatusRes{ + hash: hashes[i], + }) + } + } + return res, nil +} + +// updateDBStatus returns entries that were updated only +func (tm *PendingTxTracker) updateDBStatus(ctx context.Context, chainID common.ChainID, statuses []txStatusRes) ([]eth.Hash, error) { + res := make([]eth.Hash, 0, len(statuses)) + tx, err := tm.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to begin transaction: %w", err) + } + + updateStmt, err := tx.PrepareContext(ctx, `UPDATE pending_transactions SET status = ? WHERE network_id = ? AND hash = ?`) + if err != nil { + rollErr := tx.Rollback() + if rollErr != nil { + err = fmt.Errorf("failed to rollback transaction due to: %w", err) + } + return nil, fmt.Errorf("failed to prepare update statement: %w", err) + } + + checkAutoDelStmt, err := tx.PrepareContext(ctx, `SELECT auto_delete FROM pending_transactions WHERE network_id = ? AND hash = ?`) + if err != nil { + rollErr := tx.Rollback() + if rollErr != nil { + err = fmt.Errorf("failed to rollback transaction: %w", err) + } + return nil, fmt.Errorf("failed to prepare auto delete statement: %w", err) + } + + notifyFunctions := make([]func(), 0, len(statuses)) + for _, br := range statuses { + row := checkAutoDelStmt.QueryRowContext(ctx, chainID, br.hash) + var autoDel bool + err = row.Scan(&autoDel) + if err != nil { + if err == sql.ErrNoRows { + log.Warn("Missing entry while checking for auto_delete", "hash", br.hash) + } else { + log.Error("Failed to retrieve auto_delete for pending transaction", "error", err, "hash", br.hash) + } + continue + } + + if autoDel { + notifyFn, err := tm.DeleteBySQLTx(tx, chainID, br.hash) + if err != nil && err != ErrStillPending { + log.Error("Failed to delete pending transaction", "error", err, "hash", br.hash) + continue + } + notifyFunctions = append(notifyFunctions, notifyFn) + } else { + // If the entry was not deleted, update the status + // TODO - #11861: fix status - `br.status` + txStatus := Done + + res, err := updateStmt.ExecContext(ctx, txStatus, chainID, br.hash) + if err != nil { + log.Error("Failed to update pending transaction status", "error", err, "hash", br.hash) + continue + } + affected, err := res.RowsAffected() + if err != nil { + log.Error("Failed to get updated rows", "error", err, "hash", br.hash) + continue + } + + if affected == 0 { + log.Warn("Missing entry to update for", "hash", br.hash) + continue + } + } + + res = append(res, br.hash) + } + + err = tx.Commit() + if err != nil { + return nil, fmt.Errorf("failed to commit transaction: %w", err) + } + + for _, fn := range notifyFunctions { + fn() + } + + return res, nil +} + +func (tm *PendingTxTracker) emitNotifications(chainID common.ChainID, changes []eth.Hash) { + if tm.eventFeed != nil { + for _, hash := range changes { + status := StatusChangedPayload{ + ChainID: chainID, + Hash: hash, + // TODO - #11861: status + } + + jsonPayload, err := json.Marshal(status) + if err != nil { + log.Error("Failed to marshal pending transaction status", "error", err, "hash", hash) + continue + } + tm.eventFeed.Send(walletevent.Event{ + Type: EventPendingTransactionStatusChanged, + ChainID: uint64(chainID), + Message: string(jsonPayload), + }) + } + } +} + +// PendingTransaction called with autoDelete = false will keep the transaction in the database until it is confirmed by the caller using Delete +func (tm *PendingTxTracker) TrackPendingTransaction(chainID common.ChainID, hash eth.Hash, from eth.Address, trType PendingTrxType, autoDelete AutoDeleteType) error { + err := tm.addPending(&PendingTransaction{ + ChainID: chainID, + Hash: hash, + From: from, + Timestamp: uint64(time.Now().Unix()), + Type: trType, + AutoDelete: &autoDelete, + }) + if err != nil { + return err + } + + tm.taskRunner.RunUntilDone() + + return nil +} + +func (tm *PendingTxTracker) Start() error { + tm.taskRunner.RunUntilDone() + return nil +} + +// APIs returns a list of new APIs. +func (tm *PendingTxTracker) APIs() []ethrpc.API { + return []ethrpc.API{ + { + Namespace: "pending", + Version: "0.1.0", + Service: tm, + Public: true, + }, + } +} + +// Protocols returns a new protocols list. In this case, there are none. +func (tm *PendingTxTracker) Protocols() []p2p.Protocol { + return []p2p.Protocol{} +} + +func (tm *PendingTxTracker) Stop() error { + tm.taskRunner.Stop() + return nil +} + +type PendingTrxType string + +const ( + RegisterENS PendingTrxType = "RegisterENS" + ReleaseENS PendingTrxType = "ReleaseENS" + SetPubKey PendingTrxType = "SetPubKey" + BuyStickerPack PendingTrxType = "BuyStickerPack" + WalletTransfer PendingTrxType = "WalletTransfer" + DeployCommunityToken PendingTrxType = "DeployCommunityToken" + AirdropCommunityToken PendingTrxType = "AirdropCommunityToken" + RemoteDestructCollectible PendingTrxType = "RemoteDestructCollectible" + BurnCommunityToken PendingTrxType = "BurnCommunityToken" + DeployOwnerToken PendingTrxType = "DeployOwnerToken" +) + +type PendingTransaction struct { + Hash eth.Hash `json:"hash"` + Timestamp uint64 `json:"timestamp"` + Value bigint.BigInt `json:"value"` + From eth.Address `json:"from"` + To eth.Address `json:"to"` + Data string `json:"data"` + Symbol string `json:"symbol"` + GasPrice bigint.BigInt `json:"gasPrice"` + GasLimit bigint.BigInt `json:"gasLimit"` + Type PendingTrxType `json:"type"` + AdditionalData string `json:"additionalData"` + ChainID common.ChainID `json:"network_id"` + MultiTransactionID int64 `json:"multi_transaction_id"` + + // nil will insert the default value (Pending) in DB + Status *TxStatus `json:"status,omitempty"` + // nil will insert the default value (true) in DB + AutoDelete *bool `json:"autoDelete,omitempty"` +} + +const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data, + symbol, gas_price, gas_limit, type, additional_data, + network_id, COALESCE(multi_transaction_id, 0), status, auto_delete + FROM pending_transactions + ` + +func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err error) { + for rows.Next() { + transaction := &PendingTransaction{ + Value: bigint.BigInt{Int: new(big.Int)}, + GasPrice: bigint.BigInt{Int: new(big.Int)}, + GasLimit: bigint.BigInt{Int: new(big.Int)}, + } + + transaction.Status = new(TxStatus) + transaction.AutoDelete = new(bool) + err := rows.Scan(&transaction.Hash, + &transaction.Timestamp, + (*bigint.SQLBigIntBytes)(transaction.Value.Int), + &transaction.From, + &transaction.To, + &transaction.Data, + &transaction.Symbol, + (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), + (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), + &transaction.Type, + &transaction.AdditionalData, + &transaction.ChainID, + &transaction.MultiTransactionID, + transaction.Status, + transaction.AutoDelete, + ) + if err != nil { + return nil, err + } + + transactions = append(transactions, transaction) + } + return transactions, nil +} + +func (tm *PendingTxTracker) GetAllPending() ([]*PendingTransaction, error) { + log.Debug("Getting all pending transactions") + + rows, err := tm.db.Query(selectFromPending+"WHERE status = ?", Pending) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToTransactions(rows) +} + +func (tm *PendingTxTracker) GetPendingByAddress(chainIDs []uint64, address eth.Address) ([]*PendingTransaction, error) { + log.Debug("Getting pending transaction by address", "chainIDs", chainIDs, "address", address) + + if len(chainIDs) == 0 { + return nil, errors.New("at least 1 chainID is required") + } + + inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" + var parameters []interface{} + for _, c := range chainIDs { + parameters = append(parameters, c) + } + + parameters = append(parameters, address) + + rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s) AND from_address = ?", inVector), parameters...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToTransactions(rows) +} + +// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity +func (tm *PendingTxTracker) GetPendingEntry(chainID common.ChainID, hash eth.Hash) (*PendingTransaction, error) { + log.Debug("Getting pending transaction", "chainID", chainID, "hash", hash) + + rows, err := tm.db.Query(selectFromPending+"WHERE network_id = ? AND hash = ?", chainID, hash) + if err != nil { + return nil, err + } + defer rows.Close() + + trs, err := rowsToTransactions(rows) + if err != nil { + return nil, err + } + + if len(trs) == 0 { + return nil, sql.ErrNoRows + } + return trs[0], nil +} + +// StoreAndTrackPendingTx store the details of a pending transaction and track it until it is mined +func (tm *PendingTxTracker) StoreAndTrackPendingTx(transaction *PendingTransaction) error { + err := tm.addPending(transaction) + if err != nil { + return err + } + + tm.taskRunner.RunUntilDone() + + return err +} + +func (tm *PendingTxTracker) addPending(transaction *PendingTransaction) error { + insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO pending_transactions + (network_id, hash, timestamp, value, from_address, to_address, + data, symbol, gas_price, gas_limit, type, additional_data, multi_transaction_id, status, auto_delete) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? , ?)`) + if err != nil { + return err + } + _, err = insert.Exec( + transaction.ChainID, + transaction.Hash, + transaction.Timestamp, + (*bigint.SQLBigIntBytes)(transaction.Value.Int), + transaction.From, + transaction.To, + transaction.Data, + transaction.Symbol, + (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), + (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), + transaction.Type, + transaction.AdditionalData, + transaction.MultiTransactionID, + transaction.Status, + transaction.AutoDelete, + ) + // Notify listeners of new pending transaction (used in activity history) + if err == nil { + tm.notifyPendingTransactionListeners(transaction.ChainID, []eth.Address{transaction.From, transaction.To}, transaction.Timestamp) + } + if tm.rpcFilter != nil { + tm.rpcFilter.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: transaction.Hash, + Type: string(transaction.Type), + From: transaction.From, + ChainID: uint64(transaction.ChainID), + }) + } + return err +} + +func (tm *PendingTxTracker) notifyPendingTransactionListeners(chainID common.ChainID, addresses []eth.Address, timestamp uint64) { + if tm.eventFeed != nil { + tm.eventFeed.Send(walletevent.Event{ + Type: EventPendingTransactionUpdate, + ChainID: uint64(chainID), + Accounts: addresses, + At: int64(timestamp), + }) + } +} + +// DeleteBySQLTx returns ErrStillPending if the transaction is still pending +func (tm *PendingTxTracker) DeleteBySQLTx(tx *sql.Tx, chainID common.ChainID, hash eth.Hash) (notify func(), err error) { + row := tx.QueryRow(`SELECT from_address, to_address, timestamp, status FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) + var from, to eth.Address + var timestamp uint64 + var status TxStatus + err = row.Scan(&from, &to, ×tamp, &status) + if err != nil { + return nil, err + } + + _, err = tx.Exec(`DELETE FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) + if err != nil { + return nil, err + } + + if err == nil && status == Pending { + err = ErrStillPending + } + return func() { + tm.notifyPendingTransactionListeners(chainID, []eth.Address{from, to}, timestamp) + }, err +} + +func GetTransferData(tx *sql.Tx, chainID common.ChainID, hash eth.Hash) (txType *PendingTrxType, MTID *int64, err error) { + row := tx.QueryRow(`SELECT type, multi_transaction_id FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash, txType) + txType = new(PendingTrxType) + MTID = new(int64) + err = row.Scan(txType, MTID) + if err != nil { + return nil, nil, err + } + return txType, MTID, nil +} + +// Watch returns sql.ErrNoRows if no pending transaction is found for the given identity +// tx.Status is not nill if err is nil +func (tm *PendingTxTracker) Watch(ctx context.Context, chainID common.ChainID, hash eth.Hash) (*TxStatus, error) { + log.Debug("Watching transaction", "chainID", chainID, "hash", hash) + + tx, err := tm.GetPendingEntry(chainID, hash) + if err != nil { + return nil, err + } + + return tx.Status, nil +} + +// Delete returns ErrStillPending if the deleted transaction was still pending +// The transactions are suppose to be deleted by the client only after they are confirmed +func (tm *PendingTxTracker) Delete(ctx context.Context, chainID common.ChainID, transactionHash eth.Hash) error { + log.Debug("Delete pending transaction to confirm it", "chainID", chainID, "hash", transactionHash) + + tx, err := tm.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + notifyFn, err := tm.DeleteBySQLTx(tx, chainID, transactionHash) + if err != nil && err != ErrStillPending { + rollErr := tx.Rollback() + if rollErr != nil { + return fmt.Errorf("failed to rollback transaction due to error: %w", err) + } + return err + } + + commitErr := tx.Commit() + if commitErr != nil { + return fmt.Errorf("failed to commit transaction: %w", commitErr) + } + notifyFn() + return err +} diff --git a/transactions/pendingtxtracker_test.go b/transactions/pendingtxtracker_test.go new file mode 100644 index 000000000..9c7c4db2f --- /dev/null +++ b/transactions/pendingtxtracker_test.go @@ -0,0 +1,453 @@ +package transactions + +import ( + "context" + "encoding/json" + "fmt" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + eth "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/rpc" + + "github.com/status-im/status-go/rpc/chain" + "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/t/helpers" + "github.com/status-im/status-go/walletdatabase" +) + +type MockETHClient struct { + mock.Mock +} + +func (m *MockETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { + args := m.Called(ctx, b) + return args.Error(0) +} + +type MockChainClient struct { + mock.Mock + + clients map[common.ChainID]*MockETHClient +} + +func newMockChainClient() *MockChainClient { + return &MockChainClient{ + clients: make(map[common.ChainID]*MockETHClient), + } +} + +func (m *MockChainClient) setAvailableClients(chainIDs []common.ChainID) *MockChainClient { + for _, chainID := range chainIDs { + if _, ok := m.clients[chainID]; !ok { + m.clients[chainID] = new(MockETHClient) + } + } + return m +} + +func (m *MockChainClient) AbstractEthClient(chainID common.ChainID) (chain.ClientInterface, error) { + if _, ok := m.clients[chainID]; !ok { + panic(fmt.Sprintf("no mock client for chainID %d", chainID)) + } + return m.clients[chainID], nil +} + +func setupTestTransactionDB(t *testing.T) (*PendingTxTracker, func(), *MockChainClient, *event.Feed) { + db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + + chainClient := newMockChainClient() + eventFeed := &event.Feed{} + return NewPendingTxTracker(db, chainClient, nil, eventFeed), func() { + require.NoError(t, db.Close()) + }, chainClient, eventFeed +} + +const ( + transactionSuccessStatus = "0x1" + transactionFailStatus = "0x0" + transactionByHashRPCName = "eth_getTransactionByHash" +) + +func TestPendingTxTracker_ValidateConfirmed(t *testing.T) { + m, stop, chainClient, eventFeed := setupTestTransactionDB(t) + defer stop() + + txs := generateTestTransactions(1) + + // Mock the first call to getTransactionByHash + chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID}) + cl := chainClient.clients[txs[0].ChainID] + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionSuccessStatus + }) + + eventChan := make(chan walletevent.Event, 2) + sub := eventFeed.Subscribe(eventChan) + + err := m.StoreAndTrackPendingTx(&txs[0]) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + select { + case we := <-eventChan: + if i == 0 || i == 1 { + // Check add and delete + require.Equal(t, EventPendingTransactionUpdate, we.Type) + } else { + require.Equal(t, EventPendingTransactionStatusChanged, we.Type) + var p StatusChangedPayload + err = json.Unmarshal([]byte(we.Message), &p) + require.NoError(t, err) + require.Equal(t, txs[0].Hash, p.Hash) + require.Nil(t, p.Status) + } + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + } + + // Wait for the answer to be processed + err = m.Stop() + require.NoError(t, err) + + res, err := m.GetAllPending() + require.NoError(t, err) + require.Equal(t, 0, len(res)) + + sub.Unsubscribe() +} + +func TestPendingTxTracker_InterruptWatching(t *testing.T) { + m, stop, chainClient, eventFeed := setupTestTransactionDB(t) + defer stop() + + txs := generateTestTransactions(2) + + // Mock the first call to getTransactionByHash + chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID}) + cl := chainClient.clients[txs[0].ChainID] + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return (len(b) == 2 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash && b[1].Method == transactionByHashRPCName && b[1].Args[0] == txs[1].Hash) + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = nil + res = elems[1].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionFailStatus + }) + + eventChan := make(chan walletevent.Event, 2) + sub := eventFeed.Subscribe(eventChan) + + for i := range txs { + err := m.addPending(&txs[i]) + require.NoError(t, err) + } + + // Check add + for i := 0; i < 2; i++ { + select { + case we := <-eventChan: + require.Equal(t, EventPendingTransactionUpdate, we.Type) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + } + + err := m.Start() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + select { + case we := <-eventChan: + if i == 0 { + require.Equal(t, EventPendingTransactionUpdate, we.Type) + } else { + require.Equal(t, EventPendingTransactionStatusChanged, we.Type) + var p StatusChangedPayload + err := json.Unmarshal([]byte(we.Message), &p) + require.NoError(t, err) + require.Equal(t, txs[1].Hash, p.Hash) + require.Equal(t, txs[1].ChainID, p.ChainID) + require.Nil(t, p.Status) + } + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + } + + // Stop the next timed call + err = m.Stop() + require.NoError(t, err) + + res, err := m.GetAllPending() + require.NoError(t, err) + require.Equal(t, 1, len(res), "should have only one pending tx") + + // Restart the tracker to process leftovers + // + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return (len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash) + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionSuccessStatus + }) + + err = m.Start() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + select { + case we := <-eventChan: + if i == 0 { + require.Equal(t, EventPendingTransactionUpdate, we.Type) + } else { + require.Equal(t, EventPendingTransactionStatusChanged, we.Type) + var p StatusChangedPayload + err := json.Unmarshal([]byte(we.Message), &p) + require.NoError(t, err) + require.Equal(t, txs[0].ChainID, p.ChainID) + require.Equal(t, txs[0].Hash, p.Hash) + require.Nil(t, p.Status) + } + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + } + + err = m.Stop() + require.NoError(t, err) + + res, err = m.GetAllPending() + require.NoError(t, err) + require.Equal(t, 0, len(res)) + + sub.Unsubscribe() +} + +func TestPendingTxTracker_MultipleClients(t *testing.T) { + m, stop, chainClient, eventFeed := setupTestTransactionDB(t) + defer stop() + + txs := generateTestTransactions(2) + txs[1].ChainID++ + + // Mock the both clients to be available + chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID, txs[1].ChainID}) + cl := chainClient.clients[txs[0].ChainID] + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return (len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash) + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionFailStatus + }) + cl = chainClient.clients[txs[1].ChainID] + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return (len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[1].Hash) + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionSuccessStatus + }) + + for i := range txs { + err := m.TrackPendingTransaction(txs[i].ChainID, txs[i].Hash, txs[i].From, txs[i].Type, true) + require.NoError(t, err) + } + + eventChan := make(chan walletevent.Event) + sub := eventFeed.Subscribe(eventChan) + + err := m.Start() + require.NoError(t, err) + + for i := 0; i < 2; i++ { + for j := 0; j < 2; j++ { + select { + case we := <-eventChan: + if j == 0 { + require.Equal(t, EventPendingTransactionUpdate, we.Type) + } else { + require.Equal(t, EventPendingTransactionStatusChanged, we.Type) + var p StatusChangedPayload + err := json.Unmarshal([]byte(we.Message), &p) + require.NoError(t, err) + require.Nil(t, p.Status) + } + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + } + } + + err = m.Stop() + require.NoError(t, err) + + res, err := m.GetAllPending() + require.NoError(t, err) + require.Equal(t, 0, len(res)) + + sub.Unsubscribe() +} + +func TestPendingTxTracker_Watch(t *testing.T) { + m, stop, chainClient, eventFeed := setupTestTransactionDB(t) + defer stop() + + txs := generateTestTransactions(2) + // Make the second already confirmed + *txs[1].Status = Done + + // Mock the first call to getTransactionByHash + chainClient.setAvailableClients([]common.ChainID{txs[0].ChainID}) + cl := chainClient.clients[txs[0].ChainID] + cl.On("BatchCallContext", mock.Anything, mock.MatchedBy(func(b []rpc.BatchElem) bool { + return len(b) == 1 && b[0].Method == transactionByHashRPCName && b[0].Args[0] == txs[0].Hash + })).Return(nil).Once().Run(func(args mock.Arguments) { + elems := args.Get(1).([]rpc.BatchElem) + res := elems[0].Result.(*map[string]interface{}) + (*res)["blockNumber"] = transactionFailStatus + }) + + eventChan := make(chan walletevent.Event, 2) + sub := eventFeed.Subscribe(eventChan) + + // Track the first transaction + err := m.TrackPendingTransaction(txs[0].ChainID, txs[0].Hash, txs[0].From, txs[0].Type, false) + require.NoError(t, err) + + // Store the confirmed already + err = m.StoreAndTrackPendingTx(&txs[1]) + require.NoError(t, err) + + storeEventCount := 0 + statusEventCount := 0 + for j := 0; j < 3; j++ { + select { + case we := <-eventChan: + if EventPendingTransactionUpdate == we.Type { + storeEventCount++ + } else if EventPendingTransactionStatusChanged == we.Type { + statusEventCount++ + var p StatusChangedPayload + err := json.Unmarshal([]byte(we.Message), &p) + require.NoError(t, err) + require.Equal(t, txs[0].ChainID, p.ChainID) + require.Equal(t, txs[0].Hash, p.Hash) + require.Nil(t, p.Status) + } + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for the status update event") + } + } + require.Equal(t, 2, storeEventCount) + require.Equal(t, 1, statusEventCount) + + // Stop the next timed call + err = m.Stop() + require.NoError(t, err) + + res, err := m.GetAllPending() + require.NoError(t, err) + require.Equal(t, 0, len(res), "should have only one pending tx") + + status, err := m.Watch(context.Background(), txs[0].ChainID, txs[0].Hash) + require.NoError(t, err) + require.NotEqual(t, Pending, status) + + err = m.Delete(context.Background(), txs[0].ChainID, txs[0].Hash) + require.NoError(t, err) + + select { + case we := <-eventChan: + require.Equal(t, EventPendingTransactionUpdate, we.Type) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for the delete event") + } + + sub.Unsubscribe() +} + +func TestPendingTransactions(t *testing.T) { + manager, stop, _, _ := setupTestTransactionDB(t) + defer stop() + + tx := generateTestTransactions(1)[0] + + rst, err := manager.GetAllPending() + require.NoError(t, err) + require.Nil(t, rst) + + rst, err = manager.GetPendingByAddress([]uint64{777}, tx.From) + require.NoError(t, err) + require.Nil(t, rst) + + err = manager.addPending(&tx) + require.NoError(t, err) + + rst, err = manager.GetPendingByAddress([]uint64{777}, tx.From) + require.NoError(t, err) + require.Equal(t, 1, len(rst)) + require.Equal(t, tx, *rst[0]) + + rst, err = manager.GetAllPending() + require.NoError(t, err) + require.Equal(t, 1, len(rst)) + require.Equal(t, tx, *rst[0]) + + rst, err = manager.GetPendingByAddress([]uint64{777}, eth.Address{2}) + require.NoError(t, err) + require.Nil(t, rst) + + err = manager.Delete(context.Background(), common.ChainID(777), tx.Hash) + require.Error(t, err, ErrStillPending) + + rst, err = manager.GetPendingByAddress([]uint64{777}, tx.From) + require.NoError(t, err) + require.Equal(t, 0, len(rst)) + + rst, err = manager.GetAllPending() + require.NoError(t, err) + require.Equal(t, 0, len(rst)) +} + +func generateTestTransactions(count int) []PendingTransaction { + if count > 127 { + panic("can't generate more than 127 distinct transactions") + } + + txs := make([]PendingTransaction, count) + for i := 0; i < count; i++ { + txs[i] = PendingTransaction{ + Hash: eth.Hash{byte(i)}, + From: eth.Address{byte(i)}, + To: eth.Address{byte(i * 2)}, + Type: RegisterENS, + AdditionalData: "someuser.stateofus.eth", + Value: bigint.BigInt{Int: big.NewInt(int64(i))}, + GasLimit: bigint.BigInt{Int: big.NewInt(21000)}, + GasPrice: bigint.BigInt{Int: big.NewInt(int64(i))}, + ChainID: 777, + Status: new(TxStatus), + AutoDelete: new(bool), + } + *txs[i].Status = Pending // set to pending by default + *txs[i].AutoDelete = true // set to true by default + } + return txs +} diff --git a/transactions/transaction_test.go b/transactions/transaction_test.go deleted file mode 100644 index a3e4e1a82..000000000 --- a/transactions/transaction_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package transactions - -import ( - "math/big" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ethereum/go-ethereum/common" - - "github.com/status-im/status-go/services/wallet/bigint" - "github.com/status-im/status-go/t/helpers" - "github.com/status-im/status-go/walletdatabase" -) - -func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { - db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) - require.NoError(t, err) - return &TransactionManager{db, nil, nil, nil}, func() { - require.NoError(t, db.Close()) - } -} - -func TestPendingTransactions(t *testing.T) { - manager, stop := setupTestTransactionDB(t) - defer stop() - - trx := PendingTransaction{ - Hash: common.Hash{1}, - From: common.Address{1}, - To: common.Address{2}, - Type: RegisterENS, - AdditionalData: "someuser.stateofus.eth", - Value: bigint.BigInt{Int: big.NewInt(123)}, - GasLimit: bigint.BigInt{Int: big.NewInt(21000)}, - GasPrice: bigint.BigInt{Int: big.NewInt(1)}, - ChainID: 777, - } - - rst, err := manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Nil(t, rst) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Nil(t, rst) - - err = manager.AddPending(&trx) - require.NoError(t, err) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Equal(t, 1, len(rst)) - require.Equal(t, trx, *rst[0]) - - rst, err = manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Equal(t, 1, len(rst)) - require.Equal(t, trx, *rst[0]) - - rst, err = manager.GetPendingByAddress([]uint64{777}, common.Address{2}) - require.NoError(t, err) - require.Nil(t, rst) - - err = manager.deletePending(777, trx.Hash) - require.NoError(t, err) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Equal(t, 0, len(rst)) - - rst, err = manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Equal(t, 0, len(rst)) -}