diff --git a/services/wallet/api.go b/services/wallet/api.go index 51e508413..1f1bed1e8 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -558,6 +558,11 @@ func (api *API) CreateMultiTransaction(ctx context.Context, multiTransaction *tr return api.s.transactionManager.CreateMultiTransaction(ctx, multiTransaction, data, api.router.bridges, password) } +func (api *API) GetMultiTransactions(ctx context.Context, transactionIDs []transfer.MultiTransactionIDType) ([]*transfer.MultiTransaction, error) { + log.Debug("[WalletAPI:: GetMultiTransactions] for IDs", transactionIDs) + return api.s.transactionManager.GetMultiTransactions(ctx, transactionIDs) +} + func (api *API) GetCachedCurrencyFormats() (currency.FormatPerSymbol, error) { log.Debug("call to GetCachedCurrencyFormats") return api.s.currency.GetCachedCurrencyFormats() diff --git a/services/wallet/transfer/transaction.go b/services/wallet/transfer/transaction.go index 58dd3137d..8d63d1fb0 100644 --- a/services/wallet/transfer/transaction.go +++ b/services/wallet/transfer/transaction.go @@ -250,18 +250,13 @@ func (tm *TransactionManager) Watch(ctx context.Context, transactionHash common. return watchTxCommand.Command()(commandContext) } -func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionResult, error) { - selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password) - if err != nil { - return nil, err - } +const multiTransactionColumns = "from_address, from_asset, from_amount, to_address, to_asset, type, timestamp" - insert, err := tm.db.Prepare(`INSERT OR REPLACE INTO multi_transactions - (from_address, from_asset, from_amount, to_address, to_asset, type, timestamp) - VALUES - (?, ?, ?, ?, ?, ?, ?)`) +func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (MultiTransactionIDType, error) { + insert, err := db.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO multi_transactions (%s) + VALUES(?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) if err != nil { - return nil, err + return 0, err } result, err := insert.Exec( multiTransaction.FromAddress, @@ -273,10 +268,20 @@ func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiT time.Now().Unix(), ) if err != nil { - return nil, err + return 0, err } defer insert.Close() multiTransactionID, err := result.LastInsertId() + return MultiTransactionIDType(multiTransactionID), err +} + +func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionResult, error) { + selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password) + if err != nil { + return nil, err + } + + multiTransactionID, err := insertMultiTransaction(tm.db, multiTransaction) if err != nil { return nil, err } @@ -296,7 +301,7 @@ func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiT Data: tx.Data().String(), Type: WalletTransfer, ChainID: tx.ChainID, - MultiTransactionID: MultiTransactionIDType(multiTransactionID), + MultiTransactionID: multiTransactionID, Symbol: multiTransaction.FromAsset, } err = tm.AddPending(pendingTransaction) @@ -307,11 +312,65 @@ func (tm *TransactionManager) CreateMultiTransaction(ctx context.Context, multiT } return &MultiTransactionResult{ - ID: multiTransactionID, + ID: int64(multiTransactionID), Hashes: hashes, }, nil } +func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []MultiTransactionIDType) ([]*MultiTransaction, error) { + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, v := range ids { + placeholders[i] = "?" + args[i] = v + } + + stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT rowid, %s + FROM multi_transactions + WHERE rowid in (%s)`, + multiTransactionColumns, + strings.Join(placeholders, ","))) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var multiTransactions []*MultiTransaction + for rows.Next() { + multiTransaction := &MultiTransaction{} + var fromAmount string + err := rows.Scan( + &multiTransaction.ID, + &multiTransaction.FromAddress, + &multiTransaction.FromAsset, + &fromAmount, + &multiTransaction.ToAddress, + &multiTransaction.ToAsset, + &multiTransaction.Type, + &multiTransaction.Timestamp, + ) + if err != nil { + return nil, err + } + + multiTransaction.FromAmount = new(hexutil.Big) + _, ok := (*big.Int)(multiTransaction.FromAmount).SetString(fromAmount, 0) + if !ok { + return nil, errors.New("failed to convert fromAmount to big.Int: " + fromAmount) + } + + multiTransactions = append(multiTransactions, multiTransaction) + } + + return multiTransactions, nil +} + func (tm *TransactionManager) getVerifiedWalletAccount(address, password string) (*account.SelectedExtKey, error) { exists, err := tm.accountsDB.AddressExists(types.HexToAddress(address)) if err != nil { diff --git a/services/wallet/transfer/transaction_test.go b/services/wallet/transfer/transaction_test.go index 741482e8d..a0aefe2f4 100644 --- a/services/wallet/transfer/transaction_test.go +++ b/services/wallet/transfer/transaction_test.go @@ -1,20 +1,21 @@ package transfer import ( + "context" "math/big" "testing" "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/services/wallet/bigint" - "github.com/status-im/status-go/sqlite" ) func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) { - db, err := appdatabase.InitializeDB(sqlite.InMemoryPath, "wallet-tests", sqlite.ReducedKDFIterationsNumber) + db, err := appdatabase.SetupTestMemorySQLDB("wallet-transfer-transaction-tests") require.NoError(t, err) return &TransactionManager{db, nil, nil, nil, nil}, func() { require.NoError(t, db.Close()) @@ -73,3 +74,45 @@ func TestPendingTransactions(t *testing.T) { require.NoError(t, err) require.Equal(t, 0, len(rst)) } + +func TestMultiTransactions(t *testing.T) { + manager, stop := setupTestTransactionDB(t) + defer stop() + + trx1 := MultiTransaction{ + Timestamp: 123, + FromAddress: common.Address{1}, + ToAddress: common.Address{2}, + FromAsset: "fromAsset", + ToAsset: "toAsset", + FromAmount: (*hexutil.Big)(big.NewInt(123)), + Type: MultiTransactionBridge, + } + trx2 := trx1 + trx2.FromAmount = (*hexutil.Big)(big.NewInt(456)) + + var err error + ids := make([]MultiTransactionIDType, 2) + ids[0], err = insertMultiTransaction(manager.db, &trx1) + require.NoError(t, err) + require.Equal(t, MultiTransactionIDType(1), ids[0]) + ids[1], err = insertMultiTransaction(manager.db, &trx2) + require.NoError(t, err) + require.Equal(t, MultiTransactionIDType(2), ids[1]) + + rst, err := manager.GetMultiTransactions(context.Background(), []MultiTransactionIDType{ids[0], 555}) + require.NoError(t, err) + require.Equal(t, 1, len(rst)) + + rst, err = manager.GetMultiTransactions(context.Background(), ids) + require.NoError(t, err) + require.Equal(t, 2, len(rst)) + + for _, id := range ids { + found := false + for _, trx := range rst { + found = found || id == MultiTransactionIDType(trx.ID) + } + require.True(t, found, "result contains transaction with id %d", id) + } +}