From bf78c15e6fca8bdb1a734cf547d84798fa07c733 Mon Sep 17 00:00:00 2001 From: Ivan Belyakov Date: Mon, 3 Jun 2024 01:37:51 +0200 Subject: [PATCH] test(wallet)_: implement build/sign/watch unit tests for multitransaction manager Closes #14848 --- multiaccounts/accounts/database.go | 7 + services/wallet/api.go | 31 +-- .../wallet/transfer/transaction_manager.go | 8 +- .../transaction_manager_multitransaction.go | 42 ++++ ...ansaction_manager_multitransaction_test.go | 84 ++++++- .../transfer/transaction_manager_test.go | 230 +++++++++++++++++- 6 files changed, 357 insertions(+), 45 deletions(-) diff --git a/multiaccounts/accounts/database.go b/multiaccounts/accounts/database.go index 887e0c625..b651b5070 100644 --- a/multiaccounts/accounts/database.go +++ b/multiaccounts/accounts/database.go @@ -284,6 +284,13 @@ func (a *Keypair) Operability() AccountOperable { return AccountFullyOperable } +// TODO: implement clean full interface. This might require refactoring Database methods +type AccountsStorage interface { + GetKeypairByKeyUID(keyUID string) (*Keypair, error) + GetAccountByAddress(address types.Address) (*Account, error) + AddressExists(address types.Address) (bool, error) +} + // Database sql wrapper for operations with browser objects. type Database struct { settings.DatabaseSettingsManager diff --git a/services/wallet/api.go b/services/wallet/api.go index 0bc9ff7ef..bea2e836a 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -4,7 +4,6 @@ import ( "context" "encoding/hex" "encoding/json" - "errors" "fmt" "math/big" "strings" @@ -31,7 +30,6 @@ import ( "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" "github.com/status-im/status-go/services/wallet/walletconnect" - "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/transactions" ) @@ -271,38 +269,11 @@ func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identit // TODO - #11861: Remove this and replace with EventPendingTransactionStatusChanged event and Delete to confirm the transaction where it is needed func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) (err error) { log.Debug("wallet.api.WatchTransactionByChainID", "chainID", chainID, "transactionHash", transactionHash) - var status *transactions.TxStatus defer func() { log.Debug("wallet.api.WatchTransactionByChainID return", "err", err, "chainID", chainID, "transactionHash", transactionHash) }() - // Workaround to keep the blocking call until the clients use the PendingTxTracker APIs - eventChan := make(chan walletevent.Event, 2) - sub := api.s.feed.Subscribe(eventChan) - defer sub.Unsubscribe() - - status, err = api.s.pendingTxManager.Watch(ctx, wcommon.ChainID(chainID), transactionHash) - if err == nil && *status != transactions.Pending { - return nil - } - - for { - select { - case we := <-eventChan: - if transactions.EventPendingTransactionStatusChanged == we.Type { - var p transactions.StatusChangedPayload - err = json.Unmarshal([]byte(we.Message), &p) - if err != nil { - return err - } - if p.ChainID == wcommon.ChainID(chainID) && p.Hash == transactionHash { - return nil - } - } - case <-time.After(10 * time.Minute): - return errors.New("timeout watching for pending transaction") - } - } + return api.s.transactionManager.WatchTransaction(ctx, chainID, transactionHash) } func (api *API) GetCryptoOnRamps(ctx context.Context) ([]onramp.CryptoOnRamp, error) { diff --git a/services/wallet/transfer/transaction_manager.go b/services/wallet/transfer/transaction_manager.go index 1c21800ed..d6679cbb5 100644 --- a/services/wallet/transfer/transaction_manager.go +++ b/services/wallet/transfer/transaction_manager.go @@ -38,7 +38,7 @@ type TransactionManager struct { gethManager *account.GethManager transactor transactions.TransactorIface config *params.NodeConfig - accountsDB *accounts.Database + accountsDB accounts.AccountsStorage pendingTracker *transactions.PendingTxTracker eventFeed *event.Feed @@ -59,7 +59,7 @@ func NewTransactionManager( gethManager *account.GethManager, transactor transactions.TransactorIface, config *params.NodeConfig, - accountsDB *accounts.Database, + accountsDB accounts.AccountsStorage, pendingTxManager *transactions.PendingTxTracker, eventFeed *event.Feed, ) *TransactionManager { @@ -160,6 +160,10 @@ func NewMultiTransaction(timestamp uint64, fromNetworkID, toNetworkID uint64, fr } func (tm *TransactionManager) SignMessage(message types.HexBytes, account *types.Key) (string, error) { + if account == nil || account.PrivateKey == nil { + return "", fmt.Errorf("account or private key is nil") + } + signature, err := crypto.Sign(message[:], account.PrivateKey) return types.EncodeHex(signature), err diff --git a/services/wallet/transfer/transaction_manager_multitransaction.go b/services/wallet/transfer/transaction_manager_multitransaction.go index a40b04c04..36fb7fc1d 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction.go +++ b/services/wallet/transfer/transaction_manager_multitransaction.go @@ -2,19 +2,30 @@ package transfer import ( "context" + "encoding/json" "fmt" + "time" + "github.com/pkg/errors" + + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/services/wallet/bridge" wallet_common "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/signal" + "github.com/status-im/status-go/transactions" ) const multiTransactionColumns = "id, from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" const selectMultiTransactionColumns = "id, COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" +var pendingTxTimeout time.Duration = 10 * time.Minute +var ErrWatchPendingTxTimeout = errors.New("timeout watching for pending transaction") +var ErrPendingTxNotExists = errors.New("pending transaction does not exist") + func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) { return multiTransaction.ID, tm.storage.CreateMultiTransaction(multiTransaction) } @@ -149,3 +160,34 @@ func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.C return nil, nil } + +func (tm *TransactionManager) WatchTransaction(ctx context.Context, chainID uint64, transactionHash common.Hash) error { + // Workaround to keep the blocking call until the clients use the PendingTxTracker APIs + eventChan := make(chan walletevent.Event, 2) + sub := tm.eventFeed.Subscribe(eventChan) + defer sub.Unsubscribe() + + status, err := tm.pendingTracker.Watch(ctx, wallet_common.ChainID(chainID), transactionHash) + if err == nil && *status != transactions.Pending { + log.Error("transaction is not pending", "status", status) + return nil + } + + for { + select { + case we := <-eventChan: + if transactions.EventPendingTransactionStatusChanged == we.Type { + var p transactions.StatusChangedPayload + err = json.Unmarshal([]byte(we.Message), &p) + if err != nil { + return err + } + if p.ChainID == wallet_common.ChainID(chainID) && p.Hash == transactionHash { + return nil + } + } + case <-time.After(pendingTxTimeout): + return ErrWatchPendingTxTimeout + } + } +} diff --git a/services/wallet/transfer/transaction_manager_multitransaction_test.go b/services/wallet/transfer/transaction_manager_multitransaction_test.go index b542150bc..4b7e5871c 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction_test.go +++ b/services/wallet/transfer/transaction_manager_multitransaction_test.go @@ -2,21 +2,28 @@ package transfer import ( "context" + "encoding/json" "math/big" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/event" "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/wallet/bridge" "github.com/status-im/status-go/services/wallet/bridge/mock_bridge" + wallet_common "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/services/wallet/walletevent" + "github.com/status-im/status-go/t/helpers" "github.com/status-im/status-go/transactions" "github.com/status-im/status-go/transactions/mock_transactor" + "github.com/status-im/status-go/walletdatabase" ) func deepCopy(tx *transactions.SendTxArgs) *transactions.SendTxArgs { @@ -146,7 +153,7 @@ func TestSendTransactionsETHFailOnBridge(t *testing.T) { // Call the SendTransactions method _, err := tm.SendTransactions(context.Background(), multiTransaction, data, bridges, account) - require.Error(t, expectedErr, err) + require.ErrorIs(t, expectedErr, err) } func TestSendTransactionsETHFailOnTransactor(t *testing.T) { @@ -162,5 +169,78 @@ func TestSendTransactionsETHFailOnTransactor(t *testing.T) { // Call the SendTransactions method _, err := tm.SendTransactions(context.Background(), multiTransaction, data, bridges, account) - require.Error(t, expectedErr, err) + require.ErrorIs(t, expectedErr, err) +} + +func TestWatchTransaction(t *testing.T) { + tm, _, _ := setupTransactionManager(t) + chainID := uint64(1) + pendingTxTimeout = 2 * time.Millisecond + + walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + chainClient := transactions.NewMockChainClient() + eventFeed := &event.Feed{} + // For now, pending tracker is not interface, so we have to use a real one + tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout) + tm.eventFeed = eventFeed + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*pendingTxTimeout) + defer cancel() + + // Insert a pending transaction + txs := transactions.MockTestTransactions(t, chainClient, []transactions.TestTxSummary{{}}) + err = tm.pendingTracker.StoreAndTrackPendingTx(&txs[0]) // We dont need to track it, but no other way to insert it + require.NoError(t, err) + + txEventPayload := transactions.StatusChangedPayload{ + TxIdentity: transactions.TxIdentity{ + Hash: txs[0].Hash, + ChainID: wallet_common.ChainID(chainID), + }, + Status: transactions.Pending, + } + jsonPayload, err := json.Marshal(txEventPayload) + require.NoError(t, err) + + go func() { + time.Sleep(pendingTxTimeout / 2) + eventFeed.Send(walletevent.Event{ + Type: transactions.EventPendingTransactionStatusChanged, + Message: string(jsonPayload), + }) + }() + + // Call the WatchTransaction method + err = tm.WatchTransaction(ctx, chainID, txs[0].Hash) + require.NoError(t, err) +} + +func TestWatchTransaction_Timeout(t *testing.T) { + tm, _, _ := setupTransactionManager(t) + chainID := uint64(1) + transactionHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + pendingTxTimeout = 2 * time.Millisecond + + walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + chainClient := transactions.NewMockChainClient() + eventFeed := &event.Feed{} + // For now, pending tracker is not interface, so we have to use a real one + tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout) + tm.eventFeed = eventFeed + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + // Insert a pending transaction + txs := transactions.MockTestTransactions(t, chainClient, []transactions.TestTxSummary{{}}) + err = tm.pendingTracker.StoreAndTrackPendingTx(&txs[0]) // We dont need to track it, but no other way to insert it + require.NoError(t, err) + + // Call the WatchTransaction method + err = tm.WatchTransaction(ctx, chainID, transactionHash) + require.ErrorIs(t, err, ErrWatchPendingTxTimeout) } diff --git a/services/wallet/transfer/transaction_manager_test.go b/services/wallet/transfer/transaction_manager_test.go index 3dfcf72c2..b13ac6b61 100644 --- a/services/wallet/transfer/transaction_manager_test.go +++ b/services/wallet/transfer/transaction_manager_test.go @@ -2,25 +2,80 @@ package transfer import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" "math/big" + "reflect" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + gethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/multiaccounts/accounts" wallet_common "github.com/status-im/status-go/services/wallet/common" - "github.com/status-im/status-go/t/helpers" - "github.com/status-im/status-go/walletdatabase" + "github.com/status-im/status-go/transactions" + "github.com/status-im/status-go/transactions/mock_transactor" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" ) -func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { - db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) - require.NoError(t, err) +type dummyAccountsStorage struct { + keypair *accounts.Keypair + account *accounts.Account +} + +func (d *dummyAccountsStorage) GetAccountByAddress(address types.Address) (*accounts.Account, error) { + if address != d.account.Address { + return nil, fmt.Errorf("address not found") + } + return d.account, nil +} + +func (d *dummyAccountsStorage) GetKeypairByKeyUID(keyUID string) (*accounts.Keypair, error) { + if keyUID != d.keypair.KeyUID { + return nil, fmt.Errorf("keyUID not found") + } + return d.keypair, nil +} + +func (d *dummyAccountsStorage) AddressExists(address types.Address) (bool, error) { + return d.account.Address == address, nil +} + +type dummySigner struct{} + +func (d *dummySigner) Hash(tx *gethtypes.Transaction) common.Hash { + return common.HexToHash("0xc8e7a34af766c4ba9dc9b3d49939806fbf41fa01250c5a26afa5659e87b2020b") +} + +func setupTestSuite(t *testing.T) (*TransactionManager, *mock_transactor.MockTransactorIface) { SetMultiTransactionIDGenerator(StaticIDCounter()) // to have different multi-transaction IDs even with fast execution - return &TransactionManager{NewMultiTransactionDB(db), nil, nil, nil, nil, nil, nil, nil, nil, nil}, func() { - require.NoError(t, db.Close()) + accountsDB := setupAccountsStorage() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + transactor := mock_transactor.NewMockTransactorIface(ctrl) + return &TransactionManager{ + storage: NewInMemMultiTransactionStorage(), + accountsDB: accountsDB, + transactor: transactor, + }, transactor +} + +func setupAccountsStorage() *dummyAccountsStorage { + return &dummyAccountsStorage{ + keypair: &accounts.Keypair{ + KeyUID: "keyUid", + }, + account: &accounts.Account{ + KeyUID: "keyUid", + Address: types.Address{1}, + }, } } @@ -41,8 +96,7 @@ func areMultiTransactionsEqual(mt1, mt2 *MultiTransaction) bool { } func TestBridgeMultiTransactions(t *testing.T) { - manager, stop := setupTestTransactionDB(t) - defer stop() + manager, _ := setupTestSuite(t) trx1 := NewMultiTransaction( /* Timestamp: */ 123, @@ -105,8 +159,7 @@ func TestBridgeMultiTransactions(t *testing.T) { } func TestMultiTransactions(t *testing.T) { - manager, stop := setupTestTransactionDB(t) - defer stop() + manager, _ := setupTestSuite(t) trx1 := *NewMultiTransaction( /* Timestamp: */ 123, @@ -165,3 +218,158 @@ func TestMultiTransactions(t *testing.T) { require.True(t, found, "result contains transaction with id %d", id) } } + +func TestSignMessage(t *testing.T) { + tm, _ := setupTestSuite(t) + + message := (types.HexBytes)(make([]byte, 32)) + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + account := &types.Key{ + PrivateKey: privateKey, + } + + signature, err := tm.SignMessage(message, account) + require.NoError(t, err) + require.NotEmpty(t, signature) +} + +func TestSignMessage_InvalidAccount(t *testing.T) { + tm, _ := setupTestSuite(t) + + message := (types.HexBytes)(make([]byte, 32)) + account := &types.Key{ + PrivateKey: nil, + } + + signature, err := tm.SignMessage(message, account) + require.Error(t, err) + require.Empty(t, signature) +} + +func TestSignMessage_InvalidMessage(t *testing.T) { + tm, _ := setupTestSuite(t) + + message := types.HexBytes{} + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + account := &types.Key{ + PrivateKey: privateKey, + } + + signature, err := tm.SignMessage(message, account) + require.Error(t, err) + require.Equal(t, "0x", signature) +} + +func TestBuildTransaction(t *testing.T) { + manager, transactor := setupTestSuite(t) + + chainID := uint64(1) + nonce := uint64(1) + gas := uint64(21000) + sendArgs := transactions.SendTxArgs{ + From: types.Address{1}, + To: &types.Address{2}, + Value: (*hexutil.Big)(big.NewInt(123)), + Nonce: (*hexutil.Uint64)(&nonce), + Gas: (*hexutil.Uint64)(&gas), + GasPrice: (*hexutil.Big)(big.NewInt(1000000000)), + MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)), + MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)), + } + + expectedTx := gethtypes.NewTransaction(nonce, common.Address(*sendArgs.To), sendArgs.Value.ToInt(), gas, sendArgs.GasPrice.ToInt(), nil) + transactor.EXPECT().ValidateAndBuildTransaction(chainID, sendArgs).Return(expectedTx, nil) + + response, err := manager.BuildTransaction(chainID, sendArgs) + require.NoError(t, err) + require.NotNil(t, response) + + accDB := manager.accountsDB.(*dummyAccountsStorage) + signer := dummySigner{} + expectedKeyUID := accDB.keypair.KeyUID + expectedAddress := accDB.account.Address + expectedAddressPath := "" + expectedSignOnKeycard := false + expectedMessageToSign := signer.Hash(expectedTx) + + require.Equal(t, expectedKeyUID, response.KeyUID) + require.Equal(t, expectedAddress, response.Address) + require.Equal(t, expectedAddressPath, response.AddressPath) + require.Equal(t, expectedSignOnKeycard, response.SignOnKeycard) + require.Equal(t, chainID, response.ChainID) + require.Equal(t, expectedMessageToSign, response.MessageToSign) + require.True(t, reflect.DeepEqual(sendArgs, response.TxArgs)) +} + +func TestBuildTransaction_AccountNotFound(t *testing.T) { + manager, _ := setupTestSuite(t) + + chainID := uint64(1) + nonce := uint64(1) + gas := uint64(21000) + sendArgs := transactions.SendTxArgs{ + From: types.Address{2}, + To: &types.Address{2}, + Value: (*hexutil.Big)(big.NewInt(123)), + Nonce: (*hexutil.Uint64)(&nonce), + Gas: (*hexutil.Uint64)(&gas), + GasPrice: (*hexutil.Big)(big.NewInt(1000000000)), + MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)), + MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)), + } + + _, err := manager.BuildTransaction(chainID, sendArgs) + require.Error(t, err) +} + +func TestBuildTransaction_InvalidSendTxArgs(t *testing.T) { + manager, transactor := setupTestSuite(t) + + chainID := uint64(1) + sendArgs := transactions.SendTxArgs{ + From: types.Address{1}, + To: &types.Address{2}, + } + + expectedErr := fmt.Errorf("invalid SendTxArgs") + transactor.EXPECT().ValidateAndBuildTransaction(chainID, sendArgs).Return(nil, expectedErr) + tx, err := manager.BuildTransaction(chainID, sendArgs) + require.Equal(t, expectedErr, err) + require.Nil(t, tx) +} + +func TestBuildRawTransaction(t *testing.T) { + manager, transactor := setupTestSuite(t) + + chainID := uint64(1) + nonce := uint64(1) + gas := uint64(21000) + sendArgs := transactions.SendTxArgs{ + From: types.Address{1}, + To: &types.Address{2}, + Value: (*hexutil.Big)(big.NewInt(123)), + Nonce: (*hexutil.Uint64)(&nonce), + Gas: (*hexutil.Uint64)(&gas), + GasPrice: (*hexutil.Big)(big.NewInt(1000000000)), + MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)), + MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)), + } + + expectedTx := gethtypes.NewTransaction(1, common.Address(*sendArgs.To), sendArgs.Value.ToInt(), 21000, sendArgs.GasPrice.ToInt(), nil) + signature := []byte("signature") + transactor.EXPECT().BuildTransactionWithSignature(chainID, sendArgs, signature).Return(expectedTx, nil) + + response, err := manager.BuildRawTransaction(chainID, sendArgs, signature) + require.NoError(t, err) + require.NotNil(t, response) + + expectedData, _ := expectedTx.MarshalBinary() + expectedHash := expectedTx.Hash() + + require.Equal(t, chainID, response.ChainID) + require.Equal(t, sendArgs, response.TxArgs) + require.Equal(t, types.EncodeHex(expectedData), response.RawTx) + require.Equal(t, expectedHash, response.TxHash) +}