diff --git a/services/wallet/activity/activity.go b/services/wallet/activity/activity.go index 979ed1989..1c870ec3d 100644 --- a/services/wallet/activity/activity.go +++ b/services/wallet/activity/activity.go @@ -725,6 +725,8 @@ func getTrInAndOutAmounts(activityType Type, trAmount sql.NullString, pTrAmount if ok { switch activityType { + case ApproveAT: + fallthrough case ContractDeploymentAT: fallthrough case SendAT: diff --git a/services/wallet/activity/activity_test.go b/services/wallet/activity/activity_test.go index 06f349726..03fd9b0fb 100644 --- a/services/wallet/activity/activity_test.go +++ b/services/wallet/activity/activity_test.go @@ -108,6 +108,7 @@ type testData struct { multiTx1Tr2 transfer.TestTransfer // index 5, USDC/Optimism multiTx2Tr2 transfer.TestTransfer // index 6, SNT/Mainnet multiTx2PendingTr transfer.TestTransfer // index 7, DAI/Mainnet + multiTx3Tr1 transfer.TestTransfer // index 8, DAI/Goerli multiTx1 transfer.MultiTransaction multiTx1ID common.MultiTransactionIDType @@ -115,14 +116,17 @@ type testData struct { multiTx2 transfer.MultiTransaction multiTx2ID common.MultiTransactionIDType + multiTx3 transfer.MultiTransaction + multiTx3ID common.MultiTransactionIDType + nextIndex int } -// Generates and adds to the DB 7 transfers and 2 multitransactions. -// There are only 4 extractable activity entries (transactions + multi-transactions) with timestamps 1-4. The others are associated with a multi-transaction +// Generates and adds to the DB 8 transfers and 3 multitransactions. +// There are only 5 extractable activity entries (transactions + multi-transactions) with timestamps 1-5. The others are associated with a multi-transaction func fillTestData(t *testing.T, db *sql.DB) (td testData, fromAddresses, toAddresses []eth.Address) { - // Generates ETH/Goerli, ETH/Optimism, USDC/Mainnet, USDC/Goerli, USDC/Optimism, SNT/Mainnet, DAI/Mainnet - trs, fromAddresses, toAddresses := transfer.GenerateTestTransfers(t, db, 1, 7) + // Generates ETH/Goerli, ETH/Optimism, USDC/Mainnet, USDC/Goerli, USDC/Optimism, SNT/Mainnet, DAI/Mainnet, DAI/Goerli + trs, fromAddresses, toAddresses := transfer.GenerateTestTransfers(t, db, 1, 8) // Plain transfer td.tr1 = trs[0] @@ -134,7 +138,7 @@ func fillTestData(t *testing.T, db *sql.DB) (td testData, fromAddresses, toAddre // Send Multitransaction containing 2 x Plain transfers td.multiTx1Tr1 = trs[2] - td.multiTx1Tr2 = trs[4] + td.multiTx1Tr2 = trs[7] td.multiTx1 = transfer.GenerateTestSendMultiTransaction(td.multiTx1Tr1) td.multiTx1.ToAsset = testutils.DaiSymbol @@ -165,7 +169,17 @@ func fillTestData(t *testing.T, db *sql.DB) (td testData, fromAddresses, toAddre td.multiTx2PendingTr.MultiTransactionID = td.multiTx2ID transfer.InsertTestPendingTransaction(t, db, &td.multiTx2PendingTr) - td.nextIndex = 8 + // Approve Multitransaction containing 1 x Plain transfer + td.multiTx3Tr1 = trs[4] + + td.multiTx3 = transfer.GenerateTestApproveMultiTransaction(td.multiTx3Tr1) + + td.multiTx3ID = transfer.InsertTestMultiTransaction(t, db, &td.multiTx3) + + td.multiTx3Tr1.MultiTransactionID = td.multiTx3ID + transfer.InsertTestTransfer(t, db, td.multiTx3Tr1.From, &td.multiTx3Tr1) + + td.nextIndex = 9 return td, fromAddresses, toAddresses } @@ -201,85 +215,107 @@ func TestGetActivityEntriesAll(t *testing.T) { var filter Filter entries, err := getActivityEntries(context.Background(), deps, append(toAddresses, fromAddresses...), true, []common.ChainID{}, filter, 0, 10) require.NoError(t, err) - require.Equal(t, 4, len(entries)) + require.Equal(t, 5, len(entries)) // Ensure we have the correct order - var curTimestamp int64 = 4 + var curTimestamp int64 = 5 for _, entry := range entries { require.Equal(t, curTimestamp, entry.timestamp, "entries are sorted by timestamp; expected %d, got %d", curTimestamp, entry.timestamp) curTimestamp-- } - require.Equal(t, Entry{ - payloadType: SimpleTransactionPT, - transaction: &transfer.TransactionIdentity{ChainID: td.tr1.ChainID, Hash: td.tr1.Hash, Address: td.tr1.To}, - id: td.tr1.MultiTransactionID, - timestamp: td.tr1.Timestamp, - activityType: ReceiveAT, - activityStatus: FinalizedAS, - amountOut: (*hexutil.Big)(big.NewInt(0)), - amountIn: (*hexutil.Big)(big.NewInt(td.tr1.Value)), - tokenOut: nil, - tokenIn: TTrToToken(t, &td.tr1.TestTransaction), - symbolOut: nil, - symbolIn: common.NewAndSet("ETH"), - sender: &td.tr1.From, - recipient: &td.tr1.To, - chainIDOut: nil, - chainIDIn: &td.tr1.ChainID, - transferType: expectedTokenType(td.tr1.Token.Address), - }, entries[3]) - require.Equal(t, Entry{ - payloadType: PendingTransactionPT, - transaction: &transfer.TransactionIdentity{ChainID: td.pendingTr.ChainID, Hash: td.pendingTr.Hash}, - id: td.pendingTr.MultiTransactionID, - timestamp: td.pendingTr.Timestamp, - activityType: SendAT, - activityStatus: PendingAS, - amountOut: (*hexutil.Big)(big.NewInt(td.pendingTr.Value)), - amountIn: (*hexutil.Big)(big.NewInt(0)), - tokenOut: TTrToToken(t, &td.pendingTr.TestTransaction), - tokenIn: nil, - symbolOut: common.NewAndSet("ETH"), - symbolIn: nil, - sender: &td.pendingTr.From, - recipient: &td.pendingTr.To, - chainIDOut: &td.pendingTr.ChainID, - chainIDIn: nil, - transferType: expectedTokenType(eth.Address{}), - }, entries[2]) - require.Equal(t, Entry{ - payloadType: MultiTransactionPT, - transaction: nil, - id: td.multiTx1ID, - timestamp: int64(td.multiTx1.Timestamp), - activityType: SendAT, - activityStatus: FinalizedAS, - amountOut: td.multiTx1.FromAmount, - amountIn: td.multiTx1.ToAmount, - tokenOut: tokenFromSymbol(nil, td.multiTx1.FromAsset), - tokenIn: tokenFromSymbol(nil, td.multiTx1.ToAsset), - symbolOut: common.NewAndSet("USDC"), - symbolIn: common.NewAndSet("DAI"), - sender: &td.multiTx1.FromAddress, - recipient: &td.multiTx1.ToAddress, - }, entries[1]) - require.Equal(t, Entry{ - payloadType: MultiTransactionPT, - transaction: nil, - id: td.multiTx2ID, - timestamp: int64(td.multiTx2.Timestamp), - activityType: SendAT, - activityStatus: PendingAS, - amountOut: td.multiTx2.FromAmount, - amountIn: td.multiTx2.ToAmount, - symbolOut: common.NewAndSet("USDC"), - symbolIn: common.NewAndSet("SNT"), - tokenOut: tokenFromSymbol(nil, td.multiTx2.FromAsset), - tokenIn: tokenFromSymbol(nil, td.multiTx2.ToAsset), - sender: &td.multiTx2.FromAddress, - recipient: &td.multiTx2.ToAddress, - }, entries[0]) + expectedEntries := []Entry{ + Entry{ + payloadType: MultiTransactionPT, + transaction: nil, + id: td.multiTx3ID, + timestamp: int64(td.multiTx3.Timestamp), + activityType: ApproveAT, + activityStatus: FinalizedAS, + amountOut: td.multiTx3.FromAmount, + amountIn: td.multiTx3.ToAmount, + tokenOut: tokenFromSymbol(nil, td.multiTx3.FromAsset), + tokenIn: tokenFromSymbol(nil, td.multiTx3.ToAsset), + symbolOut: common.NewAndSet("USDC"), + symbolIn: common.NewAndSet("USDC"), + sender: &td.multiTx3.FromAddress, + recipient: &td.multiTx3.ToAddress, + }, + Entry{ + payloadType: MultiTransactionPT, + transaction: nil, + id: td.multiTx2ID, + timestamp: int64(td.multiTx2.Timestamp), + activityType: SendAT, + activityStatus: PendingAS, + amountOut: td.multiTx2.FromAmount, + amountIn: td.multiTx2.ToAmount, + symbolOut: common.NewAndSet("USDC"), + symbolIn: common.NewAndSet("SNT"), + tokenOut: tokenFromSymbol(nil, td.multiTx2.FromAsset), + tokenIn: tokenFromSymbol(nil, td.multiTx2.ToAsset), + sender: &td.multiTx2.FromAddress, + recipient: &td.multiTx2.ToAddress, + }, + Entry{ + payloadType: MultiTransactionPT, + transaction: nil, + id: td.multiTx1ID, + timestamp: int64(td.multiTx1.Timestamp), + activityType: SendAT, + activityStatus: FinalizedAS, + amountOut: td.multiTx1.FromAmount, + amountIn: td.multiTx1.ToAmount, + tokenOut: tokenFromSymbol(nil, td.multiTx1.FromAsset), + tokenIn: tokenFromSymbol(nil, td.multiTx1.ToAsset), + symbolOut: common.NewAndSet("USDC"), + symbolIn: common.NewAndSet("DAI"), + sender: &td.multiTx1.FromAddress, + recipient: &td.multiTx1.ToAddress, + }, + Entry{ + payloadType: PendingTransactionPT, + transaction: &transfer.TransactionIdentity{ChainID: td.pendingTr.ChainID, Hash: td.pendingTr.Hash}, + id: td.pendingTr.MultiTransactionID, + timestamp: td.pendingTr.Timestamp, + activityType: SendAT, + activityStatus: PendingAS, + amountOut: (*hexutil.Big)(big.NewInt(td.pendingTr.Value)), + amountIn: (*hexutil.Big)(big.NewInt(0)), + tokenOut: TTrToToken(t, &td.pendingTr.TestTransaction), + tokenIn: nil, + symbolOut: common.NewAndSet("ETH"), + symbolIn: nil, + sender: &td.pendingTr.From, + recipient: &td.pendingTr.To, + chainIDOut: &td.pendingTr.ChainID, + chainIDIn: nil, + transferType: expectedTokenType(eth.Address{}), + }, + Entry{ + payloadType: SimpleTransactionPT, + transaction: &transfer.TransactionIdentity{ChainID: td.tr1.ChainID, Hash: td.tr1.Hash, Address: td.tr1.To}, + id: td.tr1.MultiTransactionID, + timestamp: td.tr1.Timestamp, + activityType: ReceiveAT, + activityStatus: FinalizedAS, + amountOut: (*hexutil.Big)(big.NewInt(0)), + amountIn: (*hexutil.Big)(big.NewInt(td.tr1.Value)), + tokenOut: nil, + tokenIn: TTrToToken(t, &td.tr1.TestTransaction), + symbolOut: nil, + symbolIn: common.NewAndSet("ETH"), + sender: &td.tr1.From, + recipient: &td.tr1.To, + chainIDOut: nil, + chainIDIn: &td.tr1.ChainID, + transferType: expectedTokenType(td.tr1.Token.Address), + }, + } + + for idx, expectedEntry := range expectedEntries { + require.Equal(t, expectedEntry, entries[idx], "entry %d", idx) + } } // TestGetActivityEntriesWithSenderFilter covers the corner-case of having both sender and receiver in the filter. @@ -323,7 +359,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) { td, fromTds, toTds := fillTestData(t, deps.db) - // Add 6 extractable transactions with timestamps 6-12 + // Add 6 extractable transactions with timestamps 7-13 trs, fromTrs, toTrs := transfer.GenerateTestTransfers(t, deps.db, td.nextIndex, 6) for i := range trs { transfer.InsertTestTransfer(t, deps.db, trs[i].To, &trs[i]) @@ -337,7 +373,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) { filter.Period.EndTimestamp = NoLimitTimestampForPeriod entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 8, len(entries)) + require.Equal(t, 9, len(entries)) const simpleTrIndex = 5 // Check start and end content @@ -378,13 +414,13 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) { chainIDOut: nil, chainIDIn: nil, transferType: nil, - }, entries[7]) + }, entries[8]) // Test complete interval filter.Period.EndTimestamp = trs[2].Timestamp entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 5, len(entries)) + require.Equal(t, 6, len(entries)) // Check start and end content require.Equal(t, Entry{ @@ -424,13 +460,13 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) { chainIDOut: nil, chainIDIn: nil, transferType: nil, - }, entries[4]) + }, entries[5]) // Test end only filter.Period.StartTimestamp = NoLimitTimestampForPeriod entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 7, len(entries)) + require.Equal(t, 8, len(entries)) // Check start and end content require.Equal(t, Entry{ payloadType: SimpleTransactionPT, @@ -469,7 +505,7 @@ func TestGetActivityEntriesFilterByTime(t *testing.T) { chainIDOut: nil, chainIDIn: &td.tr1.ChainID, transferType: expectedTokenType(td.tr1.Token.Address), - }, entries[6]) + }, entries[7]) } func TestGetActivityEntriesCheckOffsetAndLimit(t *testing.T) { @@ -821,7 +857,7 @@ func TestStatusMintCustomEvent(t *testing.T) { entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 7, len(entries)) + require.Equal(t, 8, len(entries)) filter.Types = []Type{MintAT} @@ -849,7 +885,7 @@ func TestGetActivityEntriesFilterByAddresses(t *testing.T) { entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 10, len(entries)) + require.Equal(t, 11, len(entries)) addressesFilter := []eth.Address{td.multiTx1.ToAddress, td.multiTx2.FromAddress, trs[1].From, trs[4].From, trs[3].To} // The td.multiTx1.ToAddress and trs[3].To are missing not having them as owner address @@ -944,7 +980,7 @@ func TestGetActivityEntriesFilterByStatus(t *testing.T) { filter.Statuses = allActivityStatusesFilter() entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 11, len(entries)) + require.Equal(t, 12, len(entries)) filter.Statuses = []Status{PendingAS} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) @@ -967,7 +1003,7 @@ func TestGetActivityEntriesFilterByStatus(t *testing.T) { filter.Statuses = []Status{FinalizedAS} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 4, len(entries)) + require.Equal(t, 5, len(entries)) // Combined filter filter.Statuses = []Status{FailedAS, PendingAS} @@ -1004,7 +1040,7 @@ func TestGetActivityEntriesFilterByTokenType(t *testing.T) { filter.Assets = allTokensFilter() entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 13, len(entries)) + require.Equal(t, 14, len(entries)) // Native tokens are network agnostic, hence all are returned filter.Assets = []Token{{TokenType: Native, ChainID: common.ChainID(transfer.EthMainnet.ChainID)}} @@ -1025,8 +1061,8 @@ func TestGetActivityEntriesFilterByTokenType(t *testing.T) { }} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - // Two MT for which ChainID is ignored and one transfer on the main net and the Goerli is ignored - require.Equal(t, 3, len(entries)) + // Three MT for which ChainID is ignored and one transfer on the main net and the Goerli is ignored + require.Equal(t, 4, len(entries)) require.Equal(t, Erc20, entries[0].tokenIn.TokenType) require.Equal(t, transfer.UsdcMainnet.Address, entries[0].tokenIn.Address) require.Nil(t, entries[0].tokenOut) @@ -1034,11 +1070,13 @@ func TestGetActivityEntriesFilterByTokenType(t *testing.T) { require.Equal(t, Erc20, entries[1].tokenOut.TokenType) require.Equal(t, transfer.UsdcMainnet.Address, entries[1].tokenOut.Address) require.Equal(t, Erc20, entries[1].tokenIn.TokenType) - require.Equal(t, transfer.SntMainnet.Address, entries[1].tokenIn.Address) + require.Equal(t, transfer.UsdcMainnet.Address, entries[1].tokenIn.Address) require.Equal(t, Erc20, entries[2].tokenOut.TokenType) - require.Equal(t, transfer.UsdcMainnet.Address, entries[1].tokenOut.Address) + require.Equal(t, transfer.UsdcMainnet.Address, entries[2].tokenOut.Address) require.Equal(t, Erc20, entries[2].tokenIn.TokenType) - require.Equal(t, transfer.UsdcMainnet.Address, entries[1].tokenOut.Address) + require.Equal(t, transfer.SntMainnet.Address, entries[2].tokenIn.Address) + require.Equal(t, Erc20, entries[3].tokenOut.TokenType) + require.Equal(t, transfer.UsdcMainnet.Address, entries[3].tokenOut.Address) filter.Assets = []Token{{ TokenType: Erc20, @@ -1051,8 +1089,8 @@ func TestGetActivityEntriesFilterByTokenType(t *testing.T) { }} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - // Two MT for which ChainID is ignored and two transfers on the main net and Goerli - require.Equal(t, 4, len(entries)) + // Three MT for which ChainID is ignored and two transfers on the main net and Goerli + require.Equal(t, 5, len(entries)) require.Equal(t, Erc20, entries[0].tokenIn.TokenType) require.Equal(t, transfer.UsdcGoerli.Address, entries[0].tokenIn.Address) require.Nil(t, entries[0].tokenOut) @@ -1087,7 +1125,7 @@ func TestGetActivityEntriesFilterByCollectibles(t *testing.T) { filter.Collectibles = allTokensFilter() entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 8, len(entries)) + require.Equal(t, 9, len(entries)) // Search for a specific collectible filter.Collectibles = []Token{tokenFromCollectible(&transfer.TestCollectibles[0])} @@ -1134,7 +1172,7 @@ func TestGetActivityEntriesFilterByToAddresses(t *testing.T) { filter.CounterpartyAddresses = allAddresses entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 10, len(entries)) + require.Equal(t, 11, len(entries)) filter.CounterpartyAddresses = []eth.Address{eth.HexToAddress("0x567890")} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, []common.ChainID{}, filter, 0, 15) @@ -1185,11 +1223,12 @@ func TestGetActivityEntriesFilterByNetworks(t *testing.T) { if td.multiTx2PendingTr.ChainID != td.multiTx2Tr1.ChainID && td.multiTx2PendingTr.ChainID != td.multiTx2Tr2.ChainID { recordPresence(td.multiTx2PendingTr.ChainID, 3) } + recordPresence(td.multiTx3Tr1.ChainID, 4) // Add 6 extractable transactions trs, fromTrs, toTrs := transfer.GenerateTestTransfers(t, deps.db, td.nextIndex, 6) for i := range trs { - recordPresence(trs[i].ChainID, 4+i) + recordPresence(trs[i].ChainID, 5+i) transfer.InsertTestTransfer(t, deps.db, trs[i].To, &trs[i]) } allAddresses := append(append(append(fromTds, toTds...), fromTrs...), toTrs...) @@ -1198,14 +1237,14 @@ func TestGetActivityEntriesFilterByNetworks(t *testing.T) { chainIDs := allNetworksFilter() entries, err := getActivityEntries(context.Background(), deps, allAddresses, true, chainIDs, filter, 0, 15) require.NoError(t, err) - require.Equal(t, 10, len(entries)) + require.Equal(t, 11, len(entries)) chainIDs = []common.ChainID{5674839210} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, chainIDs, filter, 0, 15) require.NoError(t, err) require.Equal(t, 0, len(entries)) - chainIDs = []common.ChainID{td.pendingTr.ChainID, td.multiTx2Tr1.ChainID, trs[3].ChainID} + chainIDs = []common.ChainID{td.pendingTr.ChainID, td.multiTx2Tr1.ChainID, trs[3].ChainID, td.multiTx3Tr1.ChainID} entries, err = getActivityEntries(context.Background(), deps, allAddresses, true, chainIDs, filter, 0, 15) require.NoError(t, err) expectedResults := make(map[int]int) diff --git a/services/wallet/transfer/commands.go b/services/wallet/transfer/commands.go index 43ce64f66..aa97335fc 100644 --- a/services/wallet/transfer/commands.go +++ b/services/wallet/transfer/commands.go @@ -240,18 +240,20 @@ func (c *transfersCommand) Run(ctx context.Context) (err error) { c.processUnknownErc20CommunityTransactions(ctx, allTransfers) - err = c.processMultiTransactions(ctx, allTransfers) - if err != nil { - log.Error("processMultiTransactions error", "error", err) - return err - } - if len(allTransfers) > 0 { + // First, try to match to any pre-existing pending/multi-transaction err := c.saveAndConfirmPending(allTransfers, blockNum) if err != nil { log.Error("saveAndConfirmPending error", "error", err) return err } + + // Check if multi transaction needs to be created + err = c.processMultiTransactions(ctx, allTransfers) + if err != nil { + log.Error("processMultiTransactions error", "error", err) + return err + } } else { // If no transfers found, that is suspecting, because downloader returned this block as containing transfers log.Error("no transfers found in block", "chain", c.chainClient.NetworkID(), "address", c.address, "block", blockNum) @@ -467,6 +469,11 @@ func (c *transfersCommand) processMultiTransactions(ctx context.Context, allTran // Detect / Generate multitransactions // Iterate over all detected transactions for _, tx := range txByTxHash { + // Check if already matched to a multi transaction + if tx[0].MultiTransactionID > 0 { + continue + } + // Then check for a Swap transaction txProcessed, err := c.checkAndProcessSwapMultiTx(ctx, tx) if err != nil { diff --git a/services/wallet/transfer/helpers.go b/services/wallet/transfer/helpers.go index 6d7f10f7b..6a8ee9547 100644 --- a/services/wallet/transfer/helpers.go +++ b/services/wallet/transfer/helpers.go @@ -102,6 +102,10 @@ func addSignaturesToTransactions(transactions map[common.Hash]*TransactionDescri } func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransaction { + toAmount := new(hexutil.Big) + if command.ToAmount != nil { + toAmount = command.ToAmount + } multiTransaction := NewMultiTransaction( /* Timestamp: */ uint64(time.Now().Unix()), /* FromNetworkID: */ 0, @@ -113,7 +117,7 @@ func multiTransactionFromCommand(command *MultiTransactionCommand) *MultiTransac /* FromAsset: */ command.FromAsset, /* ToAsset: */ command.ToAsset, /* FromAmount: */ command.FromAmount, - /* ToAmount: */ new(hexutil.Big), + /* ToAmount: */ toAmount, /* Type: */ command.Type, /* CrossTxID: */ "", ) diff --git a/services/wallet/transfer/transaction_manager.go b/services/wallet/transfer/transaction_manager.go index 3b77d3286..184cb9718 100644 --- a/services/wallet/transfer/transaction_manager.go +++ b/services/wallet/transfer/transaction_manager.go @@ -111,6 +111,7 @@ type MultiTransactionCommand struct { FromAsset string `json:"fromAsset"` ToAsset string `json:"toAsset"` FromAmount *hexutil.Big `json:"fromAmount"` + ToAmount *hexutil.Big `json:"toAmount"` Type MultiTransactionType `json:"type"` } diff --git a/services/wallet/transfer/transaction_manager_multitransaction.go b/services/wallet/transfer/transaction_manager_multitransaction.go index 1d8096e48..07102535b 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction.go +++ b/services/wallet/transfer/transaction_manager_multitransaction.go @@ -39,8 +39,12 @@ func (tm *TransactionManager) CreateMultiTransactionFromCommand(command *MultiTr multiTransaction := multiTransactionFromCommand(command) - if multiTransaction.Type == MultiTransactionSend && multiTransaction.FromNetworkID == 0 && len(data) == 1 { - multiTransaction.FromNetworkID = data[0].ChainID + // Set network for single chain transactions + switch multiTransaction.Type { + case MultiTransactionSend, MultiTransactionApprove, MultiTransactionSwap: + if multiTransaction.FromNetworkID == wallet_common.UnknownChainID && len(data) == 1 { + multiTransaction.FromNetworkID = data[0].ChainID + } } return multiTransaction, nil diff --git a/services/wallet/transfer/transaction_manager_multitransaction_test.go b/services/wallet/transfer/transaction_manager_multitransaction_test.go index 4f8b1187c..e54c3694e 100644 --- a/services/wallet/transfer/transaction_manager_multitransaction_test.go +++ b/services/wallet/transfer/transaction_manager_multitransaction_test.go @@ -120,6 +120,44 @@ func setupTransactionData(_ *testing.T, transactor transactions.TransactorIface) return &multiTransaction, data, bridges, expectedData } +func setupApproveTransactionData(_ *testing.T, transactor transactions.TransactorIface) (*MultiTransaction, []*pathprocessor.MultipathProcessorTxArgs, map[string]pathprocessor.PathProcessor, []*pathprocessor.MultipathProcessorTxArgs) { + SetMultiTransactionIDGenerator(StaticIDCounter()) + + // Create mock data for the test + tokenTransfer := generateTestTransfer(4) + multiTransaction := GenerateTestApproveMultiTransaction(tokenTransfer) + + // Initialize the bridges + var rpcClient *rpc.Client = nil + bridges := make(map[string]pathprocessor.PathProcessor) + transferBridge := pathprocessor.NewTransferProcessor(rpcClient, transactor) + bridges[transferBridge.Name()] = transferBridge + + data := []*pathprocessor.MultipathProcessorTxArgs{ + { + //ChainID: 1, // This will be set by transaction manager + Name: transferBridge.Name(), + TransferTx: &transactions.SendTxArgs{ + From: types.Address(tokenTransfer.From), + To: (*types.Address)(&tokenTransfer.To), + Value: (*hexutil.Big)(big.NewInt(tokenTransfer.Value)), + Data: types.HexBytes("0x0"), + // Symbol: multiTransaction.FromAsset, // This will be set by transaction manager + // MultiTransactionID: multiTransaction.ID, // This will be set by transaction manager + }, + }, + } + + expectedData := make([]*pathprocessor.MultipathProcessorTxArgs, 0) + for _, tx := range data { + txCopy := deepCopyTransactionBridgeWithTransferTx(tx) + updateDataFromMultiTx([]*pathprocessor.MultipathProcessorTxArgs{txCopy}, &multiTransaction) + expectedData = append(expectedData, txCopy) + } + + return &multiTransaction, data, bridges, expectedData +} + func TestSendTransactionsETHSuccess(t *testing.T) { tm, transactor, _ := setupTransactionManager(t) account := setupAccount(t, common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")) @@ -136,6 +174,22 @@ func TestSendTransactionsETHSuccess(t *testing.T) { require.NoError(t, err) } +func TestSendTransactionsApproveSuccess(t *testing.T) { + tm, transactor, _ := setupTransactionManager(t) + account := setupAccount(t, common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")) + multiTransaction, data, bridges, expectedData := setupApproveTransactionData(t, transactor) + + // Verify that the SendTransactionWithChainID method is called for each transaction with proper arguments + // Return values are not checked, because they must be checked in Transactor tests + for _, tx := range expectedData { + transactor.EXPECT().SendTransactionWithChainID(tx.ChainID, *(tx.TransferTx), account).Return(types.Hash{}, nil) + } + + // Call the SendTransactions method + _, err := tm.SendTransactions(context.Background(), multiTransaction, data, bridges, account) + require.NoError(t, err) +} + func TestSendTransactionsETHFailOnBridge(t *testing.T) { tm, transactor, ctrl := setupTransactionManager(t) account := setupAccount(t, common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")) @@ -258,3 +312,43 @@ func TestWatchTransaction_Timeout(t *testing.T) { err = tm.WatchTransaction(ctx, chainID, transactionHash) require.ErrorIs(t, err, ErrWatchPendingTxTimeout) } + +func TestCreateMultiTransactionFromCommand(t *testing.T) { + tm, _, _ := setupTransactionManager(t) + + var command *MultiTransactionCommand + + // Test types that should get chainID from the data + mtTypes := []MultiTransactionType{MultiTransactionSend, MultiTransactionApprove, MultiTransactionSwap} + + for _, mtType := range mtTypes { + fromAmount := hexutil.Big(*big.NewInt(1000000000000000000)) + toAmount := hexutil.Big(*big.NewInt(123)) + command = &MultiTransactionCommand{ + Type: mtType, + FromAddress: common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678"), + ToAddress: common.HexToAddress("0xabcdef1234567890abcdef1234567890abcdef12"), + FromAsset: "DAI", + ToAsset: "USDT", + FromAmount: &fromAmount, + ToAmount: &toAmount, + } + + data := make([]*pathprocessor.MultipathProcessorTxArgs, 0) + data = append(data, &pathprocessor.MultipathProcessorTxArgs{ + ChainID: 1, + }) + + multiTransaction, err := tm.CreateMultiTransactionFromCommand(command, data) + require.NoError(t, err) + require.NotNil(t, multiTransaction) + require.Equal(t, command.FromAddress, multiTransaction.FromAddress) + require.Equal(t, command.ToAddress, multiTransaction.ToAddress) + require.Equal(t, command.FromAsset, multiTransaction.FromAsset) + require.Equal(t, command.ToAsset, multiTransaction.ToAsset) + require.Equal(t, command.FromAmount, multiTransaction.FromAmount) + require.Equal(t, command.ToAmount, multiTransaction.ToAmount) + require.Equal(t, command.Type, multiTransaction.Type) + require.Equal(t, data[0].ChainID, multiTransaction.FromNetworkID) + } +}