From 4d1149100f2073aec582334ee92710b18a9a3ed6 Mon Sep 17 00:00:00 2001 From: Ivan Belyakov Date: Thu, 23 May 2024 20:22:57 +0400 Subject: [PATCH] chore(wallet)_: code structure improved for multi_transaction manager - exported API methods left at the same place - private methods moved to helpers.go - stuff for testing moved to testutils.go - created storage interface with clean API and multi transaction related db calls moved to MultiTransactionDBStorage implementation - created dummy in-mem storage for tests with multi transactions - written tests for MultiTransactionDBStorage --- services/wallet/api.go | 2 +- services/wallet/service.go | 2 +- .../transfer/commands_sequential_test.go | 2 +- services/wallet/transfer/controller_test.go | 4 +- services/wallet/transfer/helpers.go | 219 ++++++++++ .../wallet/transfer/multi_transaction_db.go | 183 ++++++++ .../transfer/multi_transaction_db_test.go | 124 ++++++ services/wallet/transfer/testutils.go | 98 ++++- .../wallet/transfer/transaction_manager.go | 28 +- .../transfer/transaction_manager_internal.go | 33 ++ .../transaction_manager_multitransaction.go | 393 ++---------------- .../transfer/transaction_manager_test.go | 4 +- 12 files changed, 694 insertions(+), 398 deletions(-) create mode 100644 services/wallet/transfer/helpers.go create mode 100644 services/wallet/transfer/multi_transaction_db.go create mode 100644 services/wallet/transfer/multi_transaction_db_test.go create mode 100644 services/wallet/transfer/transaction_manager_internal.go diff --git a/services/wallet/api.go b/services/wallet/api.go index 8f735b122..6b068a4b0 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -626,7 +626,7 @@ func (api *API) CreateMultiTransaction(ctx context.Context, multiTransactionComm _, err = api.s.transactionManager.InsertMultiTransaction(cmd) if err != nil { - return nil, err + log.Error("Failed to save multi transaction", "error", err) // not critical } return cmdRes, nil diff --git a/services/wallet/service.go b/services/wallet/service.go index f6e05ec97..196f35526 100644 --- a/services/wallet/service.go +++ b/services/wallet/service.go @@ -105,7 +105,7 @@ func NewService( tokenManager := token.NewTokenManager(db, rpcClient, communityManager, rpcClient.NetworkManager, appDB, mediaServer, feed, accountFeed, accountsDB) tokenManager.Start() savedAddressesManager := &SavedAddressesManager{db: db} - transactionManager := transfer.NewTransactionManager(db, gethManager, transactor, config, accountsDB, pendingTxManager, feed) + transactionManager := transfer.NewTransactionManager(transfer.NewMultiTransactionDB(db), gethManager, transactor, config, accountsDB, pendingTxManager, feed) blockChainState := blockchainstate.NewBlockChainState() transferController := transfer.NewTransferController(db, accountsDB, rpcClient, accountFeed, feed, transactionManager, pendingTxManager, tokenManager, balanceCacher, blockChainState) diff --git a/services/wallet/transfer/commands_sequential_test.go b/services/wallet/transfer/commands_sequential_test.go index de4485ae0..371286f89 100644 --- a/services/wallet/transfer/commands_sequential_test.go +++ b/services/wallet/transfer/commands_sequential_test.go @@ -1318,7 +1318,7 @@ func TestFetchTransfersForLoadedBlocks(t *testing.T) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) - tm := &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil} + tm := &TransactionManager{NewMultiTransactionDB(db), nil, nil, nil, nil, nil, nil, nil, nil, nil} mediaServer, err := server.NewMediaServer(appdb, nil, nil, db) require.NoError(t, err) diff --git a/services/wallet/transfer/controller_test.go b/services/wallet/transfer/controller_test.go index 531f8cef9..c3befd3dc 100644 --- a/services/wallet/transfer/controller_test.go +++ b/services/wallet/transfer/controller_test.go @@ -36,7 +36,7 @@ func TestController_watchAccountsChanges(t *testing.T) { bcstate := blockchainstate.NewBlockChainState() SetMultiTransactionIDGenerator(StaticIDCounter()) // to have different multi-transaction IDs even with fast execution - transactionManager := NewTransactionManager(walletDB, nil, nil, nil, accountsDB, nil, nil) + transactionManager := NewTransactionManager(NewInMemMultiTransactionStorage(), nil, nil, nil, accountsDB, nil, nil) c := NewTransferController( walletDB, accountsDB, @@ -239,7 +239,7 @@ func TestController_cleanupAccountLeftovers(t *testing.T) { require.NoError(t, err) require.Len(t, storedAccs, 1) - transactionManager := NewTransactionManager(walletDB, nil, nil, nil, accountsDB, nil, nil) + transactionManager := NewTransactionManager(NewMultiTransactionDB(walletDB), nil, nil, nil, accountsDB, nil, nil) bcstate := blockchainstate.NewBlockChainState() c := NewTransferController( walletDB, diff --git a/services/wallet/transfer/helpers.go b/services/wallet/transfer/helpers.go new file mode 100644 index 000000000..1591f58d1 --- /dev/null +++ b/services/wallet/transfer/helpers.go @@ -0,0 +1,219 @@ +package transfer + +import ( + "database/sql" + "encoding/hex" + "errors" + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/log" + "github.com/status-im/status-go/account" + "github.com/status-im/status-go/eth-node/crypto" + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/services/wallet/bridge" + wallet_common "github.com/status-im/status-go/services/wallet/common" +) + +func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { + var multiTransactions []*MultiTransaction + for rows.Next() { + multiTransaction := &MultiTransaction{} + var fromAmountDB, toAmountDB sql.NullString + var fromTxHash, toTxHash sql.RawBytes + err := rows.Scan( + &multiTransaction.ID, + &multiTransaction.FromNetworkID, + &fromTxHash, + &multiTransaction.FromAddress, + &multiTransaction.FromAsset, + &fromAmountDB, + &multiTransaction.ToNetworkID, + &toTxHash, + &multiTransaction.ToAddress, + &multiTransaction.ToAsset, + &toAmountDB, + &multiTransaction.Type, + &multiTransaction.CrossTxID, + &multiTransaction.Timestamp, + ) + if len(fromTxHash) > 0 { + multiTransaction.FromTxHash = common.BytesToHash(fromTxHash) + } + if len(toTxHash) > 0 { + multiTransaction.ToTxHash = common.BytesToHash(toTxHash) + } + if err != nil { + return nil, err + } + + if fromAmountDB.Valid { + multiTransaction.FromAmount = new(hexutil.Big) + if _, ok := (*big.Int)(multiTransaction.FromAmount).SetString(fromAmountDB.String, 0); !ok { + return nil, errors.New("failed to convert fromAmountDB.String to big.Int: " + fromAmountDB.String) + } + } + + if toAmountDB.Valid { + multiTransaction.ToAmount = new(hexutil.Big) + if _, ok := (*big.Int)(multiTransaction.ToAmount).SetString(toAmountDB.String, 0); !ok { + return nil, errors.New("failed to convert fromAmountDB.String to big.Int: " + toAmountDB.String) + } + } + + multiTransactions = append(multiTransactions, multiTransaction) + } + + return multiTransactions, nil +} + +func addSignaturesToTransactions(transactions map[common.Hash]*TransactionDescription, signatures map[string]SignatureDetails) error { + if len(transactions) == 0 { + return errors.New("no transactions to proceed with") + } + if len(signatures) != len(transactions) { + return errors.New("not all transactions have been signed") + } + + // check if all transactions have been signed + for hash, desc := range transactions { + sigDetails, ok := signatures[hash.String()] + if !ok { + return fmt.Errorf("missing signature for transaction %s", hash) + } + + rBytes, _ := hex.DecodeString(sigDetails.R) + sBytes, _ := hex.DecodeString(sigDetails.S) + vByte := byte(0) + if sigDetails.V == "01" { + vByte = 1 + } + + desc.signature = make([]byte, crypto.SignatureLength) + copy(desc.signature[32-len(rBytes):32], rBytes) + copy(desc.signature[64-len(rBytes):64], sBytes) + desc.signature[64] = vByte + } + + return nil +} + +func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { + multiTransaction := NewMultiTransaction( + /* Timestamp: */ uint64(time.Now().Unix()), + /* FromNetworkID: */ 0, + /* ToNetworkID: */ 0, + /* FromTxHash: */ common.Hash{}, + /* ToTxHash: */ common.Hash{}, + /* FromAddress: */ command.FromAddress, + /* ToAddress: */ command.ToAddress, + /* FromAsset: */ command.FromAsset, + /* ToAsset: */ command.ToAsset, + /* FromAmount: */ command.FromAmount, + /* ToAmount: */ new(hexutil.Big), + /* Type: */ command.Type, + /* CrossTxID: */ "", + ) + + return multiTransaction +} + +func updateDataFromMultiTx(data []*bridge.TransactionBridge, multiTransaction *MultiTransaction) { + for _, tx := range data { + if tx.TransferTx != nil { + tx.TransferTx.MultiTransactionID = multiTransaction.ID + tx.TransferTx.Symbol = multiTransaction.FromAsset + } + if tx.HopTx != nil { + tx.HopTx.MultiTransactionID = multiTransaction.ID + tx.HopTx.Symbol = multiTransaction.FromAsset + } + if tx.CbridgeTx != nil { + tx.CbridgeTx.MultiTransactionID = multiTransaction.ID + tx.CbridgeTx.Symbol = multiTransaction.FromAsset + } + if tx.ERC721TransferTx != nil { + tx.ERC721TransferTx.MultiTransactionID = multiTransaction.ID + tx.ERC721TransferTx.Symbol = multiTransaction.FromAsset + } + if tx.ERC1155TransferTx != nil { + tx.ERC1155TransferTx.MultiTransactionID = multiTransaction.ID + tx.ERC1155TransferTx.Symbol = multiTransaction.FromAsset + } + if tx.SwapTx != nil { + tx.SwapTx.MultiTransactionID = multiTransaction.ID + tx.SwapTx.Symbol = multiTransaction.FromAsset + } + } +} + +func sendTransactions(data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) ( + map[uint64][]types.Hash, error) { + + hashes := make(map[uint64][]types.Hash) + for _, tx := range data { + hash, err := bridges[tx.BridgeName].Send(tx, account) + if err != nil { + return nil, err // TODO: One of transfers within transaction could have been sent. Need to notify user about it + } + hashes[tx.ChainID] = append(hashes[tx.ChainID], hash) + } + return hashes, nil +} + +func idFromTimestamp() wallet_common.MultiTransactionIDType { + return wallet_common.MultiTransactionIDType(time.Now().UnixMilli()) +} + +var multiTransactionIDGenerator func() wallet_common.MultiTransactionIDType = idFromTimestamp + +func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Address) error { + // We must not remove those transactions, where from_address and to_address are different and both are stored in accounts DB + // and one of them is equal to the address, as we want to keep the records for the other address + // That is why we don't use cascade delete here with references to transfers table, as we might have 2 records in multi_transactions + // for the same transaction, one for each address + + details := NewMultiTxDetails() + details.FromAddress = address + mtxs, err := tm.storage.ReadMultiTransactionsByDetails(details) + + ids := make([]wallet_common.MultiTransactionIDType, 0) + for _, mtx := range mtxs { + // Remove self transactions as well, leave only those where we have the counterparty in accounts DB + if mtx.FromAddress != mtx.ToAddress { + // If both addresses are stored in accounts DB, we don't remove the record + var addressToCheck common.Address + if mtx.FromAddress == address { + addressToCheck = mtx.ToAddress + } else { + addressToCheck = mtx.FromAddress + } + counterpartyExists, err := tm.accountsDB.AddressExists(types.Address(addressToCheck)) + if err != nil { + log.Error("Failed to query accounts db for a given address", "address", address, "error", err) + continue + } + + // Skip removal if counterparty is in accounts DB and removed address is not sender + if counterpartyExists && address != mtx.FromAddress { + continue + } + } + + ids = append(ids, mtx.ID) + } + + if len(ids) > 0 { + for _, id := range ids { + err = tm.storage.DeleteMultiTransaction(id) + if err != nil { + log.Error("Failed to delete multi transaction", "id", id, "error", err) + } + } + } + + return err +} diff --git a/services/wallet/transfer/multi_transaction_db.go b/services/wallet/transfer/multi_transaction_db.go new file mode 100644 index 000000000..3e3fb32bd --- /dev/null +++ b/services/wallet/transfer/multi_transaction_db.go @@ -0,0 +1,183 @@ +package transfer + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/common" + wallet_common "github.com/status-im/status-go/services/wallet/common" +) + +// DO NOT CREATE IT MANUALLY! Use NewMultiTxDetails() instead +type MultiTxDetails struct { + AnyAddress common.Address + FromAddress common.Address + ToAddress common.Address + ToChainID uint64 + CrossTxID string + Type MultiTransactionType +} + +func NewMultiTxDetails() *MultiTxDetails { + details := &MultiTxDetails{} + details.Type = MultiTransactionTypeInvalid + return details +} + +type MultiTransactionDB struct { + db *sql.DB +} + +func NewMultiTransactionDB(db *sql.DB) *MultiTransactionDB { + return &MultiTransactionDB{ + db: db, + } +} + +func (mtDB *MultiTransactionDB) CreateMultiTransaction(multiTransaction *MultiTransaction) error { + insert, err := mtDB.db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) + if err != nil { + return err + } + _, err = insert.Exec( + multiTransaction.ID, + multiTransaction.FromNetworkID, + multiTransaction.FromTxHash, + multiTransaction.FromAddress, + multiTransaction.FromAsset, + multiTransaction.FromAmount.String(), + multiTransaction.ToNetworkID, + multiTransaction.ToTxHash, + multiTransaction.ToAddress, + multiTransaction.ToAsset, + multiTransaction.ToAmount.String(), + multiTransaction.Type, + multiTransaction.CrossTxID, + multiTransaction.Timestamp, + ) + if err != nil { + return err + } + defer insert.Close() + + return err +} + +func (mtDB *MultiTransactionDB) ReadMultiTransactions(ids []wallet_common.MultiTransactionIDType) ([]*MultiTransaction, error) { + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, v := range ids { + placeholders[i] = "?" + args[i] = v + } + + stmt, err := mtDB.db.Prepare(fmt.Sprintf(`SELECT %s + FROM multi_transactions + WHERE id in (%s)`, + selectMultiTransactionColumns, + strings.Join(placeholders, ","))) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToMultiTransactions(rows) +} + +func (mtDB *MultiTransactionDB) ReadMultiTransactionsByDetails(details *MultiTxDetails) ([]*MultiTransaction, error) { + if details == nil { + return nil, fmt.Errorf("details is nil") + } + + whereClause := "" + + args := []interface{}{} + + if (details.AnyAddress != common.Address{}) { + whereClause += "(from_address=? OR to_address=?) AND " + args = append(args, details.AnyAddress, details.AnyAddress) + } + if (details.FromAddress != common.Address{}) { + whereClause += "from_address=? AND " + args = append(args, details.FromAddress) + } + if (details.ToAddress != common.Address{}) { + whereClause += "to_address=? AND " + args = append(args, details.ToAddress) + } + if details.ToChainID != 0 { + whereClause += "to_network_id=? AND " + args = append(args, details.ToChainID) + } + if details.CrossTxID != "" { + whereClause += "cross_tx_id=? AND " + args = append(args, details.CrossTxID) + } + if details.Type != MultiTransactionTypeInvalid { + whereClause += "type=? AND " + args = append(args, details.Type) + } + + stmt, err := mtDB.db.Prepare(fmt.Sprintf(`SELECT %s + FROM multi_transactions + WHERE %s`, + selectMultiTransactionColumns, whereClause[:len(whereClause)-5])) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rowsToMultiTransactions(rows) +} + +func (mtDB *MultiTransactionDB) UpdateMultiTransaction(multiTransaction *MultiTransaction) error { + if multiTransaction.ID == wallet_common.NoMultiTransactionID { + return fmt.Errorf("no multitransaction ID") + } + + update, err := mtDB.db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (%s) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) + + if err != nil { + return err + } + _, err = update.Exec( + multiTransaction.ID, + multiTransaction.FromNetworkID, + multiTransaction.FromTxHash, + multiTransaction.FromAddress, + multiTransaction.FromAsset, + multiTransaction.FromAmount.String(), + multiTransaction.ToNetworkID, + multiTransaction.ToTxHash, + multiTransaction.ToAddress, + multiTransaction.ToAsset, + multiTransaction.ToAmount.String(), + multiTransaction.Type, + multiTransaction.CrossTxID, + multiTransaction.Timestamp, + ) + if err != nil { + return err + } + return update.Close() +} + +func (mtDB *MultiTransactionDB) DeleteMultiTransaction(id wallet_common.MultiTransactionIDType) error { + _, err := mtDB.db.Exec(`DELETE FROM multi_transactions WHERE id=?`, id) + return err +} diff --git a/services/wallet/transfer/multi_transaction_db_test.go b/services/wallet/transfer/multi_transaction_db_test.go new file mode 100644 index 000000000..090a39cd9 --- /dev/null +++ b/services/wallet/transfer/multi_transaction_db_test.go @@ -0,0 +1,124 @@ +package transfer + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + wallet_common "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/t/helpers" + "github.com/status-im/status-go/walletdatabase" +) + +func setupTestMultiTransactionDB(t *testing.T) (*MultiTransactionDB, func()) { + db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) + require.NoError(t, err) + SetMultiTransactionIDGenerator(StaticIDCounter()) // to have different multi-transaction IDs even with fast execution + return NewMultiTransactionDB(db), func() { + require.NoError(t, db.Close()) + } +} +func TestCreateMultiTransaction(t *testing.T) { + mtDB, cleanup := setupTestMultiTransactionDB(t) + defer cleanup() + + tr := generateTestTransfer(0) + multiTransaction := GenerateTestSendMultiTransaction(tr) + + err := mtDB.CreateMultiTransaction(&multiTransaction) + require.NoError(t, err) + + // Add assertions here to verify the result of the CreateMultiTransaction method + mtx, err := mtDB.ReadMultiTransactions([]wallet_common.MultiTransactionIDType{multiTransaction.ID}) + require.NoError(t, err) + require.Len(t, mtx, 1) + require.True(t, areMultiTransactionsEqual(&multiTransaction, mtx[0])) +} +func TestReadMultiTransactions(t *testing.T) { + mtDB, cleanup := setupTestMultiTransactionDB(t) + defer cleanup() + + // Create test multi transactions + tr := generateTestTransfer(0) + mt1 := GenerateTestSendMultiTransaction(tr) + tr20 := generateTestTransfer(1) + tr21 := generateTestTransfer(2) + mt2 := GenerateTestBridgeMultiTransaction(tr20, tr21) + tr3 := generateTestTransfer(3) + mt3 := GenerateTestSwapMultiTransaction(tr3, "SNT", 100) + + err := mtDB.CreateMultiTransaction(&mt1) + require.NoError(t, err) + err = mtDB.CreateMultiTransaction(&mt2) + require.NoError(t, err) + err = mtDB.CreateMultiTransaction(&mt3) + require.NoError(t, err) + + // Read multi transactions + ids := []wallet_common.MultiTransactionIDType{mt1.ID, mt2.ID, mt3.ID} + mtx, err := mtDB.ReadMultiTransactions(ids) + require.NoError(t, err) + require.Len(t, mtx, 3) + require.True(t, areMultiTransactionsEqual(&mt1, mtx[0])) + require.True(t, areMultiTransactionsEqual(&mt2, mtx[1])) + require.True(t, areMultiTransactionsEqual(&mt3, mtx[2])) +} + +func TestUpdateMultiTransaction(t *testing.T) { + mtDB, cleanup := setupTestMultiTransactionDB(t) + defer cleanup() + + // Create test multi transaction + tr := generateTestTransfer(0) + multiTransaction := GenerateTestSendMultiTransaction(tr) + + err := mtDB.CreateMultiTransaction(&multiTransaction) + require.NoError(t, err) + + // Update the multi transaction + multiTransaction.FromNetworkID = 1 + multiTransaction.FromTxHash = common.Hash{1} + multiTransaction.FromAddress = common.Address{2} + multiTransaction.FromAsset = "fromAsset1" + multiTransaction.FromAmount = (*hexutil.Big)(big.NewInt(234)) + multiTransaction.ToNetworkID = 2 + multiTransaction.ToTxHash = common.Hash{3} + multiTransaction.ToAddress = common.Address{4} + multiTransaction.ToAsset = "toAsset1" + multiTransaction.ToAmount = (*hexutil.Big)(big.NewInt(345)) + multiTransaction.Type = MultiTransactionBridge + multiTransaction.CrossTxID = "crossTxD2" + + err = mtDB.UpdateMultiTransaction(&multiTransaction) + require.NoError(t, err) + + // Read the updated multi transaction + mtx, err := mtDB.ReadMultiTransactions([]wallet_common.MultiTransactionIDType{multiTransaction.ID}) + require.NoError(t, err) + require.Len(t, mtx, 1) + require.True(t, areMultiTransactionsEqual(&multiTransaction, mtx[0])) +} + +func TestDeleteMultiTransaction(t *testing.T) { + mtDB, cleanup := setupTestMultiTransactionDB(t) + defer cleanup() + + // Create test multi transaction + tr := generateTestTransfer(0) + multiTransaction := GenerateTestSendMultiTransaction(tr) + + err := mtDB.CreateMultiTransaction(&multiTransaction) + require.NoError(t, err) + + // Delete the multi transaction + err = mtDB.DeleteMultiTransaction(multiTransaction.ID) + require.NoError(t, err) + + // Read the deleted multi transaction + mtx, err := mtDB.ReadMultiTransactions([]wallet_common.MultiTransactionIDType{multiTransaction.ID}) + require.NoError(t, err) + require.Len(t, mtx, 0) +} diff --git a/services/wallet/transfer/testutils.go b/services/wallet/transfer/testutils.go index af09fc603..834916d4a 100644 --- a/services/wallet/transfer/testutils.go +++ b/services/wallet/transfer/testutils.go @@ -37,20 +37,6 @@ type TestTransfer struct { Token *token.Token } -type TestMultiTransaction struct { - MultiTransactionID common.MultiTransactionIDType - Type MultiTransactionType - FromAddress eth_common.Address - ToAddress eth_common.Address - FromAsset string - ToToken string - FromAmount int64 - ToAmount int64 - Timestamp int64 - FromNetworkID *uint64 - ToNetworkID *uint64 -} - func SeedToToken(seed int) *token.Token { tokenIndex := seed % len(TestTokens) return TestTokens[tokenIndex] @@ -95,6 +81,7 @@ func generateTestTransfer(seed int) TestTransfer { func GenerateTestSendMultiTransaction(tr TestTransfer) MultiTransaction { return MultiTransaction{ + ID: multiTransactionIDGenerator(), Type: MultiTransactionSend, FromAddress: tr.From, ToAddress: tr.To, @@ -108,6 +95,7 @@ func GenerateTestSendMultiTransaction(tr TestTransfer) MultiTransaction { func GenerateTestSwapMultiTransaction(tr TestTransfer, toToken string, toAmount int64) MultiTransaction { return MultiTransaction{ + ID: multiTransactionIDGenerator(), Type: MultiTransactionSwap, FromAddress: tr.From, ToAddress: tr.To, @@ -121,6 +109,7 @@ func GenerateTestSwapMultiTransaction(tr TestTransfer, toToken string, toAmount func GenerateTestBridgeMultiTransaction(fromTr, toTr TestTransfer) MultiTransaction { return MultiTransaction{ + ID: multiTransactionIDGenerator(), Type: MultiTransactionBridge, FromAddress: fromTr.From, ToAddress: toTr.To, @@ -376,7 +365,8 @@ func InsertTestMultiTransaction(tb testing.TB, db *sql.DB, tr *MultiTransaction) } tr.ID = multiTransactionIDGenerator() - err := insertMultiTransaction(db, tr) + multiTxDB := NewMultiTransactionDB(db) + err := multiTxDB.CreateMultiTransaction(tr) require.NoError(tb, err) return tr.ID } @@ -398,3 +388,81 @@ func StaticIDCounter() (f func() common.MultiTransactionIDType) { } return } + +type InMemMultiTransactionStorage struct { + storage map[common.MultiTransactionIDType]*MultiTransaction +} + +func NewInMemMultiTransactionStorage() *InMemMultiTransactionStorage { + return &InMemMultiTransactionStorage{ + storage: make(map[common.MultiTransactionIDType]*MultiTransaction), + } +} + +func (s *InMemMultiTransactionStorage) CreateMultiTransaction(multiTx *MultiTransaction) error { + s.storage[multiTx.ID] = multiTx + return nil +} + +func (s *InMemMultiTransactionStorage) GetMultiTransaction(id common.MultiTransactionIDType) (*MultiTransaction, error) { + multiTx, ok := s.storage[id] + if !ok { + return nil, nil + } + return multiTx, nil +} + +func (s *InMemMultiTransactionStorage) UpdateMultiTransaction(multiTx *MultiTransaction) error { + s.storage[multiTx.ID] = multiTx + return nil +} + +func (s *InMemMultiTransactionStorage) DeleteMultiTransaction(id common.MultiTransactionIDType) error { + delete(s.storage, id) + return nil +} + +func (s *InMemMultiTransactionStorage) ReadMultiTransactions(ids []common.MultiTransactionIDType) ([]*MultiTransaction, error) { + var multiTxs []*MultiTransaction + for _, id := range ids { + multiTx, ok := s.storage[id] + if !ok { + continue + } + multiTxs = append(multiTxs, multiTx) + } + return multiTxs, nil +} + +func (s *InMemMultiTransactionStorage) ReadMultiTransactionsByDetails(details *MultiTxDetails) ([]*MultiTransaction, error) { + var multiTxs []*MultiTransaction + for _, multiTx := range s.storage { + if (details.AnyAddress != eth_common.Address{}) && + (multiTx.FromAddress != details.AnyAddress && multiTx.ToAddress != details.AnyAddress) { + continue + } + + if (details.FromAddress != eth_common.Address{}) && multiTx.FromAddress != details.FromAddress { + continue + } + + if (details.ToAddress != eth_common.Address{}) && multiTx.ToAddress != details.ToAddress { + continue + } + + if details.ToChainID != 0 && multiTx.ToNetworkID != details.ToChainID { + continue + } + + if details.Type != MultiTransactionTypeInvalid && multiTx.Type != details.Type { + continue + } + + if details.CrossTxID != "" && multiTx.CrossTxID != details.CrossTxID { + continue + } + + multiTxs = append(multiTxs, multiTx) + } + return multiTxs, nil +} diff --git a/services/wallet/transfer/transaction_manager.go b/services/wallet/transfer/transaction_manager.go index 072aa7a1a..a0c2a6881 100644 --- a/services/wallet/transfer/transaction_manager.go +++ b/services/wallet/transfer/transaction_manager.go @@ -1,7 +1,6 @@ package transfer import ( - "database/sql" "fmt" "math/big" "time" @@ -35,7 +34,7 @@ type TransactionDescription struct { } type TransactionManager struct { - db *sql.DB + storage MultiTransactionStorage gethManager *account.GethManager transactor *transactions.Transactor config *params.NodeConfig @@ -45,11 +44,19 @@ type TransactionManager struct { multiTransactionForKeycardSigning *MultiTransaction transactionsBridgeData []*bridge.TransactionBridge - transactionsForKeycardSingning map[common.Hash]*TransactionDescription + transactionsForKeycardSigning map[common.Hash]*TransactionDescription +} + +type MultiTransactionStorage interface { + CreateMultiTransaction(tx *MultiTransaction) error + ReadMultiTransactions(ids []wallet_common.MultiTransactionIDType) ([]*MultiTransaction, error) + ReadMultiTransactionsByDetails(details *MultiTxDetails) ([]*MultiTransaction, error) + UpdateMultiTransaction(tx *MultiTransaction) error + DeleteMultiTransaction(id wallet_common.MultiTransactionIDType) error } func NewTransactionManager( - db *sql.DB, + storage MultiTransactionStorage, gethManager *account.GethManager, transactor *transactions.Transactor, config *params.NodeConfig, @@ -58,7 +65,7 @@ func NewTransactionManager( eventFeed *event.Feed, ) *TransactionManager { return &TransactionManager{ - db: db, + storage: storage, gethManager: gethManager, transactor: transactor, config: config, @@ -78,6 +85,7 @@ const ( MultiTransactionSend = iota MultiTransactionSwap MultiTransactionBridge + MultiTransactionTypeInvalid = 255 ) type MultiTransaction struct { @@ -235,6 +243,12 @@ func (tm *TransactionManager) BuildRawTransaction(chainID uint64, sendArgs trans }, nil } -func (tm *TransactionManager) SendTransactionWithSignature(chainID uint64, txType transactions.PendingTrxType, sendArgs transactions.SendTxArgs, signature []byte) (hash types.Hash, err error) { - return tm.transactor.BuildTransactionAndSendWithSignature(chainID, sendArgs, signature) +func (tm *TransactionManager) SendTransactionWithSignature(chainID uint64, sendArgs transactions.SendTxArgs, signature []byte) (hash types.Hash, err error) { + txWithSignature, err := tm.transactor.BuildTransactionWithSignature(chainID, sendArgs, signature) + if err != nil { + return hash, err + } + + hash, err = tm.transactor.SendTransactionWithSignature(common.Address(sendArgs.From), sendArgs.Symbol, sendArgs.MultiTransactionID, txWithSignature) + return hash, err } diff --git a/services/wallet/transfer/transaction_manager_internal.go b/services/wallet/transfer/transaction_manager_internal.go new file mode 100644 index 000000000..7ce957b50 --- /dev/null +++ b/services/wallet/transfer/transaction_manager_internal.go @@ -0,0 +1,33 @@ +package transfer + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" + ethTypes "github.com/ethereum/go-ethereum/core/types" + "github.com/status-im/status-go/services/wallet/bridge" +) + +func (tm *TransactionManager) buildTransactions(bridges map[string]bridge.Bridge) ([]string, error) { + tm.transactionsForKeycardSigning = make(map[common.Hash]*TransactionDescription) + var hashes []string + for _, bridgeTx := range tm.transactionsBridgeData { + builtTx, err := bridges[bridgeTx.BridgeName].BuildTransaction(bridgeTx) + if err != nil { + return hashes, err + } + + signer := ethTypes.NewLondonSigner(big.NewInt(int64(bridgeTx.ChainID))) + txHash := signer.Hash(builtTx) + + tm.transactionsForKeycardSigning[txHash] = &TransactionDescription{ + from: common.Address(bridgeTx.From()), + chainID: bridgeTx.ChainID, + builtTx: builtTx, + } + + hashes = append(hashes, txHash.String()) + } + + return hashes, nil +} diff --git a/services/wallet/transfer/transaction_manager_multitransaction.go b/services/wallet/transfer/transaction_manager_multitransaction.go index 3afb559bd..c888e22d1 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction.go +++ b/services/wallet/transfer/transaction_manager_multitransaction.go @@ -2,21 +2,10 @@ package transfer import ( "context" - "database/sql" - "encoding/hex" - "errors" "fmt" - "math/big" - "strings" - "time" - ethTypes "github.com/ethereum/go-ethereum/core/types" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/log" "github.com/status-im/status-go/account" - "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/services/wallet/bridge" wallet_common "github.com/status-im/status-go/services/wallet/common" @@ -26,128 +15,12 @@ import ( const multiTransactionColumns = "id, from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" const selectMultiTransactionColumns = "id, COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" -func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { - var multiTransactions []*MultiTransaction - for rows.Next() { - multiTransaction := &MultiTransaction{} - var fromAmountDB, toAmountDB sql.NullString - var fromTxHash, toTxHash sql.RawBytes - err := rows.Scan( - &multiTransaction.ID, - &multiTransaction.FromNetworkID, - &fromTxHash, - &multiTransaction.FromAddress, - &multiTransaction.FromAsset, - &fromAmountDB, - &multiTransaction.ToNetworkID, - &toTxHash, - &multiTransaction.ToAddress, - &multiTransaction.ToAsset, - &toAmountDB, - &multiTransaction.Type, - &multiTransaction.CrossTxID, - &multiTransaction.Timestamp, - ) - if len(fromTxHash) > 0 { - multiTransaction.FromTxHash = common.BytesToHash(fromTxHash) - } - if len(toTxHash) > 0 { - multiTransaction.ToTxHash = common.BytesToHash(toTxHash) - } - if err != nil { - return nil, err - } - - if fromAmountDB.Valid { - multiTransaction.FromAmount = new(hexutil.Big) - if _, ok := (*big.Int)(multiTransaction.FromAmount).SetString(fromAmountDB.String, 0); !ok { - return nil, errors.New("failed to convert fromAmountDB.String to big.Int: " + fromAmountDB.String) - } - } - - if toAmountDB.Valid { - multiTransaction.ToAmount = new(hexutil.Big) - if _, ok := (*big.Int)(multiTransaction.ToAmount).SetString(toAmountDB.String, 0); !ok { - return nil, errors.New("failed to convert fromAmountDB.String to big.Int: " + toAmountDB.String) - } - } - - multiTransactions = append(multiTransactions, multiTransaction) - } - - return multiTransactions, nil -} - -// insertMultiTransaction inserts a multi transaction into the database and updates timestamp -func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error { - insert, err := db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) - if err != nil { - return err - } - _, err = insert.Exec( - multiTransaction.ID, - multiTransaction.FromNetworkID, - multiTransaction.FromTxHash, - multiTransaction.FromAddress, - multiTransaction.FromAsset, - multiTransaction.FromAmount.String(), - multiTransaction.ToNetworkID, - multiTransaction.ToTxHash, - multiTransaction.ToAddress, - multiTransaction.ToAsset, - multiTransaction.ToAmount.String(), - multiTransaction.Type, - multiTransaction.CrossTxID, - multiTransaction.Timestamp, - ) - if err != nil { - return err - } - defer insert.Close() - - return err -} - func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) { - return multiTransaction.ID, insertMultiTransaction(tm.db, multiTransaction) -} - -func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error { - if multiTransaction.ID == wallet_common.NoMultiTransactionID { - return fmt.Errorf("no multitransaction ID") - } - - update, err := db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (%s) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) - - if err != nil { - return err - } - _, err = update.Exec( - multiTransaction.ID, - multiTransaction.FromNetworkID, - multiTransaction.FromTxHash, - multiTransaction.FromAddress, - multiTransaction.FromAsset, - multiTransaction.FromAmount.String(), - multiTransaction.ToNetworkID, - multiTransaction.ToTxHash, - multiTransaction.ToAddress, - multiTransaction.ToAsset, - multiTransaction.ToAmount.String(), - multiTransaction.Type, - multiTransaction.CrossTxID, - multiTransaction.Timestamp, - ) - if err != nil { - return err - } - return update.Close() + return multiTransaction.ID, tm.storage.CreateMultiTransaction(multiTransaction) } func (tm *TransactionManager) UpdateMultiTransaction(multiTransaction *MultiTransaction) error { - return updateMultiTransaction(tm.db, multiTransaction) + return tm.storage.UpdateMultiTransaction(multiTransaction) } func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand, @@ -190,7 +63,8 @@ func (tm *TransactionManager) SendTransactionForSigningToKeycard(ctx context.Con } func (tm *TransactionManager) SendTransactions(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) (*MultiTransactionCommandResult, error) { - hashes, err := tm.sendTransactions(multiTransaction, data, bridges, account) + updateDataFromMultiTx(data, multiTransaction) + hashes, err := sendTransactions(data, bridges, account) if err != nil { return nil, err } @@ -202,53 +76,22 @@ func (tm *TransactionManager) SendTransactions(ctx context.Context, multiTransac } func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Context, signatures map[string]SignatureDetails) (*MultiTransactionCommandResult, error) { - if tm.multiTransactionForKeycardSigning == nil { - return nil, errors.New("no multi transaction to proceed with") - } - if len(tm.transactionsBridgeData) == 0 { - return nil, errors.New("no transactions bridge data to proceed with") - } - if len(tm.transactionsForKeycardSingning) == 0 { - return nil, errors.New("no transactions to proceed with") - } - if len(signatures) != len(tm.transactionsForKeycardSingning) { - return nil, errors.New("not all transactions have been signed") - } - - // check if all transactions have been signed - for hash, desc := range tm.transactionsForKeycardSingning { - sigDetails, ok := signatures[hash.String()] - if !ok { - return nil, fmt.Errorf("missing signature for transaction %s", hash) - } - - rBytes, _ := hex.DecodeString(sigDetails.R) - sBytes, _ := hex.DecodeString(sigDetails.S) - vByte := byte(0) - if sigDetails.V == "01" { - vByte = 1 - } - - desc.signature = make([]byte, crypto.SignatureLength) - copy(desc.signature[32-len(rBytes):32], rBytes) - copy(desc.signature[64-len(rBytes):64], sBytes) - desc.signature[64] = vByte + if err := addSignaturesToTransactions(tm.transactionsForKeycardSigning, signatures); err != nil { + return nil, err } // send transactions hashes := make(map[uint64][]types.Hash) - for _, desc := range tm.transactionsForKeycardSingning { - hash, err := tm.transactor.AddSignatureToTransactionAndSend( - desc.chainID, - desc.from, - tm.multiTransactionForKeycardSigning.FromAsset, - tm.multiTransactionForKeycardSigning.ID, - desc.builtTx, - desc.signature, - ) + for _, desc := range tm.transactionsForKeycardSigning { + txWithSignature, err := tm.transactor.AddSignatureToTransaction(desc.chainID, desc.builtTx, desc.signature) if err != nil { return nil, err } + + hash, err := tm.transactor.SendTransactionWithSignature(desc.from, tm.multiTransactionForKeycardSigning.FromAsset, tm.multiTransactionForKeycardSigning.ID, txWithSignature) + if err != nil { + return nil, err // TODO: One of transfers within transaction could have been sent. Need to notify user about it + } hashes[desc.chainID] = append(hashes[desc.chainID], hash) } @@ -263,138 +106,16 @@ func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Cont }, nil } -func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { - multiTransaction := NewMultiTransaction( - /* Timestamp: */ uint64(time.Now().Unix()), - /* FromNetworkID: */ 0, - /* ToNetworkID: */ 0, - /* FromTxHash: */ common.Hash{}, - /* ToTxHash: */ common.Hash{}, - /* FromAddress: */ command.FromAddress, - /* ToAddress: */ command.ToAddress, - /* FromAsset: */ command.FromAsset, - /* ToAsset: */ command.ToAsset, - /* FromAmount: */ command.FromAmount, - /* ToAmount: */ new(hexutil.Big), - /* Type: */ command.Type, - /* CrossTxID: */ "", - ) - - return multiTransaction -} - -func (tm *TransactionManager) buildTransactions(bridges map[string]bridge.Bridge) ([]string, error) { - tm.transactionsForKeycardSingning = make(map[common.Hash]*TransactionDescription) - var hashes []string - for _, bridgeTx := range tm.transactionsBridgeData { - builtTx, err := bridges[bridgeTx.BridgeName].BuildTransaction(bridgeTx) - if err != nil { - return hashes, err - } - - signer := ethTypes.NewLondonSigner(big.NewInt(int64(bridgeTx.ChainID))) - txHash := signer.Hash(builtTx) - - tm.transactionsForKeycardSingning[txHash] = &TransactionDescription{ - from: common.Address(bridgeTx.From()), - chainID: bridgeTx.ChainID, - builtTx: builtTx, - } - - hashes = append(hashes, txHash.String()) - } - - return hashes, nil -} - -func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransaction, - data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) ( - map[uint64][]types.Hash, error) { - - hashes := make(map[uint64][]types.Hash) - for _, tx := range data { - if tx.TransferTx != nil { - tx.TransferTx.MultiTransactionID = multiTransaction.ID - tx.TransferTx.Symbol = multiTransaction.FromAsset - } - if tx.HopTx != nil { - tx.HopTx.MultiTransactionID = multiTransaction.ID - tx.HopTx.Symbol = multiTransaction.FromAsset - } - if tx.CbridgeTx != nil { - tx.CbridgeTx.MultiTransactionID = multiTransaction.ID - tx.CbridgeTx.Symbol = multiTransaction.FromAsset - } - if tx.ERC721TransferTx != nil { - tx.ERC721TransferTx.MultiTransactionID = multiTransaction.ID - tx.ERC721TransferTx.Symbol = multiTransaction.FromAsset - } - if tx.ERC1155TransferTx != nil { - tx.ERC1155TransferTx.MultiTransactionID = multiTransaction.ID - tx.ERC1155TransferTx.Symbol = multiTransaction.FromAsset - } - if tx.SwapTx != nil { - tx.SwapTx.MultiTransactionID = multiTransaction.ID - tx.SwapTx.Symbol = multiTransaction.FromAsset - } - - hash, err := bridges[tx.BridgeName].Send(tx, account) - if err != nil { - return nil, err // TODO: One of transfers within transaction could have been sent. Need to notify user about it - } - hashes[tx.ChainID] = append(hashes[tx.ChainID], hash) - } - return hashes, nil -} - func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []wallet_common.MultiTransactionIDType) ([]*MultiTransaction, error) { - placeholders := make([]string, len(ids)) - args := make([]interface{}, len(ids)) - for i, v := range ids { - placeholders[i] = "?" - args[i] = v - } - - stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s - FROM multi_transactions - WHERE id in (%s)`, - selectMultiTransactionColumns, - strings.Join(placeholders, ","))) - if err != nil { - return nil, err - } - defer stmt.Close() - - rows, err := stmt.Query(args...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToMultiTransactions(rows) -} - -func (tm *TransactionManager) getBridgeMultiTransactions(ctx context.Context, toChainID uint64, crossTxID string) ([]*MultiTransaction, error) { - stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s - FROM multi_transactions - WHERE type=? AND to_network_id=? AND cross_tx_id=?`, - multiTransactionColumns)) - if err != nil { - return nil, err - } - defer stmt.Close() - - rows, err := stmt.Query(MultiTransactionBridge, toChainID, crossTxID) - if err != nil { - return nil, err - } - defer rows.Close() - - return rowsToMultiTransactions(rows) + return tm.storage.ReadMultiTransactions(ids) } func (tm *TransactionManager) GetBridgeOriginMultiTransaction(ctx context.Context, toChainID uint64, crossTxID string) (*MultiTransaction, error) { - multiTxs, err := tm.getBridgeMultiTransactions(ctx, toChainID, crossTxID) + details := NewMultiTxDetails() + details.ToChainID = toChainID + details.CrossTxID = crossTxID + + multiTxs, err := tm.storage.ReadMultiTransactionsByDetails(details) if err != nil { return nil, err } @@ -410,7 +131,11 @@ func (tm *TransactionManager) GetBridgeOriginMultiTransaction(ctx context.Contex } func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.Context, toChainID uint64, crossTxID string) (*MultiTransaction, error) { - multiTxs, err := tm.getBridgeMultiTransactions(ctx, toChainID, crossTxID) + details := NewMultiTxDetails() + details.ToChainID = toChainID + details.CrossTxID = crossTxID + + multiTxs, err := tm.storage.ReadMultiTransactionsByDetails(details) if err != nil { return nil, err } @@ -424,73 +149,3 @@ func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.C return nil, nil } - -func idFromTimestamp() wallet_common.MultiTransactionIDType { - return wallet_common.MultiTransactionIDType(time.Now().UnixMilli()) -} - -var multiTransactionIDGenerator func() wallet_common.MultiTransactionIDType = idFromTimestamp - -func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Address) error { - // We must not remove those transactions, where from_address and to_address are different and both are stored in accounts DB - // and one of them is equal to the address, as we want to keep the records for the other address - // That is why we don't use cascade delete here with references to transfers table, as we might have 2 records in multi_transactions - // for the same transaction, one for each address - - stmt, err := tm.db.Prepare(`SELECT id, from_address, to_address - FROM multi_transactions - WHERE from_address=? OR to_address=?`) - if err != nil { - return err - } - - rows, err := stmt.Query(address, address) - if err != nil { - return err - } - defer rows.Close() - - ids := make([]int, 0) - id, fromAddress, toAddress := 0, common.Address{}, common.Address{} - for rows.Next() { - err = rows.Scan(&id, &fromAddress, &toAddress) - if err != nil { - log.Error("Failed to scan row", "error", err) - continue - } - - // Remove self transactions as well, leave only those where we have the counterparty in accounts DB - if fromAddress != toAddress { - // If both addresses are stored in accounts DB, we don't remove the record - var addressToCheck common.Address - if fromAddress == address { - addressToCheck = toAddress - } else { - addressToCheck = fromAddress - } - counterpartyExists, err := tm.accountsDB.AddressExists(types.Address(addressToCheck)) - if err != nil { - log.Error("Failed to query accounts db for a given address", "address", address, "error", err) - continue - } - - // Skip removal if counterparty is in accounts DB and removed address is not sender - if counterpartyExists && address != fromAddress { - continue - } - } - - ids = append(ids, id) - } - - if len(ids) > 0 { - for _, id := range ids { - _, err = tm.db.Exec(`DELETE FROM multi_transactions WHERE id=?`, id) - if err != nil { - log.Error("Failed to remove multitransaction", "id", id, "error", err) - } - } - } - - return err -} diff --git a/services/wallet/transfer/transaction_manager_test.go b/services/wallet/transfer/transaction_manager_test.go index bd9405a85..3dfcf72c2 100644 --- a/services/wallet/transfer/transaction_manager_test.go +++ b/services/wallet/transfer/transaction_manager_test.go @@ -19,7 +19,7 @@ func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) require.NoError(t, err) SetMultiTransactionIDGenerator(StaticIDCounter()) // to have different multi-transaction IDs even with fast execution - return &TransactionManager{db, nil, nil, nil, nil, nil, nil, nil, nil, nil}, func() { + return &TransactionManager{NewMultiTransactionDB(db), nil, nil, nil, nil, nil, nil, nil, nil, nil}, func() { require.NoError(t, db.Close()) } } @@ -146,7 +146,7 @@ func TestMultiTransactions(t *testing.T) { trx1.FromAmount = (*hexutil.Big)(big.NewInt(789)) trx1.ToAmount = (*hexutil.Big)(big.NewInt(890)) - err = updateMultiTransaction(manager.db, &trx1) + err = manager.UpdateMultiTransaction(&trx1) require.NoError(t, err) rst, err = manager.GetMultiTransactions(context.Background(), ids)