chore(wallet)_: clean up wallet API send and sign transactions

This commit is contained in:
Ivan Belyakov 2024-05-02 17:36:42 +02:00 committed by IvanBelyakoff
parent 8cd4560823
commit ed164e4ac5
10 changed files with 307 additions and 239 deletions

View File

@ -152,7 +152,7 @@ func GetRecipients(ctx context.Context, db *sql.DB, chainIDs []common.ChainID, a
return entries, hasMore, nil return entries, hasMore, nil
} }
func GetOldestTimestamp(ctx context.Context, db *sql.DB, addresses []eth.Address) (timestamp int64, err error) { func GetOldestTimestamp(ctx context.Context, db *sql.DB, addresses []eth.Address) (timestamp uint64, err error) {
filterAllAddresses := len(addresses) == 0 filterAllAddresses := len(addresses) == 0
involvedAddresses := noEntriesInTmpTableSQLValues involvedAddresses := noEntriesInTmpTableSQLValues
if !filterAllAddresses { if !filterAllAddresses {

View File

@ -445,7 +445,7 @@ SELECT
NULL AS transfer_hash, NULL AS transfer_hash,
NULL AS pending_hash, NULL AS pending_hash,
NULL AS network_id, NULL AS network_id,
multi_transactions.ROWID AS multi_tx_id, multi_transactions.id AS multi_tx_id,
multi_transactions.timestamp AS timestamp, multi_transactions.timestamp AS timestamp,
multi_transactions.type AS mt_type, multi_transactions.type AS mt_type,
NULL as tr_type, NULL as tr_type,
@ -488,8 +488,8 @@ SELECT
FROM FROM
multi_transactions multi_transactions
CROSS JOIN filter_conditions CROSS JOIN filter_conditions
LEFT JOIN tr_status ON multi_transactions.ROWID = tr_status.multi_transaction_id LEFT JOIN tr_status ON multi_transactions.id = tr_status.multi_transaction_id
LEFT JOIN pending_status ON multi_transactions.ROWID = pending_status.multi_transaction_id LEFT JOIN pending_status ON multi_transactions.id = pending_status.multi_transaction_id
WHERE WHERE
( (
( (
@ -577,7 +577,7 @@ WHERE
FROM FROM
tr_network_ids tr_network_ids
WHERE WHERE
multi_transactions.ROWID = tr_network_ids.multi_transaction_id multi_transactions.id = tr_network_ids.multi_transaction_id
) )
OR EXISTS ( OR EXISTS (
SELECT SELECT
@ -585,7 +585,7 @@ WHERE
FROM FROM
pending_network_ids pending_network_ids
WHERE WHERE
multi_transactions.ROWID = pending_network_ids.multi_transaction_id multi_transactions.id = pending_network_ids.multi_transaction_id
) )
) )
) )

View File

@ -91,7 +91,7 @@ FROM (
FROM FROM
tr_network_ids tr_network_ids
WHERE WHERE
multi_transactions.ROWID = tr_network_ids.multi_transaction_id multi_transactions.id = tr_network_ids.multi_transaction_id
) )
OR EXISTS ( OR EXISTS (
SELECT SELECT
@ -99,7 +99,7 @@ FROM (
FROM FROM
pending_network_ids pending_network_ids
WHERE WHERE
multi_transactions.ROWID = pending_network_ids.multi_transaction_id multi_transactions.id = pending_network_ids.multi_transaction_id
) )
) )
) )

View File

@ -15,6 +15,7 @@ import (
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
gethrpc "github.com/ethereum/go-ethereum/rpc" gethrpc "github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/account"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/rpc/network"
@ -526,7 +527,13 @@ func (api *API) GetAddressDetails(ctx context.Context, chainID uint64, address s
func (api *API) SignMessage(ctx context.Context, message types.HexBytes, address common.Address, password string) (string, error) { func (api *API) SignMessage(ctx context.Context, message types.HexBytes, address common.Address, password string) (string, error) {
log.Debug("[WalletAPI::SignMessage]", "message", message, "address", address) log.Debug("[WalletAPI::SignMessage]", "message", message, "address", address)
return api.s.transactionManager.SignMessage(message, address, password)
selectedAccount, err := api.s.gethManager.VerifyAccountPassword(api.s.Config().KeyStoreDir, address.Hex(), password)
if err != nil {
return "", err
}
return api.s.transactionManager.SignMessage(message, selectedAccount)
} }
func (api *API) BuildTransaction(ctx context.Context, chainID uint64, sendTxArgsJSON string) (response *transfer.TxResponse, err error) { func (api *API) BuildTransaction(ctx context.Context, chainID uint64, sendTxArgsJSON string) (response *transfer.TxResponse, err error) {
@ -574,7 +581,32 @@ func (api *API) SendTransactionWithSignature(ctx context.Context, chainID uint64
func (api *API) CreateMultiTransaction(ctx context.Context, multiTransactionCommand *transfer.MultiTransactionCommand, data []*bridge.TransactionBridge, password string) (*transfer.MultiTransactionCommandResult, error) { func (api *API) CreateMultiTransaction(ctx context.Context, multiTransactionCommand *transfer.MultiTransactionCommand, data []*bridge.TransactionBridge, password string) (*transfer.MultiTransactionCommandResult, error) {
log.Debug("[WalletAPI:: CreateMultiTransaction] create multi transaction") log.Debug("[WalletAPI:: CreateMultiTransaction] create multi transaction")
return api.s.transactionManager.CreateMultiTransactionFromCommand(ctx, multiTransactionCommand, data, api.router.bridges, password)
cmd, err := api.s.transactionManager.CreateMultiTransactionFromCommand(ctx, multiTransactionCommand, data)
if err != nil {
return nil, err
}
if password != "" {
selectedAccount, err := api.getVerifiedWalletAccount(multiTransactionCommand.FromAddress.Hex(), password)
if err != nil {
return nil, err
}
cmdRes, err := api.s.transactionManager.SendTransactions(ctx, cmd, data, api.router.bridges, selectedAccount)
if err != nil {
return nil, err
}
_, err = api.s.transactionManager.InsertMultiTransaction(cmd)
if err != nil {
return nil, err
}
return cmdRes, nil
}
return nil, api.s.transactionManager.SendTransactionForSigningToKeycard(ctx, cmd, data, api.router.bridges)
} }
func (api *API) ProceedWithTransactionsSignatures(ctx context.Context, signatures map[string]transfer.SignatureDetails) (*transfer.MultiTransactionCommandResult, error) { func (api *API) ProceedWithTransactionsSignatures(ctx context.Context, signatures map[string]transfer.SignatureDetails) (*transfer.MultiTransactionCommandResult, error) {
@ -740,3 +772,28 @@ func (api *API) WCAuthRequest(ctx context.Context, address common.Address, authM
return api.s.walletConnect.AuthRequest(address, authMessage) return api.s.walletConnect.AuthRequest(address, authMessage)
} }
func (api *API) getVerifiedWalletAccount(address, password string) (*account.SelectedExtKey, error) {
exists, err := api.s.accountsDB.AddressExists(types.HexToAddress(address))
if err != nil {
log.Error("failed to query db for a given address", "address", address, "error", err)
return nil, err
}
if !exists {
log.Error("failed to get a selected account", "err", transactions.ErrInvalidTxSender)
return nil, transactions.ErrAccountDoesntExist
}
keyStoreDir := api.s.Config().KeyStoreDir
key, err := api.s.gethManager.VerifyAccountPassword(keyStoreDir, address, password)
if err != nil {
log.Error("failed to verify account", "account", address, "error", err)
return nil, err
}
return &account.SelectedExtKey{
Address: key.Address,
AccountKey: key,
}, nil
}

View File

@ -201,6 +201,7 @@ func NewService(
blockChainState: blockChainState, blockChainState: blockChainState,
keycardPairings: NewKeycardPairings(), keycardPairings: NewKeycardPairings(),
walletConnect: walletconnect, walletConnect: walletconnect,
config: config,
} }
} }
@ -235,6 +236,7 @@ type Service struct {
blockChainState *blockchainstate.BlockChainState blockChainState *blockchainstate.BlockChainState
keycardPairings *KeycardPairings keycardPairings *KeycardPairings
walletConnect *walletconnect.Service walletConnect *walletconnect.Service
config *params.NodeConfig
} }
// Start signals transmitter. // Start signals transmitter.
@ -293,3 +295,7 @@ func (s *Service) IsStarted() bool {
func (s *Service) KeycardPairings() *KeycardPairings { func (s *Service) KeycardPairings() *KeycardPairings {
return s.keycardPairings return s.keycardPairings
} }
func (s *Service) Config() *params.NodeConfig {
return s.config
}

View File

@ -45,23 +45,21 @@ func upsertHopBridgeOriginTx(ctx context.Context, transactionManager *Transactio
} }
if multiTx == nil { if multiTx == nil {
multiTx = &MultiTransaction{ multiTx = NewMultiTransaction(
// Data from "origin" transaction /* Timestamp: */ params.timestamp, // Common data
FromNetworkID: params.fromNetworkID, /* FromNetworkID: */ params.fromNetworkID, // Data from "origin" transaction
FromTxHash: params.fromTxHash, /* ToNetworkID: */ params.toNetworkID, // Data from "origin" transaction
FromAddress: params.fromAddress, /* FromTxHash: */ params.fromTxHash, // Data from "origin" transaction
FromAsset: params.fromAsset, /* ToTxHash: */ common.Hash{},
FromAmount: (*hexutil.Big)(params.fromAmount), /* FromAddress: */ params.fromAddress, // Data from "origin" transaction
ToNetworkID: params.toNetworkID, /* ToAddress: */ params.toAddress, // Data from "origin" transaction
ToAddress: params.toAddress, /* FromAsset: */ params.fromAsset, // Data from "origin" transaction
// To be replaced by "destination" transaction, need to be non-null /* ToAsset: */ params.fromAsset, // To be replaced by "destination" transaction, need to be non-null
ToAsset: params.fromAsset, /* FromAmount: */ (*hexutil.Big)(params.fromAmount), // Data from "origin" transaction
ToAmount: (*hexutil.Big)(params.fromAmount), /* ToAmount: */ (*hexutil.Big)(params.fromAmount), // To be replaced by "destination" transaction, need to be non-null
// Common data /* Type: */ MultiTransactionBridge, // Common data
Type: MultiTransactionBridge, /* CrossTxID: */ params.crossTxID, // Common data
CrossTxID: params.crossTxID, )
Timestamp: params.timestamp,
}
_, err := transactionManager.InsertMultiTransaction(multiTx) _, err := transactionManager.InsertMultiTransaction(multiTx)
if err != nil { if err != nil {
@ -102,22 +100,21 @@ func upsertHopBridgeDestinationTx(ctx context.Context, transactionManager *Trans
} }
if multiTx == nil { if multiTx == nil {
multiTx = &MultiTransaction{ multiTx = NewMultiTransaction(
// To be replaced by "origin" transaction, need to be non-null /* Timestamp: */ params.timestamp, // Common data
FromAddress: params.toAddress, /* FromNetworkID: */ 0, // not set
FromAsset: params.toAsset, /* ToNetworkID: */ params.toNetworkID, // Data from "destination" transaction
FromAmount: (*hexutil.Big)(params.toAmount), /* FromTxHash: */ common.Hash{},
// Data from "destination" transaction /* ToTxHash: */ params.toTxHash, // Data from "destination" transaction
ToNetworkID: params.toNetworkID, /* FromAddress: */ params.toAddress, // To be replaced by "origin" transaction, need to be non-null
ToTxHash: params.toTxHash, /* ToAddress: */ params.toAddress, // Data from "destination" transaction
ToAddress: params.toAddress, /* FromAsset: */ params.toAsset, // To be replaced by "origin" transaction, need to be non-null
ToAsset: params.toAsset, /* ToAsset: */ params.toAsset, // Data from "destination" transaction
ToAmount: (*hexutil.Big)(params.toAmount), /* FromAmount: */ (*hexutil.Big)(params.toAmount), // To be replaced by "origin" transaction, need to be non-null
// Common data /* ToAmount: */ (*hexutil.Big)(params.toAmount), // Data from "destination" transaction
Type: MultiTransactionBridge, /* Type: */ MultiTransactionBridge, // Common data
CrossTxID: params.crossTxID, /* CrossTxID: */ params.crossTxID, // Common data
Timestamp: params.timestamp, )
}
_, err := transactionManager.InsertMultiTransaction(multiTx) _, err := transactionManager.InsertMultiTransaction(multiTx)
if err != nil { if err != nil {

View File

@ -94,15 +94,21 @@ func TestController_watchAccountsChanges(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Self multi transaction // Self multi transaction
midSelf, err := transactionManager.InsertMultiTransaction(&MultiTransaction{ midSelf, err := transactionManager.InsertMultiTransaction(NewMultiTransaction(
FromAddress: address, /* Timestamp: */ 1,
ToAddress: address, /* FromNetworkID: */ 1,
FromAsset: "ETH", /* ToNetworkID: */ 1,
ToAsset: "DAI", /* FromTxHash: */ common.Hash{},
FromAmount: &hexutil.Big{}, /* ToTxHash: */ common.Hash{},
ToAmount: &hexutil.Big{}, /* FromAddress: */ address,
Timestamp: 1, /* ToAddress: */ address,
}) /* FromAsset: */ "ETH",
/* ToAsset: */ "DAI",
/* FromAmount: */ &hexutil.Big{},
/* ToAmount: */ &hexutil.Big{},
/* Type: */ MultiTransactionSend,
/* CrossTxID: */ "",
))
require.NoError(t, err) require.NoError(t, err)
mtxs, err := transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf}) mtxs, err := transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf})
@ -110,15 +116,21 @@ func TestController_watchAccountsChanges(t *testing.T) {
require.Len(t, mtxs, 1) require.Len(t, mtxs, 1)
// Send multi transaction // Send multi transaction
mt := &MultiTransaction{ mt := NewMultiTransaction(
FromAddress: address, /* Timestamp: */ 2,
ToAddress: counterparty, /* FromNetworkID: */ 1,
FromAsset: "ETH", /* ToNetworkID: */ 1,
ToAsset: "DAI", /* FromTxHash: */ common.Hash{},
FromAmount: &hexutil.Big{}, /* ToTxHash: */ common.Hash{},
ToAmount: &hexutil.Big{}, /* FromAddress: */ address,
Timestamp: 2, /* ToAddress: */ counterparty,
} /* FromAsset: */ "ETH",
/* ToAsset: */ "DAI",
/* FromAmount: */ &hexutil.Big{},
/* ToAmount: */ &hexutil.Big{},
/* Type: */ MultiTransactionSend,
/* CrossTxID: */ "",
)
mid, err := transactionManager.InsertMultiTransaction(mt) mid, err := transactionManager.InsertMultiTransaction(mt)
require.NoError(t, err) require.NoError(t, err)
@ -127,15 +139,21 @@ func TestController_watchAccountsChanges(t *testing.T) {
require.Len(t, mtxs, 2) require.Len(t, mtxs, 2)
// Another Send multi-transaction where sender and receiver are inverted (both accounts are in accounts DB) // Another Send multi-transaction where sender and receiver are inverted (both accounts are in accounts DB)
midReverse, err := transactionManager.InsertMultiTransaction(&MultiTransaction{ midReverse, err := transactionManager.InsertMultiTransaction(NewMultiTransaction(
FromAddress: mt.ToAddress, /* Timestamp: */ mt.Timestamp+1,
ToAddress: mt.FromAddress, /* FromNetworkID: */ 1,
FromAsset: mt.FromAsset, /* ToNetworkID: */ 1,
ToAsset: mt.ToAsset, /* FromTxHash: */ common.Hash{},
FromAmount: mt.FromAmount, /* ToTxHash: */ common.Hash{},
ToAmount: mt.ToAmount, /* FromAddress: */ mt.ToAddress,
Timestamp: mt.Timestamp + 1, /* ToAddress: */ mt.FromAddress,
}) /* FromAsset: */ mt.FromAsset,
/* ToAsset: */ mt.ToAsset,
/* FromAmount: */ mt.FromAmount,
/* ToAmount: */ mt.ToAmount,
/* Type: */ MultiTransactionSend,
/* CrossTxID: */ "",
))
require.NoError(t, err) require.NoError(t, err)
mtxs, err = transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf, mid, midReverse}) mtxs, err = transactionManager.GetMultiTransactions(context.Background(), []wallet_common.MultiTransactionIDType{midSelf, mid, midReverse})

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"math/big" "math/big"
"time"
ethTypes "github.com/ethereum/go-ethereum/core/types" ethTypes "github.com/ethereum/go-ethereum/core/types"
@ -128,13 +129,31 @@ type TxResponse struct {
TxHash common.Hash `json:"txHash,omitempty"` TxHash common.Hash `json:"txHash,omitempty"`
} }
func (tm *TransactionManager) SignMessage(message types.HexBytes, address common.Address, password string) (string, error) { func NewMultiTransaction(timestamp uint64, fromNetworkID, toNetworkID uint64, fromTxHash, toTxHash common.Hash, fromAddress, toAddress common.Address, fromAsset, toAsset string, fromAmount, toAmount *hexutil.Big, txType MultiTransactionType, crossTxID string) *MultiTransaction {
selectedAccount, err := tm.gethManager.VerifyAccountPassword(tm.config.KeyStoreDir, address.Hex(), password) if timestamp == 0 {
if err != nil { timestamp = uint64(time.Now().Unix())
return "", err
} }
signature, err := crypto.Sign(message[:], selectedAccount.PrivateKey) return &MultiTransaction{
ID: multiTransactionIDGenerator(),
Timestamp: timestamp,
FromNetworkID: fromNetworkID,
ToNetworkID: toNetworkID,
FromTxHash: fromTxHash,
ToTxHash: toTxHash,
FromAddress: fromAddress,
ToAddress: toAddress,
FromAsset: fromAsset,
ToAsset: toAsset,
FromAmount: fromAmount,
ToAmount: toAmount,
Type: txType,
CrossTxID: crossTxID,
}
}
func (tm *TransactionManager) SignMessage(message types.HexBytes, account *types.Key) (string, error) {
signature, err := crypto.Sign(message[:], account.PrivateKey)
return types.EncodeHex(signature), err return types.EncodeHex(signature), err
} }

View File

@ -21,11 +21,10 @@ import (
"github.com/status-im/status-go/services/wallet/bridge" "github.com/status-im/status-go/services/wallet/bridge"
wallet_common "github.com/status-im/status-go/services/wallet/common" wallet_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/signal" "github.com/status-im/status-go/signal"
"github.com/status-im/status-go/transactions"
) )
const multiTransactionColumns = "from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" const multiTransactionColumns = "id, from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp"
const selectMultiTransactionColumns = "COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp" const selectMultiTransactionColumns = "id, COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp"
func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) { func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) {
var multiTransactions []*MultiTransaction var multiTransactions []*MultiTransaction
@ -79,22 +78,15 @@ func rowsToMultiTransactions(rows *sql.Rows) ([]*MultiTransaction, error) {
return multiTransactions, nil return multiTransactions, nil
} }
func getMultiTransactionTimestamp(multiTransaction *MultiTransaction) uint64 { // insertMultiTransaction inserts a multi transaction into the database and updates timestamp
if multiTransaction.Timestamp != 0 { func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error {
return multiTransaction.Timestamp
}
return uint64(time.Now().Unix())
}
// insertMultiTransaction inserts a multi transaction into the database and updates multi-transaction ID and timestamp
func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) {
insert, err := db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s) insert, err := db.Prepare(fmt.Sprintf(`INSERT INTO multi_transactions (%s)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns))
if err != nil { if err != nil {
return wallet_common.NoMultiTransactionID, err return err
} }
timestamp := getMultiTransactionTimestamp(multiTransaction) _, err = insert.Exec(
result, err := insert.Exec( multiTransaction.ID,
multiTransaction.FromNetworkID, multiTransaction.FromNetworkID,
multiTransaction.FromTxHash, multiTransaction.FromTxHash,
multiTransaction.FromAddress, multiTransaction.FromAddress,
@ -107,22 +99,18 @@ func insertMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) (wal
multiTransaction.ToAmount.String(), multiTransaction.ToAmount.String(),
multiTransaction.Type, multiTransaction.Type,
multiTransaction.CrossTxID, multiTransaction.CrossTxID,
timestamp, multiTransaction.Timestamp,
) )
if err != nil { if err != nil {
return wallet_common.NoMultiTransactionID, err return err
} }
defer insert.Close() defer insert.Close()
multiTransactionID, err := result.LastInsertId()
multiTransaction.Timestamp = timestamp return err
multiTransaction.ID = wallet_common.MultiTransactionIDType(multiTransactionID)
return wallet_common.MultiTransactionIDType(multiTransactionID), err
} }
func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) { func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) {
return insertMultiTransaction(tm.db, multiTransaction) return multiTransaction.ID, insertMultiTransaction(tm.db, multiTransaction)
} }
func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error { func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) error {
@ -130,13 +118,12 @@ func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) erro
return fmt.Errorf("no multitransaction ID") return fmt.Errorf("no multitransaction ID")
} }
update, err := db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (rowid, %s) update, err := db.Prepare(fmt.Sprintf(`REPLACE INTO multi_transactions (%s)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns)) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, multiTransactionColumns))
if err != nil { if err != nil {
return err return err
} }
timestamp := getMultiTransactionTimestamp(multiTransaction)
_, err = update.Exec( _, err = update.Exec(
multiTransaction.ID, multiTransaction.ID,
multiTransaction.FromNetworkID, multiTransaction.FromNetworkID,
@ -151,7 +138,7 @@ func updateMultiTransaction(db *sql.DB, multiTransaction *MultiTransaction) erro
multiTransaction.ToAmount.String(), multiTransaction.ToAmount.String(),
multiTransaction.Type, multiTransaction.Type,
multiTransaction.CrossTxID, multiTransaction.CrossTxID,
timestamp, multiTransaction.Timestamp,
) )
if err != nil { if err != nil {
return err return err
@ -163,55 +150,53 @@ func (tm *TransactionManager) UpdateMultiTransaction(multiTransaction *MultiTran
return updateMultiTransaction(tm.db, multiTransaction) return updateMultiTransaction(tm.db, multiTransaction)
} }
// In case of keycard account, password should be empty
func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand, func (tm *TransactionManager) CreateMultiTransactionFromCommand(ctx context.Context, command *MultiTransactionCommand,
data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) (*MultiTransactionCommandResult, error) { data []*bridge.TransactionBridge) (*MultiTransaction, error) {
multiTransaction := multiTransactionFromCommand(command) multiTransaction := multiTransactionFromCommand(command)
if multiTransaction.Type == MultiTransactionSend && multiTransaction.FromNetworkID == 0 && len(data) == 1 { if multiTransaction.Type == MultiTransactionSend && multiTransaction.FromNetworkID == 0 && len(data) == 1 {
multiTransaction.FromNetworkID = data[0].ChainID multiTransaction.FromNetworkID = data[0].ChainID
} }
multiTransactionID, err := insertMultiTransaction(tm.db, multiTransaction)
if err != nil {
return nil, err
}
multiTransaction.ID = multiTransactionID return multiTransaction, nil
if password == "" { }
func (tm *TransactionManager) SendTransactionForSigningToKeycard(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge) error {
acc, err := tm.accountsDB.GetAccountByAddress(types.Address(multiTransaction.FromAddress)) acc, err := tm.accountsDB.GetAccountByAddress(types.Address(multiTransaction.FromAddress))
if err != nil { if err != nil {
return nil, err return err
} }
kp, err := tm.accountsDB.GetKeypairByKeyUID(acc.KeyUID) kp, err := tm.accountsDB.GetKeypairByKeyUID(acc.KeyUID)
if err != nil { if err != nil {
return nil, err return err
} }
if !kp.MigratedToKeycard() { if !kp.MigratedToKeycard() {
return nil, fmt.Errorf("account being used is not migrated to a keycard, password is required") return fmt.Errorf("account being used is not migrated to a keycard, password is required")
} }
tm.multiTransactionForKeycardSigning = multiTransaction tm.multiTransactionForKeycardSigning = multiTransaction
tm.transactionsBridgeData = data tm.transactionsBridgeData = data
hashes, err := tm.buildTransactions(bridges) hashes, err := tm.buildTransactions(bridges)
if err != nil { if err != nil {
return nil, err return err
} }
signal.SendTransactionsForSigningEvent(hashes) signal.SendTransactionsForSigningEvent(hashes)
return nil, nil return nil
} }
hashes, err := tm.sendTransactions(multiTransaction, data, bridges, password) func (tm *TransactionManager) SendTransactions(ctx context.Context, multiTransaction *MultiTransaction, data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) (*MultiTransactionCommandResult, error) {
hashes, err := tm.sendTransactions(multiTransaction, data, bridges, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &MultiTransactionCommandResult{ return &MultiTransactionCommandResult{
ID: int64(multiTransactionID), ID: int64(multiTransaction.ID),
Hashes: hashes, Hashes: hashes,
}, nil }, nil
} }
@ -267,6 +252,11 @@ func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Cont
hashes[desc.chainID] = append(hashes[desc.chainID], hash) hashes[desc.chainID] = append(hashes[desc.chainID], hash)
} }
_, err := tm.InsertMultiTransaction(tm.multiTransactionForKeycardSigning)
if err != nil {
log.Error("failed to insert multi transaction", "err", err)
}
return &MultiTransactionCommandResult{ return &MultiTransactionCommandResult{
ID: int64(tm.multiTransactionForKeycardSigning.ID), ID: int64(tm.multiTransactionForKeycardSigning.ID),
Hashes: hashes, Hashes: hashes,
@ -274,18 +264,21 @@ func (tm *TransactionManager) ProceedWithTransactionsSignatures(ctx context.Cont
} }
func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction {
multiTransaction := NewMultiTransaction(
log.Info("Creating multi transaction", "command", command) /* Timestamp: */ uint64(time.Now().Unix()),
/* FromNetworkID: */ 0,
multiTransaction := &MultiTransaction{ /* ToNetworkID: */ 0,
FromAddress: command.FromAddress, /* FromTxHash: */ common.Hash{},
ToAddress: command.ToAddress, /* ToTxHash: */ common.Hash{},
FromAsset: command.FromAsset, /* FromAddress: */ command.FromAddress,
ToAsset: command.ToAsset, /* ToAddress: */ command.ToAddress,
FromAmount: command.FromAmount, /* FromAsset: */ command.FromAsset,
ToAmount: new(hexutil.Big), /* ToAsset: */ command.ToAsset,
Type: command.Type, /* FromAmount: */ command.FromAmount,
} /* ToAmount: */ new(hexutil.Big),
/* Type: */ command.Type,
/* CrossTxID: */ "",
)
return multiTransaction return multiTransaction
} }
@ -315,16 +308,9 @@ func (tm *TransactionManager) buildTransactions(bridges map[string]bridge.Bridge
} }
func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransaction, func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransaction,
data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, password string) ( data []*bridge.TransactionBridge, bridges map[string]bridge.Bridge, account *account.SelectedExtKey) (
map[uint64][]types.Hash, error) { map[uint64][]types.Hash, error) {
log.Info("Making transactions", "multiTransaction", multiTransaction)
selectedAccount, err := tm.getVerifiedWalletAccount(multiTransaction.FromAddress.Hex(), password)
if err != nil {
return nil, err
}
hashes := make(map[uint64][]types.Hash) hashes := make(map[uint64][]types.Hash)
for _, tx := range data { for _, tx := range data {
if tx.TransferTx != nil { if tx.TransferTx != nil {
@ -352,9 +338,9 @@ func (tm *TransactionManager) sendTransactions(multiTransaction *MultiTransactio
tx.SwapTx.Symbol = multiTransaction.FromAsset tx.SwapTx.Symbol = multiTransaction.FromAsset
} }
hash, err := bridges[tx.BridgeName].Send(tx, selectedAccount) hash, err := bridges[tx.BridgeName].Send(tx, account)
if err != nil { if err != nil {
return nil, err return nil, err // TODO: One of transfers within transaction could have been sent. Need to notify user about it
} }
hashes[tx.ChainID] = append(hashes[tx.ChainID], hash) hashes[tx.ChainID] = append(hashes[tx.ChainID], hash)
} }
@ -369,9 +355,9 @@ func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []wa
args[i] = v args[i] = v
} }
stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT rowid, %s stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s
FROM multi_transactions FROM multi_transactions
WHERE rowid in (%s)`, WHERE id in (%s)`,
selectMultiTransactionColumns, selectMultiTransactionColumns,
strings.Join(placeholders, ","))) strings.Join(placeholders, ",")))
if err != nil { if err != nil {
@ -389,7 +375,7 @@ func (tm *TransactionManager) GetMultiTransactions(ctx context.Context, ids []wa
} }
func (tm *TransactionManager) getBridgeMultiTransactions(ctx context.Context, toChainID uint64, crossTxID string) ([]*MultiTransaction, error) { func (tm *TransactionManager) getBridgeMultiTransactions(ctx context.Context, toChainID uint64, crossTxID string) ([]*MultiTransaction, error) {
stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT rowid, %s stmt, err := tm.db.Prepare(fmt.Sprintf(`SELECT %s
FROM multi_transactions FROM multi_transactions
WHERE type=? AND to_network_id=? AND cross_tx_id=?`, WHERE type=? AND to_network_id=? AND cross_tx_id=?`,
multiTransactionColumns)) multiTransactionColumns))
@ -439,37 +425,19 @@ func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.C
return nil, nil return nil, nil
} }
func (tm *TransactionManager) getVerifiedWalletAccount(address, password string) (*account.SelectedExtKey, error) { func idFromTimestamp() wallet_common.MultiTransactionIDType {
exists, err := tm.accountsDB.AddressExists(types.HexToAddress(address)) return wallet_common.MultiTransactionIDType(time.Now().UnixMilli())
if err != nil {
log.Error("failed to query db for a given address", "address", address, "error", err)
return nil, err
}
if !exists {
log.Error("failed to get a selected account", "err", transactions.ErrInvalidTxSender)
return nil, transactions.ErrAccountDoesntExist
}
key, err := tm.gethManager.VerifyAccountPassword(tm.config.KeyStoreDir, address, password)
if err != nil {
log.Error("failed to verify account", "account", address, "error", err)
return nil, err
}
return &account.SelectedExtKey{
Address: key.Address,
AccountKey: key,
}, nil
} }
var multiTransactionIDGenerator func() wallet_common.MultiTransactionIDType = idFromTimestamp
func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Address) error { func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Address) error {
// We must not remove those transactions, where from_address and to_address are different and both are stored in accounts DB // We must not remove those transactions, where from_address and to_address are different and both are stored in accounts DB
// and one of them is equal to the address, as we want to keep the records for the other address // and one of them is equal to the address, as we want to keep the records for the other address
// That is why we don't use cascade delete here with references to transfers table, as we might have 2 records in multi_transactions // That is why we don't use cascade delete here with references to transfers table, as we might have 2 records in multi_transactions
// for the same transaction, one for each address // for the same transaction, one for each address
stmt, err := tm.db.Prepare(`SELECT rowid, from_address, to_address stmt, err := tm.db.Prepare(`SELECT id, from_address, to_address
FROM multi_transactions FROM multi_transactions
WHERE from_address=? OR to_address=?`) WHERE from_address=? OR to_address=?`)
if err != nil { if err != nil {
@ -482,10 +450,10 @@ func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Add
} }
defer rows.Close() defer rows.Close()
rowIDs := make([]int, 0) ids := make([]int, 0)
rowID, fromAddress, toAddress := 0, common.Address{}, common.Address{} id, fromAddress, toAddress := 0, common.Address{}, common.Address{}
for rows.Next() { for rows.Next() {
err = rows.Scan(&rowID, &fromAddress, &toAddress) err = rows.Scan(&id, &fromAddress, &toAddress)
if err != nil { if err != nil {
log.Error("Failed to scan row", "error", err) log.Error("Failed to scan row", "error", err)
continue continue
@ -512,14 +480,14 @@ func (tm *TransactionManager) removeMultiTransactionByAddress(address common.Add
} }
} }
rowIDs = append(rowIDs, rowID) ids = append(ids, id)
} }
if len(rowIDs) > 0 { if len(ids) > 0 {
for _, rowID := range rowIDs { for _, id := range ids {
_, err = tm.db.Exec(`DELETE FROM multi_transactions WHERE rowid=?`, rowID) _, err = tm.db.Exec(`DELETE FROM multi_transactions WHERE id=?`, id)
if err != nil { if err != nil {
log.Error("Failed to remove multitransaction", "rowid", rowID, "error", err) log.Error("Failed to remove multitransaction", "id", id, "error", err)
} }
} }
} }

View File

@ -43,39 +43,39 @@ func TestBridgeMultiTransactions(t *testing.T) {
manager, stop := setupTestTransactionDB(t) manager, stop := setupTestTransactionDB(t)
defer stop() defer stop()
trx1 := MultiTransaction{ trx1 := NewMultiTransaction(
Timestamp: 123, /* Timestamp: */ 123,
FromNetworkID: 0, /* FromNetworkID: */ 0,
ToNetworkID: 1, /* ToNetworkID: */ 1,
FromTxHash: common.Hash{5}, /* FromTxHash: */ common.Hash{5},
// Empty ToTxHash /* // Empty ToTxHash */ common.Hash{},
FromAddress: common.Address{1}, /* FromAddress: */ common.Address{1},
ToAddress: common.Address{2}, /* ToAddress: */ common.Address{2},
FromAsset: "fromAsset", /* FromAsset: */ "fromAsset",
ToAsset: "toAsset", /* ToAsset: */ "toAsset",
FromAmount: (*hexutil.Big)(big.NewInt(123)), /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)),
ToAmount: (*hexutil.Big)(big.NewInt(234)), /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)),
Type: MultiTransactionBridge, /* Type: */ MultiTransactionBridge,
CrossTxID: "crossTxD1", /* CrossTxID: */ "crossTxD1",
} )
trx2 := MultiTransaction{ trx2 := NewMultiTransaction(
Timestamp: 321, /* Timestamp: */ 321,
FromNetworkID: 1, /* FromNetworkID: */ 1,
ToNetworkID: 0, /* ToNetworkID: */ 0,
//Empty FromTxHash /* //Empty FromTxHash */ common.Hash{},
ToTxHash: common.Hash{6}, /* ToTxHash: */ common.Hash{6},
FromAddress: common.Address{2}, /* FromAddress: */ common.Address{2},
ToAddress: common.Address{1}, /* ToAddress: */ common.Address{1},
FromAsset: "fromAsset", /* FromAsset: */ "fromAsset",
ToAsset: "toAsset", /* ToAsset: */ "toAsset",
FromAmount: (*hexutil.Big)(big.NewInt(123)), /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)),
ToAmount: (*hexutil.Big)(big.NewInt(234)), /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)),
Type: MultiTransactionBridge, /* Type: */ MultiTransactionBridge,
CrossTxID: "crossTxD2", /* CrossTxID: */ "crossTxD2",
} )
trxs := []*MultiTransaction{&trx1, &trx2} trxs := []*MultiTransaction{trx1, trx2}
var err error var err error
ids := make([]wallet_common.MultiTransactionIDType, len(trxs)) ids := make([]wallet_common.MultiTransactionIDType, len(trxs))
@ -88,7 +88,7 @@ func TestBridgeMultiTransactions(t *testing.T) {
rst, err := manager.GetBridgeOriginMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID) rst, err := manager.GetBridgeOriginMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, rst) require.NotEmpty(t, rst)
require.True(t, areMultiTransactionsEqual(&trx1, rst)) require.True(t, areMultiTransactionsEqual(trx1, rst))
rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID) rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx1.ToNetworkID, trx1.CrossTxID)
require.NoError(t, err) require.NoError(t, err)
@ -101,31 +101,34 @@ func TestBridgeMultiTransactions(t *testing.T) {
rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx2.ToNetworkID, trx2.CrossTxID) rst, err = manager.GetBridgeDestinationMultiTransaction(context.Background(), trx2.ToNetworkID, trx2.CrossTxID)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, rst) require.NotEmpty(t, rst)
require.True(t, areMultiTransactionsEqual(&trx2, rst)) require.True(t, areMultiTransactionsEqual(trx2, rst))
} }
func TestMultiTransactions(t *testing.T) { func TestMultiTransactions(t *testing.T) {
manager, stop := setupTestTransactionDB(t) manager, stop := setupTestTransactionDB(t)
defer stop() defer stop()
trx1 := MultiTransaction{ trx1 := *NewMultiTransaction(
Timestamp: 123, /* Timestamp: */ 123,
FromNetworkID: 0, /* FromNetworkID:*/ 0,
ToNetworkID: 1, /* ToNetworkID: */ 1,
FromTxHash: common.Hash{5}, /* FromTxHash: */ common.Hash{5},
ToTxHash: common.Hash{6}, /* ToTxHash: */ common.Hash{6},
FromAddress: common.Address{1}, /* FromAddress: */ common.Address{1},
ToAddress: common.Address{2}, /* ToAddress: */ common.Address{2},
FromAsset: "fromAsset", /* FromAsset: */ "fromAsset",
ToAsset: "toAsset", /* ToAsset: */ "toAsset",
FromAmount: (*hexutil.Big)(big.NewInt(123)), /* FromAmount: */ (*hexutil.Big)(big.NewInt(123)),
ToAmount: (*hexutil.Big)(big.NewInt(234)), /* ToAmount: */ (*hexutil.Big)(big.NewInt(234)),
Type: MultiTransactionBridge, /* Type: */ MultiTransactionBridge,
CrossTxID: "crossTxD", /* CrossTxID: */ "crossTxD",
} )
trx2 := trx1 trx2 := trx1
trx2.FromAmount = (*hexutil.Big)(big.NewInt(456)) trx2.FromAmount = (*hexutil.Big)(big.NewInt(456))
trx2.ToAmount = (*hexutil.Big)(big.NewInt(567)) trx2.ToAmount = (*hexutil.Big)(big.NewInt(567))
trx2.ID = generateMultiTransactionID()
require.NotEqual(t, trx1.ID, trx2.ID)
trxs := []*MultiTransaction{&trx1, &trx2} trxs := []*MultiTransaction{&trx1, &trx2}