From ed164e4ac57f6cf0e9c4ee719de32b73867ee24e Mon Sep 17 00:00:00 2001 From: Ivan Belyakov Date: Thu, 2 May 2024 17:36:42 +0200 Subject: [PATCH] chore(wallet)_: clean up wallet API send and sign transactions --- services/wallet/activity/filter.go | 2 +- services/wallet/activity/filter.sql | 10 +- services/wallet/activity/recipients.sql | 6 +- services/wallet/api.go | 61 +++++- services/wallet/service.go | 6 + services/wallet/transfer/bridge_identifier.go | 63 +++--- services/wallet/transfer/controller_test.go | 72 ++++--- .../wallet/transfer/transaction_manager.go | 29 ++- .../transaction_manager_multitransaction.go | 198 ++++++++---------- .../transfer/transaction_manager_test.go | 99 ++++----- 10 files changed, 307 insertions(+), 239 deletions(-) diff --git a/services/wallet/activity/filter.go b/services/wallet/activity/filter.go index 0b7229c64..b2042e976 100644 --- a/services/wallet/activity/filter.go +++ b/services/wallet/activity/filter.go @@ -152,7 +152,7 @@ func GetRecipients(ctx context.Context, db *sql.DB, chainIDs []common.ChainID, a return entries, hasMore, nil } -func GetOldestTimestamp(ctx context.Context, db *sql.DB, addresses []eth.Address) (timestamp int64, err error) { +func GetOldestTimestamp(ctx context.Context, db *sql.DB, addresses []eth.Address) (timestamp uint64, err error) { filterAllAddresses := len(addresses) == 0 involvedAddresses := noEntriesInTmpTableSQLValues if !filterAllAddresses { diff --git a/services/wallet/activity/filter.sql b/services/wallet/activity/filter.sql index 336d707d4..900be2144 100644 --- a/services/wallet/activity/filter.sql +++ b/services/wallet/activity/filter.sql @@ -445,7 +445,7 @@ SELECT NULL AS transfer_hash, NULL AS pending_hash, NULL AS network_id, - multi_transactions.ROWID AS multi_tx_id, + multi_transactions.id AS multi_tx_id, multi_transactions.timestamp AS timestamp, multi_transactions.type AS mt_type, NULL as tr_type, @@ -488,8 +488,8 @@ SELECT FROM multi_transactions CROSS JOIN filter_conditions - LEFT JOIN tr_status ON multi_transactions.ROWID = tr_status.multi_transaction_id - LEFT JOIN pending_status ON multi_transactions.ROWID = pending_status.multi_transaction_id + LEFT JOIN tr_status ON multi_transactions.id = tr_status.multi_transaction_id + LEFT JOIN pending_status ON multi_transactions.id = pending_status.multi_transaction_id WHERE ( ( @@ -577,7 +577,7 @@ WHERE FROM tr_network_ids WHERE - multi_transactions.ROWID = tr_network_ids.multi_transaction_id + multi_transactions.id = tr_network_ids.multi_transaction_id ) OR EXISTS ( SELECT @@ -585,7 +585,7 @@ WHERE FROM pending_network_ids WHERE - multi_transactions.ROWID = pending_network_ids.multi_transaction_id + multi_transactions.id = pending_network_ids.multi_transaction_id ) ) ) diff --git a/services/wallet/activity/recipients.sql b/services/wallet/activity/recipients.sql index 6c35924f5..eb65e2cb5 100644 --- a/services/wallet/activity/recipients.sql +++ b/services/wallet/activity/recipients.sql @@ -91,7 +91,7 @@ FROM ( FROM tr_network_ids WHERE - multi_transactions.ROWID = tr_network_ids.multi_transaction_id + multi_transactions.id = tr_network_ids.multi_transaction_id ) OR EXISTS ( SELECT @@ -99,7 +99,7 @@ FROM ( FROM pending_network_ids WHERE - multi_transactions.ROWID = pending_network_ids.multi_transaction_id + multi_transactions.id = pending_network_ids.multi_transaction_id ) ) ) @@ -111,4 +111,4 @@ GROUP BY to_address ORDER BY min_timestamp DESC -LIMIT ? OFFSET ?; \ No newline at end of file +LIMIT ? OFFSET ?; diff --git a/services/wallet/api.go b/services/wallet/api.go index d8b9520ad..762c761e1 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -15,6 +15,7 @@ import ( "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/log" gethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc/network" @@ -526,7 +527,13 @@ func (api *API) GetAddressDetails(ctx context.Context, chainID uint64, address s func (api *API) SignMessage(ctx context.Context, message types.HexBytes, address common.Address, password string) (string, error) { log.Debug("[WalletAPI::SignMessage]", "message", message, "address", address) - return api.s.transactionManager.SignMessage(message, address, password) + + selectedAccount, err := api.s.gethManager.VerifyAccountPassword(api.s.Config().KeyStoreDir, address.Hex(), password) + if err != nil { + return "", err + } + + return api.s.transactionManager.SignMessage(message, selectedAccount) } func (api *API) BuildTransaction(ctx context.Context, chainID uint64, sendTxArgsJSON string) (response *transfer.TxResponse, err error) { @@ -574,7 +581,32 @@ func (api *API) SendTransactionWithSignature(ctx context.Context, chainID uint64 func (api *API) CreateMultiTransaction(ctx context.Context, multiTransactionCommand *transfer.MultiTransactionCommand, data []*bridge.TransactionBridge, password string) (*transfer.MultiTransactionCommandResult, error) { log.Debug("[WalletAPI:: CreateMultiTransaction] create multi transaction") - return api.s.transactionManager.CreateMultiTransactionFromCommand(ctx, multiTransactionCommand, data, api.router.bridges, password) + + cmd, err := api.s.transactionManager.CreateMultiTransactionFromCommand(ctx, multiTransactionCommand, data) + if err != nil { + return nil, err + } + + if password != "" { + selectedAccount, err := api.getVerifiedWalletAccount(multiTransactionCommand.FromAddress.Hex(), password) + if err != nil { + return nil, err + } + + cmdRes, err := api.s.transactionManager.SendTransactions(ctx, cmd, data, api.router.bridges, selectedAccount) + if err != nil { + return nil, err + } + + _, err = api.s.transactionManager.InsertMultiTransaction(cmd) + if err != nil { + return nil, err + } + + return cmdRes, nil + } + + return nil, api.s.transactionManager.SendTransactionForSigningToKeycard(ctx, cmd, data, api.router.bridges) } func (api *API) ProceedWithTransactionsSignatures(ctx context.Context, signatures map[string]transfer.SignatureDetails) (*transfer.MultiTransactionCommandResult, error) { @@ -740,3 +772,28 @@ func (api *API) WCAuthRequest(ctx context.Context, address common.Address, authM return api.s.walletConnect.AuthRequest(address, authMessage) } + +func (api *API) getVerifiedWalletAccount(address, password string) (*account.SelectedExtKey, error) { + exists, err := api.s.accountsDB.AddressExists(types.HexToAddress(address)) + if err != nil { + log.Error("failed to query db for a given address", "address", address, "error", err) + return nil, err + } + + if !exists { + log.Error("failed to get a selected account", "err", transactions.ErrInvalidTxSender) + return nil, transactions.ErrAccountDoesntExist + } + + keyStoreDir := api.s.Config().KeyStoreDir + key, err := api.s.gethManager.VerifyAccountPassword(keyStoreDir, address, password) + if err != nil { + log.Error("failed to verify account", "account", address, "error", err) + return nil, err + } + + return &account.SelectedExtKey{ + Address: key.Address, + AccountKey: key, + }, nil +} diff --git a/services/wallet/service.go b/services/wallet/service.go index decec5175..32110fcad 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -201,6 +201,7 @@ func NewService( blockChainState: blockChainState, keycardPairings: NewKeycardPairings(), walletConnect: walletconnect, + config: config, } } @@ -235,6 +236,7 @@ type Service struct { blockChainState *blockchainstate.BlockChainState keycardPairings *KeycardPairings walletConnect *walletconnect.Service + config *params.NodeConfig } // Start signals transmitter. @@ -293,3 +295,7 @@ func (s *Service) IsStarted() bool { func (s *Service) KeycardPairings() *KeycardPairings { return s.keycardPairings } + +func (s *Service) Config() *params.NodeConfig { + return s.config +} diff --git a/services/wallet/transfer/bridge_identifier.go b/services/wallet/transfer/bridge_identifier.go index 3dbe94735..76207e30d 100644 --- a/services/wallet/transfer/bridge_identifier.go +++ b/services/wallet/transfer/bridge_identifier.go @@ -45,23 +45,21 @@ func upsertHopBridgeOriginTx(ctx context.Context, transactionManager *Transactio } if multiTx == nil { - multiTx = &MultiTransaction{ - // Data from "origin" transaction - FromNetworkID: params.fromNetworkID, - FromTxHash: params.fromTxHash, - FromAddress: params.fromAddress, - FromAsset: params.fromAsset, - FromAmount: (*hexutil.Big)(params.fromAmount), - ToNetworkID: params.toNetworkID, - ToAddress: params.toAddress, - // To be replaced by "destination" transaction, need to be non-null - ToAsset: params.fromAsset, - ToAmount: (*hexutil.Big)(params.fromAmount), - // Common data - Type: MultiTransactionBridge, - CrossTxID: params.crossTxID, - Timestamp: params.timestamp, - } + multiTx = NewMultiTransaction( + /* Timestamp: */ params.timestamp, // Common data + /* FromNetworkID: */ params.fromNetworkID, // Data from "origin" transaction + /* ToNetworkID: */ params.toNetworkID, // Data from "origin" transaction + /* FromTxHash: */ params.fromTxHash, // Data from "origin" transaction + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ params.fromAddress, // Data from "origin" transaction + /* ToAddress: */ params.toAddress, // Data from "origin" transaction + /* FromAsset: */ params.fromAsset, // Data from "origin" transaction + /* ToAsset: */ params.fromAsset, // To be replaced by "destination" transaction, need to be non-null + /* FromAmount: */ (*hexutil.Big)(params.fromAmount), // Data from "origin" transaction + /* ToAmount: */ (*hexutil.Big)(params.fromAmount), // To be replaced by "destination" transaction, need to be non-null + /* Type: */ MultiTransactionBridge, // Common data + /* CrossTxID: */ params.crossTxID, // Common data + ) _, err := transactionManager.InsertMultiTransaction(multiTx) if err != nil { @@ -102,22 +100,21 @@ func upsertHopBridgeDestinationTx(ctx context.Context, transactionManager *Trans } if multiTx == nil { - multiTx = &MultiTransaction{ - // To be replaced by "origin" transaction, need to be non-null - FromAddress: params.toAddress, - FromAsset: params.toAsset, - FromAmount: (*hexutil.Big)(params.toAmount), - // Data from "destination" transaction - ToNetworkID: params.toNetworkID, - ToTxHash: params.toTxHash, - ToAddress: params.toAddress, - ToAsset: params.toAsset, - ToAmount: (*hexutil.Big)(params.toAmount), - // Common data - Type: MultiTransactionBridge, - CrossTxID: params.crossTxID, - Timestamp: params.timestamp, - } + multiTx = NewMultiTransaction( + /* Timestamp: */ params.timestamp, // Common data + /* FromNetworkID: */ 0, // not set + /* ToNetworkID: */ params.toNetworkID, // Data from "destination" transaction + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ params.toTxHash, // Data from "destination" transaction + /* FromAddress: */ params.toAddress, // To be replaced by "origin" transaction, need to be non-null + /* ToAddress: */ params.toAddress, // Data from "destination" transaction + /* FromAsset: */ params.toAsset, // To be replaced by "origin" transaction, need to be non-null + /* ToAsset: */ params.toAsset, // Data from "destination" transaction + /* FromAmount: */ (*hexutil.Big)(params.toAmount), // To be replaced by "origin" transaction, need to be non-null + /* ToAmount: */ (*hexutil.Big)(params.toAmount), // Data from "destination" transaction + /* Type: */ MultiTransactionBridge, // Common data + /* CrossTxID: */ params.crossTxID, // Common data + ) _, err := transactionManager.InsertMultiTransaction(multiTx) if err != nil { diff --git a/services/wallet/transfer/controller_test.go b/services/wallet/transfer/controller_test.go index 3274d90e3..8fc8d8543 100644 --- a/services/wallet/transfer/controller_test.go +++ b/services/wallet/transfer/controller_test.go @@ -94,15 +94,21 @@ func TestController_watchAccountsChanges(t *testing.T) { require.NoError(t, err) // Self multi transaction - midSelf, err := transactionManager.InsertMultiTransaction(&MultiTransaction{ - FromAddress: address, - ToAddress: address, - FromAsset: "ETH", - ToAsset: "DAI", - FromAmount: &hexutil.Big{}, - ToAmount: &hexutil.Big{}, - Timestamp: 1, - }) + midSelf, err := transactionManager.InsertMultiTransaction(NewMultiTransaction( + /* Timestamp: */ 1, + /* FromNetworkID: */ 1, + /* ToNetworkID: */ 1, + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ address, + /* ToAddress: */ address, + /* FromAsset: */ "ETH", + /* ToAsset: */ "DAI", + /* FromAmount: */ &hexutil.Big{}, + /* ToAmount: */ &hexutil.Big{}, + /* Type: */ MultiTransactionSend, + /* CrossTxID: */ "", + )) require.NoError(t, err) mtxs, err := transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf}) @@ -110,15 +116,21 @@ func TestController_watchAccountsChanges(t *testing.T) { require.Len(t, mtxs, 1) // Send multi transaction - mt := &MultiTransaction{ - FromAddress: address, - ToAddress: counterparty, - FromAsset: "ETH", - ToAsset: "DAI", - FromAmount: &hexutil.Big{}, - ToAmount: &hexutil.Big{}, - Timestamp: 2, - } + mt := NewMultiTransaction( + /* Timestamp: */ 2, + /* FromNetworkID: */ 1, + /* ToNetworkID: */ 1, + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ address, + /* ToAddress: */ counterparty, + /* FromAsset: */ "ETH", + /* ToAsset: */ "DAI", + /* FromAmount: */ &hexutil.Big{}, + /* ToAmount: */ &hexutil.Big{}, + /* Type: */ MultiTransactionSend, + /* CrossTxID: */ "", + ) mid, err := transactionManager.InsertMultiTransaction(mt) require.NoError(t, err) @@ -127,15 +139,21 @@ func TestController_watchAccountsChanges(t *testing.T) { require.Len(t, mtxs, 2) // Another Send multi-transaction where sender and receiver are inverted (both accounts are in accounts DB) - midReverse, err := transactionManager.InsertMultiTransaction(&MultiTransaction{ - FromAddress: mt.ToAddress, - ToAddress: mt.FromAddress, - FromAsset: mt.FromAsset, - ToAsset: mt.ToAsset, - FromAmount: mt.FromAmount, - ToAmount: mt.ToAmount, - Timestamp: mt.Timestamp + 1, - }) + midReverse, err := transactionManager.InsertMultiTransaction(NewMultiTransaction( + /* Timestamp: */ mt.Timestamp+1, + /* FromNetworkID: */ 1, + /* ToNetworkID: */ 1, + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ mt.ToAddress, + /* ToAddress: */ mt.FromAddress, + /* FromAsset: */ mt.FromAsset, + /* ToAsset: */ mt.ToAsset, + /* FromAmount: */ mt.FromAmount, + /* ToAmount: */ mt.ToAmount, + /* Type: */ MultiTransactionSend, + /* CrossTxID: */ "", + )) require.NoError(t, err) mtxs, err = transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf, mid, midReverse}) diff --git a/services/wallet/transfer/transaction_manager.go b/services/wallet/transfer/transaction_manager.go index 2908e92da..072aa7a1a 100644 --- a/services/wallet/transfer/transaction_manager.go +++ b/services/wallet/transfer/transaction_manager.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "math/big" + "time" ethTypes "github.com/ethereum/go-ethereum/core/types" @@ -128,13 +129,31 @@ type TxResponse struct { TxHash common.Hash `json:"txHash,omitempty"` } -func (tm *TransactionManager) SignMessage(message types.HexBytes, address common.Address, password string) (string, error) { - selectedAccount, err := tm.gethManager.VerifyAccountPassword(tm.config.KeyStoreDir, address.Hex(), password) - if err != nil { - return "", err +func NewMultiTransaction(timestamp uint64, fromNetworkID, toNetworkID uint64, fromTxHash, toTxHash common.Hash, fromAddress, toAddress common.Address, fromAsset, toAsset string, fromAmount, toAmount *hexutil.Big, txType MultiTransactionType, crossTxID string) *MultiTransaction { + if timestamp == 0 { + timestamp = uint64(time.Now().Unix()) } - signature, err := crypto.Sign(message[:], selectedAccount.PrivateKey) + return &MultiTransaction{ + ID: multiTransactionIDGenerator(), + Timestamp: timestamp, + FromNetworkID: fromNetworkID, + ToNetworkID: toNetworkID, + FromTxHash: fromTxHash, + ToTxHash: toTxHash, + FromAddress: fromAddress, + ToAddress: toAddress, + FromAsset: fromAsset, + ToAsset: toAsset, + FromAmount: fromAmount, + ToAmount: toAmount, + Type: txType, + CrossTxID: crossTxID, + } +} + +func (tm *TransactionManager) SignMessage(message types.HexBytes, account *types.Key) (string, error) { + signature, err := crypto.Sign(message[:], account.PrivateKey) return types.EncodeHex(signature), err } diff --git a/services/wallet/transfer/transaction_manager_multitransaction.go b/services/wallet/transfer/transaction_manager_multitransaction.go index 88a41a467..3afb559bd 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction.go +++ b/services/wallet/transfer/transaction_manager_multitransaction.go @@ -21,11 +21,10 @@ import ( "github.com/status-im/status-go/services/wallet/bridge" wallet_common "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/signal" - "github.com/status-im/status-go/transactions" ) -const multiTransactionColumns = "from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" -const selectMultiTransactionColumns = "COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" +const multiTransactionColumns = "id, from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" +const selectMultiTransactionColumns = "id, COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { var multiTransactions []*MultiTransaction @@ -79,22 +78,15 @@ func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { return multiTransactions, nil } -func getMultiTransactionTimestamp(multiTransaction *MultiTransaction) uint64 { - if multiTransaction.Timestamp != 0 { - return multiTransaction.Timestamp - } - return uint64(time.Now().Unix()) -} - -// insertMultiTransaction inserts a multi transaction into the database and updates multi-transaction ID and timestamp -func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) { +// insertMultiTransaction inserts a multi transaction into the database and updates timestamp +func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error { insert, err := db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) if err != nil { - return wallet_common.NoMultiTransactionID, err + return err } - timestamp := getMultiTransactionTimestamp(multiTransaction) - result, err := insert.Exec( + _, err = insert.Exec( + multiTransaction.ID, multiTransaction.FromNetworkID, multiTransaction.FromTxHash, multiTransaction.FromAddress, @@ -107,22 +99,18 @@ func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (wal multiTransaction.ToAmount.String(), multiTransaction.Type, multiTransaction.CrossTxID, - timestamp, + multiTransaction.Timestamp, ) if err != nil { - return wallet_common.NoMultiTransactionID, err + return err } defer insert.Close() - multiTransactionID, err := result.LastInsertId() - multiTransaction.Timestamp = timestamp - multiTransaction.ID = wallet_common.MultiTransactionIDType(multiTransactionID) - - return wallet_common.MultiTransactionIDType(multiTransactionID), err + return err } func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) { - return insertMultiTransaction(tm.db, multiTransaction) + return multiTransaction.ID, insertMultiTransaction(tm.db, multiTransaction) } func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error { @@ -130,13 +118,12 @@ func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) erro return fmt.Errorf("no multitransaction ID") } - update, err := db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (rowid, %s) + update, err := db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (%s) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) if err != nil { return err } - timestamp := getMultiTransactionTimestamp(multiTransaction) _, err = update.Exec( multiTransaction.ID, multiTransaction.FromNetworkID, @@ -151,7 +138,7 @@ func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) erro multiTransaction.ToAmount.String(), multiTransaction.Type, multiTransaction.CrossTxID, - timestamp, + multiTransaction.Timestamp, ) if err != nil { return err @@ -163,55 +150,53 @@ func (tm *TransactionManager) UpdateMultiTransaction(multiTransaction *MultiTran return updateMultiTransaction(tm.db, multiTransaction) } -// In case of keycard account, password should be empty func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand, - data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionCommandResult, error) { + data []*bridge.TransactionBridge) (*MultiTransaction, error) { multiTransaction := multiTransactionFromCommand(command) if multiTransaction.Type == MultiTransactionSend && multiTransaction.FromNetworkID == 0 && len(data) == 1 { multiTransaction.FromNetworkID = data[0].ChainID } - multiTransactionID, err := insertMultiTransaction(tm.db, multiTransaction) + + return multiTransaction, nil +} + +func (tm *TransactionManager) SendTransactionForSigningToKeycard(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge) error { + acc, err := tm.accountsDB.GetAccountByAddress(types.Address(multiTransaction.FromAddress)) if err != nil { - return nil, err + return err } - multiTransaction.ID = multiTransactionID - if password == "" { - acc, err := tm.accountsDB.GetAccountByAddress(types.Address(multiTransaction.FromAddress)) - if err != nil { - return nil, err - } - - kp, err := tm.accountsDB.GetKeypairByKeyUID(acc.KeyUID) - if err != nil { - return nil, err - } - - if !kp.MigratedToKeycard() { - return nil, fmt.Errorf("account being used is not migrated to a keycard, password is required") - } - - tm.multiTransactionForKeycardSigning = multiTransaction - tm.transactionsBridgeData = data - hashes, err := tm.buildTransactions(bridges) - if err != nil { - return nil, err - } - - signal.SendTransactionsForSigningEvent(hashes) - - return nil, nil + kp, err := tm.accountsDB.GetKeypairByKeyUID(acc.KeyUID) + if err != nil { + return err } - hashes, err := tm.sendTransactions(multiTransaction, data, bridges, password) + if !kp.MigratedToKeycard() { + return fmt.Errorf("account being used is not migrated to a keycard, password is required") + } + + tm.multiTransactionForKeycardSigning = multiTransaction + tm.transactionsBridgeData = data + hashes, err := tm.buildTransactions(bridges) + if err != nil { + return err + } + + signal.SendTransactionsForSigningEvent(hashes) + + return nil +} + +func (tm *TransactionManager) SendTransactions(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) (*MultiTransactionCommandResult, error) { + hashes, err := tm.sendTransactions(multiTransaction, data, bridges, account) if err != nil { return nil, err } return &MultiTransactionCommandResult{ - ID: int64(multiTransactionID), + ID: int64(multiTransaction.ID), Hashes: hashes, }, nil } @@ -267,6 +252,11 @@ func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Cont hashes[desc.chainID] = append(hashes[desc.chainID], hash) } + _, err := tm.InsertMultiTransaction(tm.multiTransactionForKeycardSigning) + if err != nil { + log.Error("failed to insert multi transaction", "err", err) + } + return &MultiTransactionCommandResult{ ID: int64(tm.multiTransactionForKeycardSigning.ID), Hashes: hashes, @@ -274,18 +264,21 @@ func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Cont } func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { - - log.Info("Creating multi transaction", "command", command) - - multiTransaction := &MultiTransaction{ - FromAddress: command.FromAddress, - ToAddress: command.ToAddress, - FromAsset: command.FromAsset, - ToAsset: command.ToAsset, - FromAmount: command.FromAmount, - ToAmount: new(hexutil.Big), - Type: command.Type, - } + multiTransaction := NewMultiTransaction( + /* Timestamp: */ uint64(time.Now().Unix()), + /* FromNetworkID: */ 0, + /* ToNetworkID: */ 0, + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ command.FromAddress, + /* ToAddress: */ command.ToAddress, + /* FromAsset: */ command.FromAsset, + /* ToAsset: */ command.ToAsset, + /* FromAmount: */ command.FromAmount, + /* ToAmount: */ new(hexutil.Big), + /* Type: */ command.Type, + /* CrossTxID: */ "", + ) return multiTransaction } @@ -315,16 +308,9 @@ func (tm *TransactionManager) buildTransactions(bridges map[string]bridge.Bridge } func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransaction, - data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) ( + data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) ( map[uint64][]types.Hash, error) { - log.Info("Making transactions", "multiTransaction", multiTransaction) - - selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password) - if err != nil { - return nil, err - } - hashes := make(map[uint64][]types.Hash) for _, tx := range data { if tx.TransferTx != nil { @@ -352,9 +338,9 @@ func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransactio tx.SwapTx.Symbol = multiTransaction.FromAsset } - hash, err := bridges[tx.BridgeName].Send(tx, selectedAccount) + hash, err := bridges[tx.BridgeName].Send(tx, account) if err != nil { - return nil, err + return nil, err // TODO: One of transfers within transaction could have been sent. Need to notify user about it } hashes[tx.ChainID] = append(hashes[tx.ChainID], hash) } @@ -369,9 +355,9 @@ func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []wa args[i] = v } - stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT rowid, %s + stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s FROM multi_transactions - WHERE rowid in (%s)`, + WHERE id in (%s)`, selectMultiTransactionColumns, strings.Join(placeholders, ","))) if err != nil { @@ -389,7 +375,7 @@ func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []wa } func (tm *TransactionManager) getBridgeMultiTransactions(ctx context.Context, toChainID uint64, crossTxID string) ([]*MultiTransaction, error) { - stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT rowid, %s + stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s FROM multi_transactions WHERE type=? AND to_network_id=? AND cross_tx_id=?`, multiTransactionColumns)) @@ -439,37 +425,19 @@ func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.C return nil, nil } -func (tm *TransactionManager) getVerifiedWalletAccount(address, password string) (*account.SelectedExtKey, error) { - exists, err := tm.accountsDB.AddressExists(types.HexToAddress(address)) - if err != nil { - log.Error("failed to query db for a given address", "address", address, "error", err) - return nil, err - } - - if !exists { - log.Error("failed to get a selected account", "err", transactions.ErrInvalidTxSender) - return nil, transactions.ErrAccountDoesntExist - } - - key, err := tm.gethManager.VerifyAccountPassword(tm.config.KeyStoreDir, address, password) - if err != nil { - log.Error("failed to verify account", "account", address, "error", err) - return nil, err - } - - return &account.SelectedExtKey{ - Address: key.Address, - AccountKey: key, - }, nil +func idFromTimestamp() wallet_common.MultiTransactionIDType { + return wallet_common.MultiTransactionIDType(time.Now().UnixMilli()) } +var multiTransactionIDGenerator func() wallet_common.MultiTransactionIDType = idFromTimestamp + func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Address) error { // We must not remove those transactions, where from_address and to_address are different and both are stored in accounts DB // and one of them is equal to the address, as we want to keep the records for the other address // That is why we don't use cascade delete here with references to transfers table, as we might have 2 records in multi_transactions // for the same transaction, one for each address - stmt, err := tm.db.Prepare(`SELECT rowid, from_address, to_address + stmt, err := tm.db.Prepare(`SELECT id, from_address, to_address FROM multi_transactions WHERE from_address=? OR to_address=?`) if err != nil { @@ -482,10 +450,10 @@ func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Add } defer rows.Close() - rowIDs := make([]int, 0) - rowID, fromAddress, toAddress := 0, common.Address{}, common.Address{} + ids := make([]int, 0) + id, fromAddress, toAddress := 0, common.Address{}, common.Address{} for rows.Next() { - err = rows.Scan(&rowID, &fromAddress, &toAddress) + err = rows.Scan(&id, &fromAddress, &toAddress) if err != nil { log.Error("Failed to scan row", "error", err) continue @@ -512,14 +480,14 @@ func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Add } } - rowIDs = append(rowIDs, rowID) + ids = append(ids, id) } - if len(rowIDs) > 0 { - for _, rowID := range rowIDs { - _, err = tm.db.Exec(`DELETE FROM multi_transactions WHERE rowid=?`, rowID) + if len(ids) > 0 { + for _, id := range ids { + _, err = tm.db.Exec(`DELETE FROM multi_transactions WHERE id=?`, id) if err != nil { - log.Error("Failed to remove multitransaction", "rowid", rowID, "error", err) + log.Error("Failed to remove multitransaction", "id", id, "error", err) } } } diff --git a/services/wallet/transfer/transaction_manager_test.go b/services/wallet/transfer/transaction_manager_test.go index 3fa698a78..a0fe6cb43 100644 --- a/services/wallet/transfer/transaction_manager_test.go +++ b/services/wallet/transfer/transaction_manager_test.go @@ -43,39 +43,39 @@ func TestBridgeMultiTransactions(t *testing.T) { manager, stop := setupTestTransactionDB(t) defer stop() - trx1 := MultiTransaction{ - Timestamp: 123, - FromNetworkID: 0, - ToNetworkID: 1, - FromTxHash: common.Hash{5}, - // Empty ToTxHash - FromAddress: common.Address{1}, - ToAddress: common.Address{2}, - FromAsset: "fromAsset", - ToAsset: "toAsset", - FromAmount: (*hexutil.Big)(big.NewInt(123)), - ToAmount: (*hexutil.Big)(big.NewInt(234)), - Type: MultiTransactionBridge, - CrossTxID: "crossTxD1", - } + trx1 := NewMultiTransaction( + /* Timestamp: */ 123, + /* FromNetworkID: */ 0, + /* ToNetworkID: */ 1, + /* FromTxHash: */ common.Hash{5}, + /* // Empty ToTxHash */ common.Hash{}, + /* FromAddress: */ common.Address{1}, + /* ToAddress: */ common.Address{2}, + /* FromAsset: */ "fromAsset", + /* ToAsset: */ "toAsset", + /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)), + /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)), + /* Type: */ MultiTransactionBridge, + /* CrossTxID: */ "crossTxD1", + ) - trx2 := MultiTransaction{ - Timestamp: 321, - FromNetworkID: 1, - ToNetworkID: 0, - //Empty FromTxHash - ToTxHash: common.Hash{6}, - FromAddress: common.Address{2}, - ToAddress: common.Address{1}, - FromAsset: "fromAsset", - ToAsset: "toAsset", - FromAmount: (*hexutil.Big)(big.NewInt(123)), - ToAmount: (*hexutil.Big)(big.NewInt(234)), - Type: MultiTransactionBridge, - CrossTxID: "crossTxD2", - } + trx2 := NewMultiTransaction( + /* Timestamp: */ 321, + /* FromNetworkID: */ 1, + /* ToNetworkID: */ 0, + /* //Empty FromTxHash */ common.Hash{}, + /* ToTxHash: */ common.Hash{6}, + /* FromAddress: */ common.Address{2}, + /* ToAddress: */ common.Address{1}, + /* FromAsset: */ "fromAsset", + /* ToAsset: */ "toAsset", + /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)), + /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)), + /* Type: */ MultiTransactionBridge, + /* CrossTxID: */ "crossTxD2", + ) - trxs := []*MultiTransaction{&trx1, &trx2} + trxs := []*MultiTransaction{trx1, trx2} var err error ids := make([]wallet_common.MultiTransactionIDType, len(trxs)) @@ -88,7 +88,7 @@ func TestBridgeMultiTransactions(t *testing.T) { rst, err := manager.GetBridgeOriginMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID) require.NoError(t, err) require.NotEmpty(t, rst) - require.True(t, areMultiTransactionsEqual(&trx1, rst)) + require.True(t, areMultiTransactionsEqual(trx1, rst)) rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID) require.NoError(t, err) @@ -101,31 +101,34 @@ func TestBridgeMultiTransactions(t *testing.T) { rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx2.ToNetworkID, trx2.CrossTxID) require.NoError(t, err) require.NotEmpty(t, rst) - require.True(t, areMultiTransactionsEqual(&trx2, rst)) + require.True(t, areMultiTransactionsEqual(trx2, rst)) } func TestMultiTransactions(t *testing.T) { manager, stop := setupTestTransactionDB(t) defer stop() - trx1 := MultiTransaction{ - Timestamp: 123, - FromNetworkID: 0, - ToNetworkID: 1, - FromTxHash: common.Hash{5}, - ToTxHash: common.Hash{6}, - FromAddress: common.Address{1}, - ToAddress: common.Address{2}, - FromAsset: "fromAsset", - ToAsset: "toAsset", - FromAmount: (*hexutil.Big)(big.NewInt(123)), - ToAmount: (*hexutil.Big)(big.NewInt(234)), - Type: MultiTransactionBridge, - CrossTxID: "crossTxD", - } + trx1 := *NewMultiTransaction( + /* Timestamp: */ 123, + /* FromNetworkID:*/ 0, + /* ToNetworkID: */ 1, + /* FromTxHash: */ common.Hash{5}, + /* ToTxHash: */ common.Hash{6}, + /* FromAddress: */ common.Address{1}, + /* ToAddress: */ common.Address{2}, + /* FromAsset: */ "fromAsset", + /* ToAsset: */ "toAsset", + /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)), + /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)), + /* Type: */ MultiTransactionBridge, + /* CrossTxID: */ "crossTxD", + ) trx2 := trx1 trx2.FromAmount = (*hexutil.Big)(big.NewInt(456)) trx2.ToAmount = (*hexutil.Big)(big.NewInt(567)) + trx2.ID = generateMultiTransactionID() + + require.NotEqual(t, trx1.ID, trx2.ID) trxs := []*MultiTransaction{&trx1, &trx2}