test(wallet)_: implement build/sign/watch unit tests for multitransaction manager

Closes #14848
This commit is contained in:
Ivan Belyakov 2024-06-03 01:37:51 +02:00 committed by IvanBelyakoff
parent 1c116589cd
commit bf78c15e6f
6 changed files with 357 additions and 45 deletions

View File

@ -284,6 +284,13 @@ func (a *Keypair) Operability() AccountOperable {
return AccountFullyOperable
}
// TODO: implement clean full interface. This might require refactoring Database methods
type AccountsStorage interface {
GetKeypairByKeyUID(keyUID string) (*Keypair, error)
GetAccountByAddress(address types.Address) (*Account, error)
AddressExists(address types.Address) (bool, error)
}
// Database sql wrapper for operations with browser objects.
type Database struct {
settings.DatabaseSettingsManager

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"
@ -31,7 +30,6 @@ import (
"github.com/status-im/status-go/services/wallet/token"
"github.com/status-im/status-go/services/wallet/transfer"
"github.com/status-im/status-go/services/wallet/walletconnect"
"github.com/status-im/status-go/services/wallet/walletevent"
"github.com/status-im/status-go/transactions"
)
@ -271,38 +269,11 @@ func (api *API) GetPendingTransactionsForIdentities(ctx context.Context, identit
// TODO - #11861: Remove this and replace with EventPendingTransactionStatusChanged event and Delete to confirm the transaction where it is needed
func (api *API) WatchTransactionByChainID(ctx context.Context, chainID uint64, transactionHash common.Hash) (err error) {
log.Debug("wallet.api.WatchTransactionByChainID", "chainID", chainID, "transactionHash", transactionHash)
var status *transactions.TxStatus
defer func() {
log.Debug("wallet.api.WatchTransactionByChainID return", "err", err, "chainID", chainID, "transactionHash", transactionHash)
}()
// Workaround to keep the blocking call until the clients use the PendingTxTracker APIs
eventChan := make(chan walletevent.Event, 2)
sub := api.s.feed.Subscribe(eventChan)
defer sub.Unsubscribe()
status, err = api.s.pendingTxManager.Watch(ctx, wcommon.ChainID(chainID), transactionHash)
if err == nil && *status != transactions.Pending {
return nil
}
for {
select {
case we := <-eventChan:
if transactions.EventPendingTransactionStatusChanged == we.Type {
var p transactions.StatusChangedPayload
err = json.Unmarshal([]byte(we.Message), &p)
if err != nil {
return err
}
if p.ChainID == wcommon.ChainID(chainID) && p.Hash == transactionHash {
return nil
}
}
case <-time.After(10 * time.Minute):
return errors.New("timeout watching for pending transaction")
}
}
return api.s.transactionManager.WatchTransaction(ctx, chainID, transactionHash)
}
func (api *API) GetCryptoOnRamps(ctx context.Context) ([]onramp.CryptoOnRamp, error) {

View File

@ -38,7 +38,7 @@ type TransactionManager struct {
gethManager *account.GethManager
transactor transactions.TransactorIface
config *params.NodeConfig
accountsDB *accounts.Database
accountsDB accounts.AccountsStorage
pendingTracker *transactions.PendingTxTracker
eventFeed *event.Feed
@ -59,7 +59,7 @@ func NewTransactionManager(
gethManager *account.GethManager,
transactor transactions.TransactorIface,
config *params.NodeConfig,
accountsDB *accounts.Database,
accountsDB accounts.AccountsStorage,
pendingTxManager *transactions.PendingTxTracker,
eventFeed *event.Feed,
) *TransactionManager {
@ -160,6 +160,10 @@ func NewMultiTransaction(timestamp uint64, fromNetworkID, toNetworkID uint64, fr
}
func (tm *TransactionManager) SignMessage(message types.HexBytes, account *types.Key) (string, error) {
if account == nil || account.PrivateKey == nil {
return "", fmt.Errorf("account or private key is nil")
}
signature, err := crypto.Sign(message[:], account.PrivateKey)
return types.EncodeHex(signature), err

View File

@ -2,19 +2,30 @@ package transfer
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/account"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/services/wallet/bridge"
wallet_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/walletevent"
"github.com/status-im/status-go/signal"
"github.com/status-im/status-go/transactions"
)
const multiTransactionColumns = "id, from_network_id, from_tx_hash, from_address, from_asset, from_amount, to_network_id, to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp"
const selectMultiTransactionColumns = "id, COALESCE(from_network_id, 0), from_tx_hash, from_address, from_asset, from_amount, COALESCE(to_network_id, 0), to_tx_hash, to_address, to_asset, to_amount, type, cross_tx_id, timestamp"
var pendingTxTimeout time.Duration = 10 * time.Minute
var ErrWatchPendingTxTimeout = errors.New("timeout watching for pending transaction")
var ErrPendingTxNotExists = errors.New("pending transaction does not exist")
func (tm *TransactionManager) InsertMultiTransaction(multiTransaction *MultiTransaction) (wallet_common.MultiTransactionIDType, error) {
return multiTransaction.ID, tm.storage.CreateMultiTransaction(multiTransaction)
}
@ -149,3 +160,34 @@ func (tm *TransactionManager) GetBridgeDestinationMultiTransaction(ctx context.C
return nil, nil
}
func (tm *TransactionManager) WatchTransaction(ctx context.Context, chainID uint64, transactionHash common.Hash) error {
// Workaround to keep the blocking call until the clients use the PendingTxTracker APIs
eventChan := make(chan walletevent.Event, 2)
sub := tm.eventFeed.Subscribe(eventChan)
defer sub.Unsubscribe()
status, err := tm.pendingTracker.Watch(ctx, wallet_common.ChainID(chainID), transactionHash)
if err == nil && *status != transactions.Pending {
log.Error("transaction is not pending", "status", status)
return nil
}
for {
select {
case we := <-eventChan:
if transactions.EventPendingTransactionStatusChanged == we.Type {
var p transactions.StatusChangedPayload
err = json.Unmarshal([]byte(we.Message), &p)
if err != nil {
return err
}
if p.ChainID == wallet_common.ChainID(chainID) && p.Hash == transactionHash {
return nil
}
}
case <-time.After(pendingTxTimeout):
return ErrWatchPendingTxTimeout
}
}
}

View File

@ -2,21 +2,28 @@ package transfer
import (
"context"
"encoding/json"
"math/big"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/event"
"github.com/status-im/status-go/account"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/services/wallet/bridge"
"github.com/status-im/status-go/services/wallet/bridge/mock_bridge"
wallet_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/services/wallet/walletevent"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/transactions"
"github.com/status-im/status-go/transactions/mock_transactor"
"github.com/status-im/status-go/walletdatabase"
)
func deepCopy(tx *transactions.SendTxArgs) *transactions.SendTxArgs {
@ -146,7 +153,7 @@ func TestSendTransactionsETHFailOnBridge(t *testing.T) {
// Call the SendTransactions method
_, err := tm.SendTransactions(context.Background(), multiTransaction, data, bridges, account)
require.Error(t, expectedErr, err)
require.ErrorIs(t, expectedErr, err)
}
func TestSendTransactionsETHFailOnTransactor(t *testing.T) {
@ -162,5 +169,78 @@ func TestSendTransactionsETHFailOnTransactor(t *testing.T) {
// Call the SendTransactions method
_, err := tm.SendTransactions(context.Background(), multiTransaction, data, bridges, account)
require.Error(t, expectedErr, err)
require.ErrorIs(t, expectedErr, err)
}
func TestWatchTransaction(t *testing.T) {
tm, _, _ := setupTransactionManager(t)
chainID := uint64(1)
pendingTxTimeout = 2 * time.Millisecond
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
chainClient := transactions.NewMockChainClient()
eventFeed := &event.Feed{}
// For now, pending tracker is not interface, so we have to use a real one
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout)
tm.eventFeed = eventFeed
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 2*pendingTxTimeout)
defer cancel()
// Insert a pending transaction
txs := transactions.MockTestTransactions(t, chainClient, []transactions.TestTxSummary{{}})
err = tm.pendingTracker.StoreAndTrackPendingTx(&txs[0]) // We dont need to track it, but no other way to insert it
require.NoError(t, err)
txEventPayload := transactions.StatusChangedPayload{
TxIdentity: transactions.TxIdentity{
Hash: txs[0].Hash,
ChainID: wallet_common.ChainID(chainID),
},
Status: transactions.Pending,
}
jsonPayload, err := json.Marshal(txEventPayload)
require.NoError(t, err)
go func() {
time.Sleep(pendingTxTimeout / 2)
eventFeed.Send(walletevent.Event{
Type: transactions.EventPendingTransactionStatusChanged,
Message: string(jsonPayload),
})
}()
// Call the WatchTransaction method
err = tm.WatchTransaction(ctx, chainID, txs[0].Hash)
require.NoError(t, err)
}
func TestWatchTransaction_Timeout(t *testing.T) {
tm, _, _ := setupTransactionManager(t)
chainID := uint64(1)
transactionHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
pendingTxTimeout = 2 * time.Millisecond
walletDB, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
chainClient := transactions.NewMockChainClient()
eventFeed := &event.Feed{}
// For now, pending tracker is not interface, so we have to use a real one
tm.pendingTracker = transactions.NewPendingTxTracker(walletDB, chainClient, nil, eventFeed, pendingTxTimeout)
tm.eventFeed = eventFeed
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
// Insert a pending transaction
txs := transactions.MockTestTransactions(t, chainClient, []transactions.TestTxSummary{{}})
err = tm.pendingTracker.StoreAndTrackPendingTx(&txs[0]) // We dont need to track it, but no other way to insert it
require.NoError(t, err)
// Call the WatchTransaction method
err = tm.WatchTransaction(ctx, chainID, transactionHash)
require.ErrorIs(t, err, ErrWatchPendingTxTimeout)
}

View File

@ -2,25 +2,80 @@ package transfer
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"
"math/big"
"reflect"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
gethtypes "github.com/ethereum/go-ethereum/core/types"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/multiaccounts/accounts"
wallet_common "github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/walletdatabase"
"github.com/status-im/status-go/transactions"
"github.com/status-im/status-go/transactions/mock_transactor"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
)
func setupTestTransactionDB(t *testing.T) (*TransactionManager, func()) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
type dummyAccountsStorage struct {
keypair *accounts.Keypair
account *accounts.Account
}
func (d *dummyAccountsStorage) GetAccountByAddress(address types.Address) (*accounts.Account, error) {
if address != d.account.Address {
return nil, fmt.Errorf("address not found")
}
return d.account, nil
}
func (d *dummyAccountsStorage) GetKeypairByKeyUID(keyUID string) (*accounts.Keypair, error) {
if keyUID != d.keypair.KeyUID {
return nil, fmt.Errorf("keyUID not found")
}
return d.keypair, nil
}
func (d *dummyAccountsStorage) AddressExists(address types.Address) (bool, error) {
return d.account.Address == address, nil
}
type dummySigner struct{}
func (d *dummySigner) Hash(tx *gethtypes.Transaction) common.Hash {
return common.HexToHash("0xc8e7a34af766c4ba9dc9b3d49939806fbf41fa01250c5a26afa5659e87b2020b")
}
func setupTestSuite(t *testing.T) (*TransactionManager, *mock_transactor.MockTransactorIface) {
SetMultiTransactionIDGenerator(StaticIDCounter()) // to have different multi-transaction IDs even with fast execution
return &TransactionManager{NewMultiTransactionDB(db), nil, nil, nil, nil, nil, nil, nil, nil, nil}, func() {
require.NoError(t, db.Close())
accountsDB := setupAccountsStorage()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
transactor := mock_transactor.NewMockTransactorIface(ctrl)
return &TransactionManager{
storage: NewInMemMultiTransactionStorage(),
accountsDB: accountsDB,
transactor: transactor,
}, transactor
}
func setupAccountsStorage() *dummyAccountsStorage {
return &dummyAccountsStorage{
keypair: &accounts.Keypair{
KeyUID: "keyUid",
},
account: &accounts.Account{
KeyUID: "keyUid",
Address: types.Address{1},
},
}
}
@ -41,8 +96,7 @@ func areMultiTransactionsEqual(mt1, mt2 *MultiTransaction) bool {
}
func TestBridgeMultiTransactions(t *testing.T) {
manager, stop := setupTestTransactionDB(t)
defer stop()
manager, _ := setupTestSuite(t)
trx1 := NewMultiTransaction(
/* Timestamp: */ 123,
@ -105,8 +159,7 @@ func TestBridgeMultiTransactions(t *testing.T) {
}
func TestMultiTransactions(t *testing.T) {
manager, stop := setupTestTransactionDB(t)
defer stop()
manager, _ := setupTestSuite(t)
trx1 := *NewMultiTransaction(
/* Timestamp: */ 123,
@ -165,3 +218,158 @@ func TestMultiTransactions(t *testing.T) {
require.True(t, found, "result contains transaction with id %d", id)
}
}
func TestSignMessage(t *testing.T) {
tm, _ := setupTestSuite(t)
message := (types.HexBytes)(make([]byte, 32))
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
account := &types.Key{
PrivateKey: privateKey,
}
signature, err := tm.SignMessage(message, account)
require.NoError(t, err)
require.NotEmpty(t, signature)
}
func TestSignMessage_InvalidAccount(t *testing.T) {
tm, _ := setupTestSuite(t)
message := (types.HexBytes)(make([]byte, 32))
account := &types.Key{
PrivateKey: nil,
}
signature, err := tm.SignMessage(message, account)
require.Error(t, err)
require.Empty(t, signature)
}
func TestSignMessage_InvalidMessage(t *testing.T) {
tm, _ := setupTestSuite(t)
message := types.HexBytes{}
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
account := &types.Key{
PrivateKey: privateKey,
}
signature, err := tm.SignMessage(message, account)
require.Error(t, err)
require.Equal(t, "0x", signature)
}
func TestBuildTransaction(t *testing.T) {
manager, transactor := setupTestSuite(t)
chainID := uint64(1)
nonce := uint64(1)
gas := uint64(21000)
sendArgs := transactions.SendTxArgs{
From: types.Address{1},
To: &types.Address{2},
Value: (*hexutil.Big)(big.NewInt(123)),
Nonce: (*hexutil.Uint64)(&nonce),
Gas: (*hexutil.Uint64)(&gas),
GasPrice: (*hexutil.Big)(big.NewInt(1000000000)),
MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)),
MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)),
}
expectedTx := gethtypes.NewTransaction(nonce, common.Address(*sendArgs.To), sendArgs.Value.ToInt(), gas, sendArgs.GasPrice.ToInt(), nil)
transactor.EXPECT().ValidateAndBuildTransaction(chainID, sendArgs).Return(expectedTx, nil)
response, err := manager.BuildTransaction(chainID, sendArgs)
require.NoError(t, err)
require.NotNil(t, response)
accDB := manager.accountsDB.(*dummyAccountsStorage)
signer := dummySigner{}
expectedKeyUID := accDB.keypair.KeyUID
expectedAddress := accDB.account.Address
expectedAddressPath := ""
expectedSignOnKeycard := false
expectedMessageToSign := signer.Hash(expectedTx)
require.Equal(t, expectedKeyUID, response.KeyUID)
require.Equal(t, expectedAddress, response.Address)
require.Equal(t, expectedAddressPath, response.AddressPath)
require.Equal(t, expectedSignOnKeycard, response.SignOnKeycard)
require.Equal(t, chainID, response.ChainID)
require.Equal(t, expectedMessageToSign, response.MessageToSign)
require.True(t, reflect.DeepEqual(sendArgs, response.TxArgs))
}
func TestBuildTransaction_AccountNotFound(t *testing.T) {
manager, _ := setupTestSuite(t)
chainID := uint64(1)
nonce := uint64(1)
gas := uint64(21000)
sendArgs := transactions.SendTxArgs{
From: types.Address{2},
To: &types.Address{2},
Value: (*hexutil.Big)(big.NewInt(123)),
Nonce: (*hexutil.Uint64)(&nonce),
Gas: (*hexutil.Uint64)(&gas),
GasPrice: (*hexutil.Big)(big.NewInt(1000000000)),
MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)),
MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)),
}
_, err := manager.BuildTransaction(chainID, sendArgs)
require.Error(t, err)
}
func TestBuildTransaction_InvalidSendTxArgs(t *testing.T) {
manager, transactor := setupTestSuite(t)
chainID := uint64(1)
sendArgs := transactions.SendTxArgs{
From: types.Address{1},
To: &types.Address{2},
}
expectedErr := fmt.Errorf("invalid SendTxArgs")
transactor.EXPECT().ValidateAndBuildTransaction(chainID, sendArgs).Return(nil, expectedErr)
tx, err := manager.BuildTransaction(chainID, sendArgs)
require.Equal(t, expectedErr, err)
require.Nil(t, tx)
}
func TestBuildRawTransaction(t *testing.T) {
manager, transactor := setupTestSuite(t)
chainID := uint64(1)
nonce := uint64(1)
gas := uint64(21000)
sendArgs := transactions.SendTxArgs{
From: types.Address{1},
To: &types.Address{2},
Value: (*hexutil.Big)(big.NewInt(123)),
Nonce: (*hexutil.Uint64)(&nonce),
Gas: (*hexutil.Uint64)(&gas),
GasPrice: (*hexutil.Big)(big.NewInt(1000000000)),
MaxFeePerGas: (*hexutil.Big)(big.NewInt(2000000000)),
MaxPriorityFeePerGas: (*hexutil.Big)(big.NewInt(1000000000)),
}
expectedTx := gethtypes.NewTransaction(1, common.Address(*sendArgs.To), sendArgs.Value.ToInt(), 21000, sendArgs.GasPrice.ToInt(), nil)
signature := []byte("signature")
transactor.EXPECT().BuildTransactionWithSignature(chainID, sendArgs, signature).Return(expectedTx, nil)
response, err := manager.BuildRawTransaction(chainID, sendArgs, signature)
require.NoError(t, err)
require.NotNil(t, response)
expectedData, _ := expectedTx.MarshalBinary()
expectedHash := expectedTx.Hash()
require.Equal(t, chainID, response.ChainID)
require.Equal(t, sendArgs, response.TxArgs)
require.Equal(t, types.EncodeHex(expectedData), response.RawTx)
require.Equal(t, expectedHash, response.TxHash)
}