diff --git a/api/geth_backend.go b/api/geth_backend.go index a37b17cd6..6767b5a8c 100644 --- a/api/geth_backend.go +++ b/api/geth_backend.go @@ -16,6 +16,7 @@ import ( "github.com/imdario/mergo" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" ethcrypto "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -30,18 +31,19 @@ import ( "github.com/status-im/status-go/logutils" "github.com/status-im/status-go/multiaccounts" "github.com/status-im/status-go/multiaccounts/accounts" - "github.com/status-im/status-go/multiaccounts/common" + multiacccommon "github.com/status-im/status-go/multiaccounts/common" "github.com/status-im/status-go/multiaccounts/settings" "github.com/status-im/status-go/node" "github.com/status-im/status-go/nodecfg" "github.com/status-im/status-go/params" "github.com/status-im/status-go/protocol" - identityUtils "github.com/status-im/status-go/protocol/identity" + identityutils "github.com/status-im/status-go/protocol/identity" "github.com/status-im/status-go/protocol/identity/colorhash" "github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/ext" "github.com/status-im/status-go/services/personal" + "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/typeddata" "github.com/status-im/status-go/signal" "github.com/status-im/status-go/sqlite" @@ -1064,7 +1066,7 @@ func (b *GethStatusBackend) generateOrImportAccount(mnemonic string, request *re account := multiaccounts.Account{ KeyUID: info.KeyUID, Name: request.DisplayName, - CustomizationColor: common.CustomizationColor(request.CustomizationColor), + CustomizationColor: multiacccommon.CustomizationColor(request.CustomizationColor), KDFIterations: sqlite.ReducedKDFIterationsNumber, } if request.ImagePath != "" { @@ -1255,7 +1257,7 @@ func enrichMultiAccountBySubAccounts(account *multiaccounts.Account, subaccs []* } account.ColorHash = colorHash - colorID, err := identityUtils.ToColorID(pk) + colorID, err := identityutils.ToColorID(pk) if err != nil { return err } @@ -1276,7 +1278,7 @@ func enrichMultiAccountByPublicKey(account *multiaccounts.Account, publicKey typ } account.ColorHash = colorHash - colorID, err := identityUtils.ToColorID(pk) + colorID, err := identityutils.ToColorID(pk) if err != nil { return err } @@ -1586,7 +1588,12 @@ func (b *GethStatusBackend) SendTransaction(sendArgs transactions.SendTxArgs, pa return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(hash) + go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: common.Hash(hash), + Type: string(transactions.WalletTransfer), + From: common.Address(sendArgs.From), + ChainID: b.transactor.NetworkID(), + }) return } @@ -1602,7 +1609,12 @@ func (b *GethStatusBackend) SendTransactionWithChainID(chainID uint64, sendArgs return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(hash) + go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: common.Hash(hash), + Type: string(transactions.WalletTransfer), + From: common.Address(sendArgs.From), + ChainID: b.transactor.NetworkID(), + }) return } @@ -1613,7 +1625,12 @@ func (b *GethStatusBackend) SendTransactionWithSignature(sendArgs transactions.S return } - go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(hash) + go b.statusNode.RPCFiltersService().TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: common.Hash(hash), + Type: string(transactions.WalletTransfer), + From: common.Address(sendArgs.From), + ChainID: b.transactor.NetworkID(), + }) return } diff --git a/node/status_node_services.go b/node/status_node_services.go index 78a81f903..4af8760c8 100644 --- a/node/status_node_services.go +++ b/node/status_node_services.go @@ -489,6 +489,7 @@ func (b *StatusNode) walletService(accountsDB *accounts.Database, accountsFeed * b.appDB, accountsDB, b.rpcClient, accountsFeed, b.gethAccountManager, b.transactor, b.config, b.ensService(b.timeSourceNow()), b.stickersService(accountsDB), + b.rpcFiltersSrvc, extService, ) } diff --git a/services/ens/api.go b/services/ens/api.go index 601068937..4d91f1d7a 100644 --- a/services/ens/api.go +++ b/services/ens/api.go @@ -30,7 +30,6 @@ import ( "github.com/status-im/status-go/contracts/registrar" "github.com/status-im/status-go/contracts/resolver" "github.com/status-im/status-go/contracts/snt" - "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/rpcfilters" @@ -354,7 +353,12 @@ func (api *API) Release(ctx context.Context, chainID uint64, txArgs transactions return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(types.Hash(tx.Hash())) + go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: tx.Hash(), + Type: string(transactions.ReleaseENS), + From: common.Address(txArgs.From), + ChainID: chainID, + }) err = api.Remove(ctx, chainID, fullDomainName(username)) @@ -439,7 +443,12 @@ func (api *API) Register(ctx context.Context, chainID uint64, txArgs transaction return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(types.Hash(tx.Hash())) + go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: tx.Hash(), + Type: string(transactions.RegisterENS), + From: common.Address(txArgs.From), + ChainID: chainID, + }) err = api.Add(ctx, chainID, fullDomainName(username)) if err != nil { @@ -545,7 +554,12 @@ func (api *API) SetPubKey(ctx context.Context, chainID uint64, txArgs transactio return "", err } - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(types.Hash(tx.Hash())) + go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: tx.Hash(), + Type: string(transactions.SetPubKey), + From: common.Address(txArgs.From), + ChainID: chainID, + }) err = api.Add(ctx, chainID, fullDomainName(username)) if err != nil { diff --git a/services/rpcfilters/api.go b/services/rpcfilters/api.go index 25fd693b4..6ad291082 100644 --- a/services/rpcfilters/api.go +++ b/services/rpcfilters/api.go @@ -10,6 +10,7 @@ import ( "github.com/pborman/uuid" ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/log" @@ -33,6 +34,13 @@ type filter interface { deadline() *time.Timer } +type ChainEvent interface { + Start() error + Stop() + Subscribe() (id int, ch interface{}) + Unsubscribe(id int) +} + // PublicAPI represents filter API that is exported to `eth` namespace type PublicAPI struct { filtersMu sync.Mutex @@ -123,7 +131,12 @@ func (api *PublicAPI) NewBlockFilter() getrpc.ID { api.filters[id] = f go func() { - id, s := api.latestBlockChangedEvent.Subscribe() + id, si := api.latestBlockChangedEvent.Subscribe() + s, ok := si.(chan common.Hash) + if !ok { + panic("latestBlockChangedEvent returned wrong type") + } + defer api.latestBlockChangedEvent.Unsubscribe(id) for { @@ -154,7 +167,11 @@ func (api *PublicAPI) NewPendingTransactionFilter() getrpc.ID { api.filters[id] = f go func() { - id, s := api.transactionSentToUpstreamEvent.Subscribe() + id, si := api.transactionSentToUpstreamEvent.Subscribe() + s, ok := si.(chan *PendingTxInfo) + if !ok { + panic("transactionSentToUpstreamEvent returned wrong type") + } defer api.transactionSentToUpstreamEvent.Unsubscribe(id) for { @@ -167,7 +184,6 @@ func (api *PublicAPI) NewPendingTransactionFilter() getrpc.ID { return } } - }() return id diff --git a/services/rpcfilters/latest_block_changed_event.go b/services/rpcfilters/latest_block_changed_event.go index 8de93e76f..b6e76f544 100644 --- a/services/rpcfilters/latest_block_changed_event.go +++ b/services/rpcfilters/latest_block_changed_event.go @@ -137,7 +137,7 @@ func (e *latestBlockChangedEvent) Stop() { e.quit = nil } -func (e *latestBlockChangedEvent) Subscribe() (int, chan common.Hash) { +func (e *latestBlockChangedEvent) Subscribe() (int, interface{}) { e.sxMu.Lock() defer e.sxMu.Unlock() diff --git a/services/rpcfilters/latest_block_changed_event_test.go b/services/rpcfilters/latest_block_changed_event_test.go index 776bd26ef..2f8a7cb9f 100644 --- a/services/rpcfilters/latest_block_changed_event_test.go +++ b/services/rpcfilters/latest_block_changed_event_test.go @@ -61,7 +61,9 @@ func TestZeroSubsciptionsOptimization(t *testing.T) { assert.Equal(t, int64(0), atomic.LoadInt64(&counter)) // subscribing an event, checking that it works - id, channel := event.Subscribe() + id, channelInterface := event.Subscribe() + channel, ok := channelInterface.(chan common.Hash) + assert.True(t, ok) timeout := time.After(1 * time.Second) select { @@ -128,7 +130,9 @@ func testEventSubscribe(t *testing.T, f func() (blockInfo, error), expectedHashe } func testEvent(t *testing.T, event *latestBlockChangedEvent, expectedHashes []common.Hash) { - id, channel := event.Subscribe() + id, channelInterface := event.Subscribe() + channel, ok := channelInterface.(chan common.Hash) + assert.True(t, ok) timeout := time.After(1 * time.Second) diff --git a/services/rpcfilters/service.go b/services/rpcfilters/service.go index 1dbb75545..08a20fcf7 100644 --- a/services/rpcfilters/service.go +++ b/services/rpcfilters/service.go @@ -4,8 +4,6 @@ import ( "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rpc" - - "github.com/status-im/status-go/eth-node/types" ) // Make sure that Service implements node.Lifecycle interface. @@ -68,8 +66,10 @@ func (s *Service) Stop() error { return nil } -// TriggerTransactionSentToUpstreamEvent notifies the subscribers -// of the TransactionSentToUpstream event -func (s *Service) TriggerTransactionSentToUpstreamEvent(transactionHash types.Hash) { - s.transactionSentToUpstreamEvent.Trigger(transactionHash) +func (s *Service) TransactionSentToUpstreamEvent() ChainEvent { + return s.transactionSentToUpstreamEvent +} + +func (s *Service) TriggerTransactionSentToUpstreamEvent(txInfo *PendingTxInfo) { + s.transactionSentToUpstreamEvent.Trigger(txInfo) } diff --git a/services/rpcfilters/transaction_sent_to_upstream_event.go b/services/rpcfilters/transaction_sent_to_upstream_event.go index 78833291c..17d1fdbab 100644 --- a/services/rpcfilters/transaction_sent_to_upstream_event.go +++ b/services/rpcfilters/transaction_sent_to_upstream_event.go @@ -4,23 +4,29 @@ import ( "errors" "sync" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" - - "github.com/status-im/status-go/eth-node/types" ) +type PendingTxInfo struct { + Hash common.Hash + Type string + From common.Address + ChainID uint64 +} + // transactionSentToUpstreamEvent represents an event that one can subscribe to type transactionSentToUpstreamEvent struct { sxMu sync.Mutex - sx map[int]chan types.Hash - listener chan types.Hash + sx map[int]chan *PendingTxInfo + listener chan *PendingTxInfo quit chan struct{} } func newTransactionSentToUpstreamEvent() *transactionSentToUpstreamEvent { return &transactionSentToUpstreamEvent{ - sx: make(map[int]chan types.Hash), - listener: make(chan types.Hash), + sx: make(map[int]chan *PendingTxInfo), + listener: make(chan *PendingTxInfo), } } @@ -34,11 +40,11 @@ func (e *transactionSentToUpstreamEvent) Start() error { go func() { for { select { - case transactionHash := <-e.listener: + case transactionInfo := <-e.listener: if e.numberOfSubscriptions() == 0 { continue } - e.processTransactionSentToUpstream(transactionHash) + e.processTransactionSentToUpstream(transactionInfo) case <-e.quit: return } @@ -54,16 +60,16 @@ func (e *transactionSentToUpstreamEvent) numberOfSubscriptions() int { return len(e.sx) } -func (e *transactionSentToUpstreamEvent) processTransactionSentToUpstream(transactionHash types.Hash) { +func (e *transactionSentToUpstreamEvent) processTransactionSentToUpstream(transactionInfo *PendingTxInfo) { e.sxMu.Lock() defer e.sxMu.Unlock() for id, channel := range e.sx { select { - case channel <- transactionHash: + case channel <- transactionInfo: default: - log.Error("dropping messages %s for subscriotion %d because the channel is full", transactionHash, id) + log.Error("dropping messages %s for subscriotion %d because the channel is full", transactionInfo, id) } } } @@ -83,11 +89,11 @@ func (e *transactionSentToUpstreamEvent) Stop() { e.quit = nil } -func (e *transactionSentToUpstreamEvent) Subscribe() (int, chan types.Hash) { +func (e *transactionSentToUpstreamEvent) Subscribe() (int, interface{}) { e.sxMu.Lock() defer e.sxMu.Unlock() - channel := make(chan types.Hash, 512) + channel := make(chan *PendingTxInfo, 512) id := len(e.sx) e.sx[id] = channel return id, channel @@ -101,6 +107,6 @@ func (e *transactionSentToUpstreamEvent) Unsubscribe(id int) { } // Trigger gets called in order to trigger the event -func (e *transactionSentToUpstreamEvent) Trigger(transactionHash types.Hash) { - e.listener <- transactionHash +func (e *transactionSentToUpstreamEvent) Trigger(transactionInfo *PendingTxInfo) { + e.listener <- transactionInfo } diff --git a/services/rpcfilters/transaction_sent_to_upstream_event_test.go b/services/rpcfilters/transaction_sent_to_upstream_event_test.go index 4acf6c825..280dcecd5 100644 --- a/services/rpcfilters/transaction_sent_to_upstream_event_test.go +++ b/services/rpcfilters/transaction_sent_to_upstream_event_test.go @@ -1,6 +1,7 @@ package rpcfilters import ( + "reflect" "sync" "testing" "time" @@ -8,19 +9,39 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/status-im/status-go/eth-node/types" + "github.com/ethereum/go-ethereum/common" ) -var transactionHashes = []types.Hash{types.HexToHash("0xAA"), types.HexToHash("0xBB"), types.HexToHash("0xCC")} +var transactionInfos = []*PendingTxInfo{ + { + Hash: common.HexToHash("0xAA"), + Type: "RegisterENS", + From: common.Address{1}, + ChainID: 0, + }, + { + Hash: common.HexToHash("0xBB"), + Type: "WalletTransfer", + ChainID: 1, + }, + { + Hash: common.HexToHash("0xCC"), + Type: "SetPubKey", + From: common.Address{3}, + ChainID: 2, + }, +} func TestTransactionSentToUpstreamEventMultipleSubscribe(t *testing.T) { event := newTransactionSentToUpstreamEvent() require.NoError(t, event.Start()) defer event.Stop() - var subscriptionChannels []chan types.Hash + var subscriptionChannels []chan *PendingTxInfo for i := 0; i < 3; i++ { - id, channel := event.Subscribe() + id, channelInterface := event.Subscribe() + channel, ok := channelInterface.(chan *PendingTxInfo) + require.True(t, ok) // test id assignment require.Equal(t, i, id) // test numberOfSubscriptions @@ -35,10 +56,10 @@ func TestTransactionSentToUpstreamEventMultipleSubscribe(t *testing.T) { for _, channel := range subscriptionChannels { ch := channel go func() { - for _, expectedHash := range transactionHashes { + for _, expectedTxInfo := range transactionInfos { select { - case receivedHash := <-ch: - require.Equal(t, expectedHash, receivedHash) + case receivedTxInfo := <-ch: + require.True(t, reflect.DeepEqual(expectedTxInfo, receivedTxInfo)) case <-time.After(1 * time.Second): assert.Fail(t, "timeout") } @@ -48,8 +69,8 @@ func TestTransactionSentToUpstreamEventMultipleSubscribe(t *testing.T) { } }() - for _, hashToTrigger := range transactionHashes { - event.Trigger(hashToTrigger) + for _, txInfo := range transactionInfos { + event.Trigger(txInfo) } wg.Wait() } diff --git a/services/stickers/transactions.go b/services/stickers/transactions.go index cb1650ea1..6ad0ce3fb 100644 --- a/services/stickers/transactions.go +++ b/services/stickers/transactions.go @@ -14,6 +14,7 @@ import ( "github.com/status-im/status-go/contracts/snt" "github.com/status-im/status-go/contracts/stickers" "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/utils" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/transactions" @@ -70,8 +71,12 @@ func (api *API) Buy(ctx context.Context, chainID uint64, txArgs transactions.Sen } // TODO: track pending transaction (do this in ENS service too) - - go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(types.Hash(tx.Hash())) + go api.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: tx.Hash(), + Type: string(transactions.BuyStickerPack), + From: common.Address(txArgs.From), + ChainID: chainID, + }) return tx.Hash().String(), nil } diff --git a/services/wallet/api.go b/services/wallet/api.go index 98825b7df..e1a61d708 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -19,6 +19,7 @@ import ( "github.com/status-im/status-go/services/wallet/thirdparty/opensea" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" + "github.com/status-im/status-go/transactions" wcommon "github.com/status-im/status-go/services/wallet/common" ) @@ -237,27 +238,29 @@ func (api *API) DeleteSavedAddress(ctx context.Context, address common.Address, return err } -func (api *API) GetPendingTransactions(ctx context.Context) ([]*transfer.PendingTransaction, error) { +func (api *API) GetPendingTransactions(ctx context.Context) ([]*transactions.PendingTransaction, error) { log.Debug("call to get pending transactions") - rst, err := api.s.transactionManager.GetAllPending([]uint64{api.s.rpcClient.UpstreamChainID}) + rst, err := api.s.pendingTxManager.GetAllPending([]uint64{api.s.rpcClient.UpstreamChainID}) log.Debug("result from database for pending transactions", "len", len(rst)) return rst, err } -func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*transfer.PendingTransaction, error) { +func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*transactions.PendingTransaction, error) { log.Debug("call to get pending transactions") - rst, err := api.s.transactionManager.GetAllPending(chainIDs) + rst, err := api.s.pendingTxManager.GetAllPending(chainIDs) log.Debug("result from database for pending transactions", "len", len(rst)) return rst, err } -func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identities []transfer.TransactionIdentity) (result []*transfer.PendingTransaction, err error) { +func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identities []transfer.TransactionIdentity) ( + result []*transactions.PendingTransaction, err error) { + log.Debug("call to GetPendingTransactionsForIdentities") - result = make([]*transfer.PendingTransaction, 0, len(identities)) - var pt *transfer.PendingTransaction + result = make([]*transactions.PendingTransaction, 0, len(identities)) + var pt *transactions.PendingTransaction for _, identity := range identities { - pt, err = api.s.transactionManager.GetPendingEntry(uint64(identity.ChainID), identity.Hash) + pt, err = api.s.pendingTxManager.GetPendingEntry(uint64(identity.ChainID), identity.Hash) result = append(result, pt) } @@ -265,50 +268,30 @@ func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identit return } -func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*transfer.PendingTransaction, error) { +func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ( + []*transactions.PendingTransaction, error) { + log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.transactionManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address) + rst, err := api.s.pendingTxManager.GetPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address) log.Debug("result from database for pending transactions by address", "len", len(rst)) return rst, err } -func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, address common.Address) ([]*transfer.PendingTransaction, error) { +func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, + address common.Address) ([]*transactions.PendingTransaction, error) { + log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.transactionManager.GetPendingByAddress(chainIDs, address) + rst, err := api.s.pendingTxManager.GetPendingByAddress(chainIDs, address) log.Debug("result from database for pending transactions by address", "len", len(rst)) return rst, err } -func (api *API) StorePendingTransaction(ctx context.Context, trx transfer.PendingTransaction) error { - log.Debug("call to create or edit pending transaction") - if trx.ChainID == 0 { - trx.ChainID = api.s.rpcClient.UpstreamChainID - } - err := api.s.transactionManager.AddPending(trx) - log.Debug("result from database for creating or editing a pending transaction", "err", err) - return err -} - -func (api *API) DeletePendingTransaction(ctx context.Context, transactionHash common.Hash) error { - log.Debug("call to remove pending transaction") - err := api.s.transactionManager.DeletePending(api.s.rpcClient.UpstreamChainID, transactionHash) - log.Debug("result from database for remove pending transaction", "err", err) - return err -} - -func (api *API) DeletePendingTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) error { - log.Debug("call to remove pending transaction") - err := api.s.transactionManager.DeletePending(chainID, transactionHash) - log.Debug("result from database for remove pending transaction", "err", err) - return err -} - func (api *API) WatchTransaction(ctx context.Context, transactionHash common.Hash) error { chainClient, err := api.s.rpcClient.EthClient(api.s.rpcClient.UpstreamChainID) if err != nil { return err } - return api.s.transactionManager.Watch(ctx, transactionHash, chainClient) + return api.s.pendingTxManager.Watch(ctx, transactionHash, chainClient) } func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) error { @@ -316,7 +299,7 @@ func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, t if err != nil { return err } - return api.s.transactionManager.Watch(ctx, transactionHash, chainClient) + return api.s.pendingTxManager.Watch(ctx, transactionHash, chainClient) } func (api *API) GetCryptoOnRamps(ctx context.Context) ([]CryptoOnRamp, error) { diff --git a/services/wallet/service.go b/services/wallet/service.go index 158a4568c..bac8ffe19 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -17,6 +17,7 @@ import ( "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/ens" + "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/stickers" "github.com/status-im/status-go/services/wallet/activity" "github.com/status-im/status-go/services/wallet/collectibles" @@ -49,6 +50,7 @@ func NewService( config *params.NodeConfig, ens *ens.Service, stickers *stickers.Service, + rpcFilterSrvc *rpcfilters.Service, nftMetadataProvider thirdparty.NFTMetadataProvider, ) *Service { cryptoOnRampManager := NewCryptoOnRampManager(&CryptoOnRampOptions{ @@ -91,8 +93,10 @@ func NewService( }) tokenManager := token.NewTokenManager(db, rpcClient, rpcClient.NetworkManager) savedAddressesManager := &SavedAddressesManager{db: db} - transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, walletFeed) - transferController := transfer.NewTransferController(db, rpcClient, accountFeed, walletFeed, transactionManager, tokenManager, config.WalletConfig.LoadAllTransfers) + pendingTxManager := transactions.NewTransactionManager(db, rpcFilterSrvc.TransactionSentToUpstreamEvent(), walletFeed) + transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager) + transferController := transfer.NewTransferController(db, rpcClient, accountFeed, walletFeed, transactionManager, pendingTxManager, + tokenManager, config.WalletConfig.LoadAllTransfers) cryptoCompare := cryptocompare.NewClient() coingecko := coingecko.NewClient() marketManager := market.NewManager(cryptoCompare, coingecko, walletFeed) @@ -111,6 +115,7 @@ func NewService( tokenManager: tokenManager, savedAddressesManager: savedAddressesManager, transactionManager: transactionManager, + pendingTxManager: pendingTxManager, transferController: transferController, cryptoOnRampManager: cryptoOnRampManager, collectiblesManager: collectiblesManager, @@ -120,6 +125,7 @@ func NewService( transactor: transactor, ens: ens, stickers: stickers, + rpcFilterSrvc: rpcFilterSrvc, feed: walletFeed, signals: signals, reader: reader, @@ -138,6 +144,7 @@ type Service struct { savedAddressesManager *SavedAddressesManager tokenManager *token.Manager transactionManager *transfer.TransactionManager + pendingTxManager *transactions.TransactionManager cryptoOnRampManager *CryptoOnRampManager transferController *transfer.Controller feesManager *FeeManager @@ -148,6 +155,7 @@ type Service struct { transactor *transactions.Transactor ens *ens.Service stickers *stickers.Service + rpcFilterSrvc *rpcfilters.Service feed *event.Feed signals *walletevent.SignalsTransmitter reader *Reader @@ -163,6 +171,7 @@ func (s *Service) Start() error { s.currency.Start() err := s.signals.Start() s.history.Start() + _ = s.pendingTxManager.Start() s.started = true return err } @@ -181,6 +190,7 @@ func (s *Service) Stop() error { s.reader.Stop() s.history.Stop() s.activity.Stop() + s.pendingTxManager.Stop() s.started = false log.Info("wallet stopped") return nil diff --git a/services/wallet/transfer/commands.go b/services/wallet/transfer/commands.go index a238791b4..fbf3376b5 100644 --- a/services/wallet/transfer/commands.go +++ b/services/wallet/transfer/commands.go @@ -17,6 +17,7 @@ import ( w_common "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/transactions" ) const ( @@ -180,11 +181,13 @@ type controlCommand struct { errorsCount int nonArchivalRPCNode bool transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager } func (c *controlCommand) LoadTransfers(ctx context.Context, limit int) error { - return loadTransfers(ctx, c.accounts, c.blockDAO, c.db, c.chainClient, limit, make(map[common.Address][]*big.Int), c.transactionManager, c.tokenManager, c.feed) + return loadTransfers(ctx, c.accounts, c.blockDAO, c.db, c.chainClient, limit, make(map[common.Address][]*big.Int), + c.transactionManager, c.pendingTxManager, c.tokenManager, c.feed) } func (c *controlCommand) Run(parent context.Context) error { @@ -357,6 +360,7 @@ type transfersCommand struct { chainClient *chain.ClientWithFallback blocksLimit int transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager feed *event.Feed @@ -450,10 +454,10 @@ func (c *transfersCommand) propagatePendingMultiTx(tx Transaction) error { // If any subTx matches a pending entry, mark all of them with the corresponding multiTxID for _, subTx := range tx { // Update MultiTransactionID from pending entry - entry, err := c.transactionManager.GetPendingEntry(c.chainClient.ChainID, subTx.ID) + entry, err := c.pendingTxManager.GetPendingEntry(c.chainClient.ChainID, subTx.ID) if err == nil { // Propagate the MultiTransactionID, in case the pending entry was a multi-transaction - multiTxID = entry.MultiTransactionID + multiTxID = MultiTransactionIDType(entry.MultiTransactionID) break } else if err != sql.ErrNoRows { log.Error("GetPendingEntry error", "error", err) @@ -561,6 +565,7 @@ type loadTransfersCommand struct { chainClient *chain.ClientWithFallback blocksByAddress map[common.Address][]*big.Int transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager blocksLimit int tokenManager *token.Manager feed *event.Feed @@ -574,7 +579,8 @@ func (c *loadTransfersCommand) Command() async.Command { } func (c *loadTransfersCommand) LoadTransfers(ctx context.Context, limit int, blocksByAddress map[common.Address][]*big.Int) error { - return loadTransfers(ctx, c.accounts, c.blockDAO, c.db, c.chainClient, limit, blocksByAddress, c.transactionManager, c.tokenManager, c.feed) + return loadTransfers(ctx, c.accounts, c.blockDAO, c.db, c.chainClient, limit, blocksByAddress, + c.transactionManager, c.pendingTxManager, c.tokenManager, c.feed) } func (c *loadTransfersCommand) Run(parent context.Context) (err error) { @@ -749,7 +755,9 @@ func (c *findAndCheckBlockRangeCommand) fastIndexErc20(ctx context.Context, from func loadTransfers(ctx context.Context, accounts []common.Address, blockDAO *BlockDAO, db *Database, chainClient *chain.ClientWithFallback, blocksLimitPerAccount int, blocksByAddress map[common.Address][]*big.Int, - transactionManager *TransactionManager, tokenManager *token.Manager, feed *event.Feed) error { + transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + tokenManager *token.Manager, feed *event.Feed) error { + log.Info("loadTransfers start", "accounts", accounts, "chain", chainClient.ChainID, "limit", blocksLimitPerAccount) start := time.Now() @@ -769,6 +777,7 @@ func loadTransfers(ctx context.Context, accounts []common.Address, blockDAO *Blo }, blockNums: blocksByAddress[address], transactionManager: transactionManager, + pendingTxManager: pendingTxManager, tokenManager: tokenManager, feed: feed, } diff --git a/services/wallet/transfer/commands_sequential.go b/services/wallet/transfer/commands_sequential.go index 663305b23..bfe9b1ca7 100644 --- a/services/wallet/transfer/commands_sequential.go +++ b/services/wallet/transfer/commands_sequential.go @@ -13,6 +13,7 @@ import ( "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/transactions" ) type findNewBlocksCommand struct { @@ -318,8 +319,8 @@ func (c *findBlocksCommand) fastIndexErc20(ctx context.Context, fromBlockNumber } func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *BlockDAO, db *Database, - chainClient *chain.ClientWithFallback, transactionManager *TransactionManager, tokenManager *token.Manager, - feed *event.Feed, blocksLoadedCh <-chan []*DBHeader) { + chainClient *chain.ClientWithFallback, transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + tokenManager *token.Manager, feed *event.Feed, blocksLoadedCh <-chan []*DBHeader) { log.Debug("loadTransfersLoop start", "chain", chainClient.ChainID, "account", account) @@ -339,7 +340,7 @@ func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *Bl blocksByAddress := map[common.Address][]*big.Int{account: blockNums} go func() { _ = loadTransfers(ctx, []common.Address{account}, blockDAO, db, chainClient, noBlockLimit, - blocksByAddress, transactionManager, tokenManager, feed) + blocksByAddress, transactionManager, pendingTxManager, tokenManager, feed) }() } } @@ -347,7 +348,8 @@ func loadTransfersLoop(ctx context.Context, account common.Address, blockDAO *Bl func newLoadBlocksAndTransfersCommand(account common.Address, db *Database, blockDAO *BlockDAO, chainClient *chain.ClientWithFallback, feed *event.Feed, - transactionManager *TransactionManager, tokenManager *token.Manager) *loadBlocksAndTransfersCommand { + transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, + tokenManager *token.Manager) *loadBlocksAndTransfersCommand { return &loadBlocksAndTransfersCommand{ account: account, @@ -358,6 +360,7 @@ func newLoadBlocksAndTransfersCommand(account common.Address, db *Database, feed: feed, errorsCount: 0, transactionManager: transactionManager, + pendingTxManager: pendingTxManager, tokenManager: tokenManager, blocksLoadedCh: make(chan []*DBHeader, 100), } @@ -374,6 +377,7 @@ type loadBlocksAndTransfersCommand struct { errorsCount int // nonArchivalRPCNode bool // TODO Make use of it transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager blocksLoadedCh chan []*DBHeader @@ -425,8 +429,8 @@ func (c *loadBlocksAndTransfersCommand) Command() async.Command { } func (c *loadBlocksAndTransfersCommand) startTransfersLoop(ctx context.Context) { - go loadTransfersLoop(ctx, c.account, c.blockDAO, c.db, c.chainClient, c.transactionManager, c.tokenManager, - c.feed, c.blocksLoadedCh) + go loadTransfersLoop(ctx, c.account, c.blockDAO, c.db, c.chainClient, c.transactionManager, + c.pendingTxManager, c.tokenManager, c.feed, c.blocksLoadedCh) } func (c *loadBlocksAndTransfersCommand) fetchHistoryBlocks(ctx context.Context, group *async.Group, blocksLoadedCh chan []*DBHeader) error { @@ -524,6 +528,7 @@ func (c *loadBlocksAndTransfersCommand) fetchTransfersForLoadedBlocks(group *asy blockDAO: c.blockDAO, chainClient: c.chainClient, transactionManager: c.transactionManager, + pendingTxManager: c.pendingTxManager, tokenManager: c.tokenManager, blocksByAddress: blocksMap, feed: c.feed, diff --git a/services/wallet/transfer/controller.go b/services/wallet/transfer/controller.go index 9ad10cc6a..148bf0f98 100644 --- a/services/wallet/transfer/controller.go +++ b/services/wallet/transfer/controller.go @@ -15,6 +15,7 @@ import ( "github.com/status-im/status-go/services/accounts/accountsevent" "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/token" + "github.com/status-im/status-go/transactions" ) type Controller struct { @@ -26,12 +27,13 @@ type Controller struct { TransferFeed *event.Feed group *async.Group transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager loadAllTransfers bool } func NewTransferController(db *sql.DB, rpcClient *rpc.Client, accountFeed *event.Feed, transferFeed *event.Feed, - transactionManager *TransactionManager, tokenManager *token.Manager, loadAllTransfers bool) *Controller { + transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, tokenManager *token.Manager, loadAllTransfers bool) *Controller { blockDAO := &BlockDAO{db} return &Controller{ @@ -41,6 +43,7 @@ func NewTransferController(db *sql.DB, rpcClient *rpc.Client, accountFeed *event accountFeed: accountFeed, TransferFeed: transferFeed, transactionManager: transactionManager, + pendingTxManager: pendingTxManager, tokenManager: tokenManager, loadAllTransfers: loadAllTransfers, } @@ -115,7 +118,8 @@ func (c *Controller) CheckRecentHistory(chainIDs []uint64, accounts []common.Add return err } } else { - c.reactor = NewReactor(c.db, c.blockDAO, c.TransferFeed, c.transactionManager, c.tokenManager) + c.reactor = NewReactor(c.db, c.blockDAO, c.TransferFeed, c.transactionManager, + c.pendingTxManager, c.tokenManager) err = c.reactor.start(chainClients, accounts, c.loadAllTransfers) if err != nil { diff --git a/services/wallet/transfer/downloader.go b/services/wallet/transfer/downloader.go index e82e9f8c9..7860e04f7 100644 --- a/services/wallet/transfer/downloader.go +++ b/services/wallet/transfer/downloader.go @@ -18,12 +18,6 @@ import ( w_common "github.com/status-im/status-go/services/wallet/common" ) -type MultiTransactionIDType int64 - -const ( - NoMultiTransactionID = MultiTransactionIDType(0) -) - func getLogSubTxID(log types.Log) common.Hash { // Get unique ID by using TxHash and log index index := [4]byte{} diff --git a/services/wallet/transfer/reactor.go b/services/wallet/transfer/reactor.go index 2acfb7c2e..ce8911a0e 100644 --- a/services/wallet/transfer/reactor.go +++ b/services/wallet/transfer/reactor.go @@ -15,6 +15,7 @@ import ( "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/transactions" ) const ( @@ -64,6 +65,7 @@ type OnDemandFetchStrategy struct { group *async.Group balanceCache *balanceCache transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager chainClients map[uint64]*chain.ClientWithFallback accounts []common.Address @@ -86,6 +88,7 @@ func (s *OnDemandFetchStrategy) newControlCommand(chainClient *chain.ClientWithF feed: s.feed, errorsCount: 0, transactionManager: s.transactionManager, + pendingTxManager: s.pendingTxManager, tokenManager: s.tokenManager, } @@ -235,16 +238,19 @@ type Reactor struct { blockDAO *BlockDAO feed *event.Feed transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager strategy HistoryFetcher } -func NewReactor(db *Database, blockDAO *BlockDAO, feed *event.Feed, tm *TransactionManager, tokenManager *token.Manager) *Reactor { +func NewReactor(db *Database, blockDAO *BlockDAO, feed *event.Feed, tm *TransactionManager, + pendingTxManager *transactions.TransactionManager, tokenManager *token.Manager) *Reactor { return &Reactor{ db: db, blockDAO: blockDAO, feed: feed, transactionManager: tm, + pendingTxManager: pendingTxManager, tokenManager: tokenManager, } } @@ -280,6 +286,7 @@ func (r *Reactor) createFetchStrategy(chainClients map[uint64]*chain.ClientWithF r.blockDAO, r.feed, r.transactionManager, + r.pendingTxManager, r.tokenManager, chainClients, accounts, @@ -291,6 +298,7 @@ func (r *Reactor) createFetchStrategy(chainClients map[uint64]*chain.ClientWithF feed: r.feed, blockDAO: r.blockDAO, transactionManager: r.transactionManager, + pendingTxManager: r.pendingTxManager, tokenManager: r.tokenManager, chainClients: chainClients, accounts: accounts, diff --git a/services/wallet/transfer/sequential_fetch_strategy.go b/services/wallet/transfer/sequential_fetch_strategy.go index 42a36db84..5e941d9e0 100644 --- a/services/wallet/transfer/sequential_fetch_strategy.go +++ b/services/wallet/transfer/sequential_fetch_strategy.go @@ -12,10 +12,11 @@ import ( "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/transactions" ) func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Feed, - transactionManager *TransactionManager, + transactionManager *TransactionManager, pendingTxManager *transactions.TransactionManager, tokenManager *token.Manager, chainClients map[uint64]*chain.ClientWithFallback, accounts []common.Address) *SequentialFetchStrategy { @@ -25,6 +26,7 @@ func NewSequentialFetchStrategy(db *Database, blockDAO *BlockDAO, feed *event.Fe blockDAO: blockDAO, feed: feed, transactionManager: transactionManager, + pendingTxManager: pendingTxManager, tokenManager: tokenManager, chainClients: chainClients, accounts: accounts, @@ -38,6 +40,7 @@ type SequentialFetchStrategy struct { mu sync.Mutex group *async.Group transactionManager *TransactionManager + pendingTxManager *transactions.TransactionManager tokenManager *token.Manager chainClients map[uint64]*chain.ClientWithFallback accounts []common.Address @@ -47,7 +50,7 @@ func (s *SequentialFetchStrategy) newCommand(chainClient *chain.ClientWithFallba account common.Address) async.Commander { return newLoadBlocksAndTransfersCommand(account, s.db, s.blockDAO, chainClient, s.feed, - s.transactionManager, s.tokenManager) + s.transactionManager, s.pendingTxManager, s.tokenManager) } func (s *SequentialFetchStrategy) start() error { diff --git a/services/wallet/transfer/transaction.go b/services/wallet/transfer/transaction.go index 7f6c87c48..8aa053256 100644 --- a/services/wallet/transfer/transaction.go +++ b/services/wallet/transfer/transaction.go @@ -11,44 +11,43 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/multiaccounts/accounts" "github.com/status-im/status-go/params" - "github.com/status-im/status-go/rpc/chain" - "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/bridge" wallet_common "github.com/status-im/status-go/services/wallet/common" - "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/transactions" ) +type MultiTransactionIDType int64 + const ( - // PendingTransactionUpdate is emitted when a pending transaction is updated (added or deleted) - EventPendingTransactionUpdate walletevent.EventType = "pending-transaction-update" + NoMultiTransactionID = MultiTransactionIDType(0) ) type TransactionManager struct { - db *sql.DB - gethManager *account.GethManager - transactor *transactions.Transactor - config *params.NodeConfig - accountsDB *accounts.Database - eventFeed *event.Feed + db *sql.DB + gethManager *account.GethManager + transactor *transactions.Transactor + config *params.NodeConfig + accountsDB *accounts.Database + pendingManager *transactions.TransactionManager } func NewTransactionManager(db *sql.DB, gethManager *account.GethManager, transactor *transactions.Transactor, - config *params.NodeConfig, accountsDB *accounts.Database, eventFeed *event.Feed) *TransactionManager { + config *params.NodeConfig, accountsDB *accounts.Database, + pendingTxManager *transactions.TransactionManager) *TransactionManager { + return &TransactionManager{ - db: db, - gethManager: gethManager, - transactor: transactor, - config: config, - accountsDB: accountsDB, - eventFeed: eventFeed, + db: db, + gethManager: gethManager, + transactor: transactor, + config: config, + accountsDB: accountsDB, + pendingManager: pendingTxManager, } } @@ -95,241 +94,12 @@ type MultiTransactionCommandResult struct { Hashes map[uint64][]types.Hash `json:"hashes"` } -type PendingTrxType string - -const ( - RegisterENS PendingTrxType = "RegisterENS" - ReleaseENS PendingTrxType = "ReleaseENS" - SetPubKey PendingTrxType = "SetPubKey" - BuyStickerPack PendingTrxType = "BuyStickerPack" - WalletTransfer PendingTrxType = "WalletTransfer" - CollectibleDeployment PendingTrxType = "CollectibleDeployment" - CollectibleAirdrop PendingTrxType = "CollectibleAirdrop" - CollectibleRemoteSelfDestruct PendingTrxType = "CollectibleRemoteSelfDestruct" - CollectibleBurn PendingTrxType = "CollectibleBurn" -) - -type PendingTransaction struct { - Hash common.Hash `json:"hash"` - Timestamp uint64 `json:"timestamp"` - Value bigint.BigInt `json:"value"` - From common.Address `json:"from"` - To common.Address `json:"to"` - Data string `json:"data"` - Symbol string `json:"symbol"` - GasPrice bigint.BigInt `json:"gasPrice"` - GasLimit bigint.BigInt `json:"gasLimit"` - Type PendingTrxType `json:"type"` - AdditionalData string `json:"additionalData"` - ChainID uint64 `json:"network_id"` - MultiTransactionID MultiTransactionIDType `json:"multi_transaction_id"` -} - type TransactionIdentity struct { ChainID wallet_common.ChainID `json:"chainId"` Hash common.Hash `json:"hash"` Address common.Address `json:"address"` } -const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data, - symbol, gas_price, gas_limit, type, additional_data, - network_id, COALESCE(multi_transaction_id, 0) - FROM pending_transactions - ` - -func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err error) { - for rows.Next() { - transaction := &PendingTransaction{ - Value: bigint.BigInt{Int: new(big.Int)}, - GasPrice: bigint.BigInt{Int: new(big.Int)}, - GasLimit: bigint.BigInt{Int: new(big.Int)}, - } - err := rows.Scan(&transaction.Hash, - &transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - &transaction.From, - &transaction.To, - &transaction.Data, - &transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - &transaction.Type, - &transaction.AdditionalData, - &transaction.ChainID, - &transaction.MultiTransactionID, - ) - if err != nil { - return nil, err - } - - transactions = append(transactions, transaction) - } - return transactions, nil -} - -func (tm *TransactionManager) GetAllPending(chainIDs []uint64) ([]*PendingTransaction, error) { - if len(chainIDs) == 0 { - return nil, errors.New("at least 1 chainID is required") - } - - inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" - var parameters []interface{} - for _, c := range chainIDs { - parameters = append(parameters, c) - } - - rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s)", inVector), parameters...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToTransactions(rows) -} - -func (tm *TransactionManager) GetPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) { - if len(chainIDs) == 0 { - return nil, errors.New("at least 1 chainID is required") - } - - inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" - var parameters []interface{} - for _, c := range chainIDs { - parameters = append(parameters, c) - } - - parameters = append(parameters, address) - - rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s) AND from_address = ?", inVector), parameters...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToTransactions(rows) -} - -// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity -// TODO: consider using address also in case we expect to have also for the receiver -func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) { - row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data, - symbol, gas_price, gas_limit, type, additional_data, - network_id, COALESCE(multi_transaction_id, 0) - FROM pending_transactions - WHERE network_id = ? AND hash = ?`, chainID, hash) - transaction := &PendingTransaction{ - Hash: hash, - Value: bigint.BigInt{Int: new(big.Int)}, - GasPrice: bigint.BigInt{Int: new(big.Int)}, - GasLimit: bigint.BigInt{Int: new(big.Int)}, - ChainID: chainID, - } - err := row.Scan( - &transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - &transaction.From, - &transaction.To, - &transaction.Data, - &transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - &transaction.Type, - &transaction.AdditionalData, - &transaction.ChainID, - &transaction.MultiTransactionID, - ) - if err != nil { - return nil, err - } - - return transaction, nil -} - -func (tm *TransactionManager) AddPending(transaction PendingTransaction) error { - insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO pending_transactions - (network_id, hash, timestamp, value, from_address, to_address, - data, symbol, gas_price, gas_limit, type, additional_data, multi_transaction_id) - VALUES - (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) - if err != nil { - return err - } - _, err = insert.Exec( - transaction.ChainID, - transaction.Hash, - transaction.Timestamp, - (*bigint.SQLBigIntBytes)(transaction.Value.Int), - transaction.From, - transaction.To, - transaction.Data, - transaction.Symbol, - (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), - (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), - transaction.Type, - transaction.AdditionalData, - transaction.MultiTransactionID, - ) - - // Notify listeners of new pending transaction (used in activity history) - if err == nil { - tm.notifyPendingTransactionListeners(transaction.ChainID, []common.Address{transaction.From, transaction.To}, transaction.Timestamp) - } - return err -} - -func (tm *TransactionManager) notifyPendingTransactionListeners(chainID uint64, addresses []common.Address, timestamp uint64) { - if tm.eventFeed != nil { - tm.eventFeed.Send(walletevent.Event{ - Type: EventPendingTransactionUpdate, - ChainID: chainID, - Accounts: addresses, - At: int64(timestamp), - }) - } -} - -func (tm *TransactionManager) DeletePending(chainID uint64, hash common.Hash) error { - tx, err := tm.db.BeginTx(context.Background(), nil) - if err != nil { - return err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - - row := tx.QueryRow(`SELECT from_address, to_address, timestamp FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) - var from, to common.Address - var timestamp uint64 - err = row.Scan(&from, &to, ×tamp) - if err != nil { - return err - } - - _, err = tx.Exec(`DELETE FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) - if err != nil { - return err - } - err = tx.Commit() - if err == nil { - tm.notifyPendingTransactionListeners(chainID, []common.Address{from, to}, timestamp) - } - return err -} - -func (tm *TransactionManager) Watch(ctx context.Context, transactionHash common.Hash, client *chain.ClientWithFallback) error { - watchTxCommand := &watchTransactionCommand{ - hash: transactionHash, - client: client, - } - - commandContext, cancel := context.WithTimeout(ctx, 10*time.Minute) - defer cancel() - - return watchTxCommand.Command()(commandContext) -} - const multiTransactionColumns = "from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { @@ -381,7 +151,7 @@ func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (Mul insert, err := db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) if err != nil { - return 0, err + return NoMultiTransactionID, err } timestamp := time.Now().Unix() @@ -401,7 +171,7 @@ func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (Mul timestamp, ) if err != nil { - return 0, err + return NoMultiTransactionID, err } defer insert.Close() multiTransactionID, err := result.LastInsertId() @@ -453,7 +223,74 @@ func (tm *TransactionManager) UpdateMultiTransaction(multiTransaction *MultiTran return updateMultiTransaction(tm.db, multiTransaction) } -func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionCommandResult, error) { +func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand, + data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionCommandResult, error) { + + multiTransaction := multiTransactionFromCommand(command) + + multiTransactionID, err := insertMultiTransaction(tm.db, multiTransaction) + if err != nil { + return nil, err + } + + multiTransaction.ID = uint(multiTransactionID) + hashes, err := tm.sendTransactions(multiTransaction, data, bridges, password) + if err != nil { + return nil, err + } + + err = tm.storePendingTransactions(multiTransaction, hashes, data) + if err != nil { + return nil, err + } + + return &MultiTransactionCommandResult{ + ID: int64(multiTransactionID), + Hashes: hashes, + }, nil +} + +func (tm *TransactionManager) storePendingTransactions(multiTransaction *MultiTransaction, + hashes map[uint64][]types.Hash, data []*bridge.TransactionBridge) error { + + txs := createPendingTransactions(hashes, data, multiTransaction) + for _, tx := range txs { + err := tm.pendingManager.AddPending(tx) + if err != nil { + return err + } + } + return nil +} + +func createPendingTransactions(hashes map[uint64][]types.Hash, data []*bridge.TransactionBridge, + multiTransaction *MultiTransaction) []*transactions.PendingTransaction { + + txs := make([]*transactions.PendingTransaction, 0) + for _, tx := range data { + for _, hash := range hashes[tx.ChainID] { + pendingTransaction := &transactions.PendingTransaction{ + Hash: common.Hash(hash), + Timestamp: uint64(time.Now().Unix()), + Value: bigint.BigInt{Int: multiTransaction.FromAmount.ToInt()}, + From: common.Address(tx.From()), + To: common.Address(tx.To()), + Data: tx.Data().String(), + Type: transactions.WalletTransfer, + ChainID: tx.ChainID, + MultiTransactionID: int64(multiTransaction.ID), + Symbol: multiTransaction.FromAsset, + } + txs = append(txs, pendingTransaction) + } + } + return txs +} + +func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { + + log.Info("Creating multi transaction", "command", command) + multiTransaction := &MultiTransaction{ FromAddress: command.FromAddress, ToAddress: command.ToAddress, @@ -464,12 +301,16 @@ func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Cont Type: command.Type, } - selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password) - if err != nil { - return nil, err - } + return multiTransaction +} - multiTransactionID, err := insertMultiTransaction(tm.db, multiTransaction) +func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransaction, + data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) ( + map[uint64][]types.Hash, error) { + + log.Info("Making transactions", "multiTransaction", multiTransaction) + + selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password) if err != nil { return nil, err } @@ -480,29 +321,9 @@ func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Cont if err != nil { return nil, err } - pendingTransaction := PendingTransaction{ - Hash: common.Hash(hash), - Timestamp: uint64(time.Now().Unix()), - Value: bigint.BigInt{Int: multiTransaction.FromAmount.ToInt()}, - From: common.Address(tx.From()), - To: common.Address(tx.To()), - Data: tx.Data().String(), - Type: WalletTransfer, - ChainID: tx.ChainID, - MultiTransactionID: multiTransactionID, - Symbol: multiTransaction.FromAsset, - } - err = tm.AddPending(pendingTransaction) - if err != nil { - return nil, err - } hashes[tx.ChainID] = append(hashes[tx.ChainID], hash) } - - return &MultiTransactionCommandResult{ - ID: int64(multiTransactionID), - Hashes: hashes, - }, nil + return hashes, nil } func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []MultiTransactionIDType) ([]*MultiTransaction, error) { @@ -606,32 +427,3 @@ func (tm *TransactionManager) getVerifiedWalletAccount(address, password string) AccountKey: key, }, nil } - -type watchTransactionCommand struct { - client *chain.ClientWithFallback - hash common.Hash -} - -func (c *watchTransactionCommand) Command() async.Command { - return async.FiniteCommand{ - Interval: 10 * time.Second, - Runable: c.Run, - }.Run -} - -func (c *watchTransactionCommand) Run(ctx context.Context) error { - requestContext, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - _, isPending, err := c.client.TransactionByHash(requestContext, c.hash) - - if err != nil { - log.Error("Watching transaction error", "error", err) - return err - } - - if isPending { - return errors.New("transaction is pending") - } - - return nil -} diff --git a/services/wallet/transfer/transaction_test.go b/services/wallet/transfer/transaction_test.go index 0b421b4c5..0c63e1c62 100644 --- a/services/wallet/transfer/transaction_test.go +++ b/services/wallet/transfer/transaction_test.go @@ -9,16 +9,14 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/event" "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/services/wallet/bigint" ) func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { db, err := appdatabase.SetupTestMemorySQLDB("wallet-transfer-transaction-tests") require.NoError(t, err) - return &TransactionManager{db, nil, nil, nil, nil, &event.Feed{}}, func() { + return &TransactionManager{db, nil, nil, nil, nil, nil}, func() { require.NoError(t, db.Close()) } } @@ -39,59 +37,6 @@ func areMultiTransactionsEqual(mt1, mt2 *MultiTransaction) bool { mt1.CrossTxID == mt2.CrossTxID } -func TestPendingTransactions(t *testing.T) { - manager, stop := setupTestTransactionDB(t) - defer stop() - - trx := PendingTransaction{ - Hash: common.Hash{1}, - From: common.Address{1}, - To: common.Address{2}, - Type: RegisterENS, - AdditionalData: "someuser.stateofus.eth", - Value: bigint.BigInt{Int: big.NewInt(123)}, - GasLimit: bigint.BigInt{Int: big.NewInt(21000)}, - GasPrice: bigint.BigInt{Int: big.NewInt(1)}, - ChainID: 777, - } - - rst, err := manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Nil(t, rst) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Nil(t, rst) - - err = manager.AddPending(trx) - require.NoError(t, err) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Equal(t, 1, len(rst)) - require.Equal(t, trx, *rst[0]) - - rst, err = manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Equal(t, 1, len(rst)) - require.Equal(t, trx, *rst[0]) - - rst, err = manager.GetPendingByAddress([]uint64{777}, common.Address{2}) - require.NoError(t, err) - require.Nil(t, rst) - - err = manager.DeletePending(777, trx.Hash) - require.NoError(t, err) - - rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) - require.NoError(t, err) - require.Equal(t, 0, len(rst)) - - rst, err = manager.GetAllPending([]uint64{777}) - require.NoError(t, err) - require.Equal(t, 0, len(rst)) -} - func TestBridgeMultiTransactions(t *testing.T) { manager, stop := setupTestTransactionDB(t) defer stop() diff --git a/services/web3provider/api.go b/services/web3provider/api.go index 8c447efe1..d517e0a39 100644 --- a/services/web3provider/api.go +++ b/services/web3provider/api.go @@ -327,7 +327,7 @@ func (api *API) ProcessWeb3ReadOnlyRequest(request Web3SendAsyncReadOnlyRequest) return nil, err } - hash, err := api.sendTransaction(request.Payload.ChainID, trxArgs, request.Payload.Password) + hash, err := api.sendTransaction(request.Payload.ChainID, trxArgs, request.Payload.Password, Web3SendAsyncReadOnly) if err != nil { log.Error("could not send transaction message", "err", err) return &Web3SendAsyncReadOnlyResponse{ diff --git a/services/web3provider/signature.go b/services/web3provider/signature.go index f26264fa8..6f20d7c0d 100644 --- a/services/web3provider/signature.go +++ b/services/web3provider/signature.go @@ -4,11 +4,13 @@ import ( "fmt" "math/big" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" signercore "github.com/ethereum/go-ethereum/signer/core/apitypes" "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/services/rpcfilters" "github.com/status-im/status-go/services/typeddata" "github.com/status-im/status-go/transactions" ) @@ -71,7 +73,7 @@ func (api *API) signTypedDataV4(typed signercore.TypedData, address string, pass } // SendTransaction creates a new transaction and waits until it's complete. -func (api *API) sendTransaction(chainID uint64, sendArgs transactions.SendTxArgs, password string) (hash types.Hash, err error) { +func (api *API) sendTransaction(chainID uint64, sendArgs transactions.SendTxArgs, password string, requestType string) (hash types.Hash, err error) { verifiedAccount, err := api.getVerifiedWalletAccount(sendArgs.From.String(), password) if err != nil { return hash, err @@ -82,7 +84,12 @@ func (api *API) sendTransaction(chainID uint64, sendArgs transactions.SendTxArgs return } - go api.s.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(hash) + go api.s.rpcFiltersSrvc.TriggerTransactionSentToUpstreamEvent(&rpcfilters.PendingTxInfo{ + Hash: common.Hash(hash), + Type: requestType, + From: common.Address(sendArgs.From), + ChainID: chainID, + }) return } diff --git a/transactions/pending.go b/transactions/pending.go new file mode 100644 index 000000000..02b85420e --- /dev/null +++ b/transactions/pending.go @@ -0,0 +1,363 @@ +package transactions + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math/big" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/status-im/status-go/rpc/chain" + "github.com/status-im/status-go/services/rpcfilters" + "github.com/status-im/status-go/services/wallet/async" + "github.com/status-im/status-go/services/wallet/bigint" + "github.com/status-im/status-go/services/wallet/walletevent" +) + +const ( + // PendingTransactionUpdate is emitted when a pending transaction is updated (added or deleted) + EventPendingTransactionUpdate walletevent.EventType = "pending-transaction-update" +) + +type TransactionManager struct { + db *sql.DB + pendingTxEvent rpcfilters.ChainEvent + eventFeed *event.Feed + quit chan struct{} +} + +func NewTransactionManager(db *sql.DB, pendingTxEvent rpcfilters.ChainEvent, eventFeed *event.Feed) *TransactionManager { + return &TransactionManager{ + db: db, + eventFeed: eventFeed, + pendingTxEvent: pendingTxEvent, + } +} + +func (tm *TransactionManager) Start() error { + if tm.quit != nil { + return errors.New("latest transaction sent to upstream event is already started") + } + + tm.quit = make(chan struct{}) + + go func() { + _, chi := tm.pendingTxEvent.Subscribe() + ch, ok := chi.(chan *rpcfilters.PendingTxInfo) + if !ok { + panic("pendingTxEvent returned wront type of channel") + } + + for { + select { + case tx := <-ch: + log.Info("Pending transaction event received", tx) + err := tm.AddPending(&PendingTransaction{ + Hash: tx.Hash, + Timestamp: uint64(time.Now().Unix()), + From: tx.From, + ChainID: tx.ChainID, + }) + if err != nil { + log.Error("Failed to add pending transaction", "error", err, "hash", tx.Hash, + "chainID", tx.ChainID) + } + case <-tm.quit: + return + } + } + }() + + return tm.pendingTxEvent.Start() +} + +func (tm *TransactionManager) Stop() { + if tm.quit == nil { + return + } + + select { + case <-tm.quit: + return + default: + close(tm.quit) + } + + tm.quit = nil +} + +type PendingTrxType string + +const ( + RegisterENS PendingTrxType = "RegisterENS" + ReleaseENS PendingTrxType = "ReleaseENS" + SetPubKey PendingTrxType = "SetPubKey" + BuyStickerPack PendingTrxType = "BuyStickerPack" + WalletTransfer PendingTrxType = "WalletTransfer" + CollectibleDeployment PendingTrxType = "CollectibleDeployment" + CollectibleAirdrop PendingTrxType = "CollectibleAirdrop" + CollectibleRemoteSelfDestruct PendingTrxType = "CollectibleRemoteSelfDestruct" + CollectibleBurn PendingTrxType = "CollectibleBurn" +) + +type PendingTransaction struct { + Hash common.Hash `json:"hash"` + Timestamp uint64 `json:"timestamp"` + Value bigint.BigInt `json:"value"` + From common.Address `json:"from"` + To common.Address `json:"to"` + Data string `json:"data"` + Symbol string `json:"symbol"` + GasPrice bigint.BigInt `json:"gasPrice"` + GasLimit bigint.BigInt `json:"gasLimit"` + Type PendingTrxType `json:"type"` + AdditionalData string `json:"additionalData"` + ChainID uint64 `json:"network_id"` + MultiTransactionID int64 `json:"multi_transaction_id"` +} + +const selectFromPending = `SELECT hash, timestamp, value, from_address, to_address, data, + symbol, gas_price, gas_limit, type, additional_data, + network_id, COALESCE(multi_transaction_id, 0) + FROM pending_transactions + ` + +func rowsToTransactions(rows *sql.Rows) (transactions []*PendingTransaction, err error) { + for rows.Next() { + transaction := &PendingTransaction{ + Value: bigint.BigInt{Int: new(big.Int)}, + GasPrice: bigint.BigInt{Int: new(big.Int)}, + GasLimit: bigint.BigInt{Int: new(big.Int)}, + } + err := rows.Scan(&transaction.Hash, + &transaction.Timestamp, + (*bigint.SQLBigIntBytes)(transaction.Value.Int), + &transaction.From, + &transaction.To, + &transaction.Data, + &transaction.Symbol, + (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), + (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), + &transaction.Type, + &transaction.AdditionalData, + &transaction.ChainID, + &transaction.MultiTransactionID, + ) + if err != nil { + return nil, err + } + + transactions = append(transactions, transaction) + } + return transactions, nil +} + +func (tm *TransactionManager) GetAllPending(chainIDs []uint64) ([]*PendingTransaction, error) { + log.Info("Getting all pending transactions", "chainIDs", chainIDs) + + if len(chainIDs) == 0 { + return nil, errors.New("at least 1 chainID is required") + } + + inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" + var parameters []interface{} + for _, c := range chainIDs { + parameters = append(parameters, c) + } + + rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s)", inVector), parameters...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToTransactions(rows) +} + +func (tm *TransactionManager) GetPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) { + log.Info("Getting pending transaction by address", "chainIDs", chainIDs, "address", address) + + if len(chainIDs) == 0 { + return nil, errors.New("at least 1 chainID is required") + } + + inVector := strings.Repeat("?, ", len(chainIDs)-1) + "?" + var parameters []interface{} + for _, c := range chainIDs { + parameters = append(parameters, c) + } + + parameters = append(parameters, address) + + rows, err := tm.db.Query(fmt.Sprintf(selectFromPending+"WHERE network_id in (%s) AND from_address = ?", inVector), parameters...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToTransactions(rows) +} + +// GetPendingEntry returns sql.ErrNoRows if no pending transaction is found for the given identity +// TODO: consider using address also in case we expect to have also for the receiver +func (tm *TransactionManager) GetPendingEntry(chainID uint64, hash common.Hash) (*PendingTransaction, error) { + log.Info("Getting pending transaction", "chainID", chainID, "hash", hash) + + row := tm.db.QueryRow(`SELECT timestamp, value, from_address, to_address, data, + symbol, gas_price, gas_limit, type, additional_data, + network_id, COALESCE(multi_transaction_id, 0) + FROM pending_transactions + WHERE network_id = ? AND hash = ?`, chainID, hash) + transaction := &PendingTransaction{ + Hash: hash, + Value: bigint.BigInt{Int: new(big.Int)}, + GasPrice: bigint.BigInt{Int: new(big.Int)}, + GasLimit: bigint.BigInt{Int: new(big.Int)}, + ChainID: chainID, + } + err := row.Scan( + &transaction.Timestamp, + (*bigint.SQLBigIntBytes)(transaction.Value.Int), + &transaction.From, + &transaction.To, + &transaction.Data, + &transaction.Symbol, + (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), + (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), + &transaction.Type, + &transaction.AdditionalData, + &transaction.ChainID, + &transaction.MultiTransactionID, + ) + if err != nil { + return nil, err + } + + return transaction, nil +} + +func (tm *TransactionManager) AddPending(transaction *PendingTransaction) error { + insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO pending_transactions + (network_id, hash, timestamp, value, from_address, to_address, + data, symbol, gas_price, gas_limit, type, additional_data, multi_transaction_id) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + return err + } + _, err = insert.Exec( + transaction.ChainID, + transaction.Hash, + transaction.Timestamp, + (*bigint.SQLBigIntBytes)(transaction.Value.Int), + transaction.From, + transaction.To, + transaction.Data, + transaction.Symbol, + (*bigint.SQLBigIntBytes)(transaction.GasPrice.Int), + (*bigint.SQLBigIntBytes)(transaction.GasLimit.Int), + transaction.Type, + transaction.AdditionalData, + transaction.MultiTransactionID, + ) + // Notify listeners of new pending transaction (used in activity history) + if err == nil { + tm.notifyPendingTransactionListeners(transaction.ChainID, []common.Address{transaction.From, transaction.To}, transaction.Timestamp) + } + return err +} + +func (tm *TransactionManager) notifyPendingTransactionListeners(chainID uint64, addresses []common.Address, timestamp uint64) { + if tm.eventFeed != nil { + tm.eventFeed.Send(walletevent.Event{ + Type: EventPendingTransactionUpdate, + ChainID: chainID, + Accounts: addresses, + At: int64(timestamp), + }) + } +} + +func (tm *TransactionManager) deletePending(chainID uint64, hash common.Hash) error { + tx, err := tm.db.BeginTx(context.Background(), nil) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + row := tx.QueryRow(`SELECT from_address, to_address, timestamp FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) + var from, to common.Address + var timestamp uint64 + err = row.Scan(&from, &to, ×tamp) + if err != nil { + return err + } + + _, err = tx.Exec(`DELETE FROM pending_transactions WHERE network_id = ? AND hash = ?`, chainID, hash) + if err != nil { + return err + } + err = tx.Commit() + if err == nil { + tm.notifyPendingTransactionListeners(chainID, []common.Address{from, to}, timestamp) + } + return err +} + +func (tm *TransactionManager) Watch(ctx context.Context, transactionHash common.Hash, client *chain.ClientWithFallback) error { + log.Info("Watching transaction", "chainID", client.ChainID, "hash", transactionHash) + + watchTxCommand := &watchTransactionCommand{ + hash: transactionHash, + client: client, + } + + commandContext, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + + err := watchTxCommand.Command()(commandContext) + if err != nil { + log.Error("watchTxCommand error", "error", err, "chainID", client.ChainID, "hash", transactionHash) + return err + } + + return tm.deletePending(client.ChainID, transactionHash) +} + +type watchTransactionCommand struct { + client *chain.ClientWithFallback + hash common.Hash +} + +func (c *watchTransactionCommand) Command() async.Command { + return async.FiniteCommand{ + Interval: 10 * time.Second, + Runable: c.Run, + }.Run +} + +func (c *watchTransactionCommand) Run(ctx context.Context) error { + requestContext, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + _, isPending, err := c.client.TransactionByHash(requestContext, c.hash) + + if err != nil { + log.Error("Watching transaction error", "error", err) + return err + } + + if isPending { + return errors.New("transaction is pending") + } + + return nil +} diff --git a/transactions/transaction_test.go b/transactions/transaction_test.go new file mode 100644 index 000000000..db374abd3 --- /dev/null +++ b/transactions/transaction_test.go @@ -0,0 +1,74 @@ +package transactions + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/common" + + "github.com/status-im/status-go/appdatabase" + "github.com/status-im/status-go/services/wallet/bigint" +) + +func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { + db, err := appdatabase.SetupTestMemorySQLDB("wallet-transfer-transaction-tests") + require.NoError(t, err) + return &TransactionManager{db, nil, nil, nil}, func() { + require.NoError(t, db.Close()) + } +} + +func TestPendingTransactions(t *testing.T) { + manager, stop := setupTestTransactionDB(t) + defer stop() + + trx := PendingTransaction{ + Hash: common.Hash{1}, + From: common.Address{1}, + To: common.Address{2}, + Type: RegisterENS, + AdditionalData: "someuser.stateofus.eth", + Value: bigint.BigInt{Int: big.NewInt(123)}, + GasLimit: bigint.BigInt{Int: big.NewInt(21000)}, + GasPrice: bigint.BigInt{Int: big.NewInt(1)}, + ChainID: 777, + } + + rst, err := manager.GetAllPending([]uint64{777}) + require.NoError(t, err) + require.Nil(t, rst) + + rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) + require.NoError(t, err) + require.Nil(t, rst) + + err = manager.AddPending(&trx) + require.NoError(t, err) + + rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) + require.NoError(t, err) + require.Equal(t, 1, len(rst)) + require.Equal(t, trx, *rst[0]) + + rst, err = manager.GetAllPending([]uint64{777}) + require.NoError(t, err) + require.Equal(t, 1, len(rst)) + require.Equal(t, trx, *rst[0]) + + rst, err = manager.GetPendingByAddress([]uint64{777}, common.Address{2}) + require.NoError(t, err) + require.Nil(t, rst) + + err = manager.deletePending(777, trx.Hash) + require.NoError(t, err) + + rst, err = manager.GetPendingByAddress([]uint64{777}, trx.From) + require.NoError(t, err) + require.Equal(t, 0, len(rst)) + + rst, err = manager.GetAllPending([]uint64{777}) + require.NoError(t, err) + require.Equal(t, 0, len(rst)) +} diff --git a/transactions/transactor.go b/transactions/transactor.go index c524b7c3f..79b27675d 100644 --- a/transactions/transactor.go +++ b/transactions/transactor.go @@ -66,6 +66,10 @@ func (t *Transactor) SetNetworkID(networkID uint64) { t.networkID = networkID } +func (t *Transactor) NetworkID() uint64 { + return t.networkID +} + // SetRPC sets RPC params, a client and a timeout func (t *Transactor) SetRPC(rpcClient *rpc.Client, timeout time.Duration) { t.rpcWrapper = newRPCWrapper(rpcClient, rpcClient.UpstreamChainID)