From 4a3c4ad0cadbda41a655c86c53c09fcf59c8bc60 Mon Sep 17 00:00:00 2001 From: Anthony Laibe <491074+alaibe@users.noreply.github.com> Date: Mon, 4 Jul 2022 09:48:30 +0200 Subject: [PATCH] feat: allows to pass array of chain ids (#2733) --- services/wallet/api.go | 14 +++++++------- services/wallet/reader.go | 2 +- services/wallet/transaction.go | 19 +++++++++++++++---- services/wallet/transaction_test.go | 14 +++++++------- services/web3provider/api.go | 3 ++- services/web3provider/signature.go | 4 ++-- 6 files changed, 34 insertions(+), 22 deletions(-) diff --git a/services/wallet/api.go b/services/wallet/api.go index 56649a6df..d8db6fb2d 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -74,7 +74,7 @@ func (api *API) LoadTransferByHash(ctx context.Context, address common.Address, } func (api *API) GetTransfersByAddressAndChainID(ctx context.Context, chainID uint64, address common.Address, toBlock, limit *hexutil.Big, fetchMore bool) ([]transfer.View, error) { - log.Debug("[WalletAPI:: GetTransfersByAddressAndChainID] get transfers for an address", "address", address) + log.Debug("[WalletAPI:: GetTransfersByAddressAndChainIDs] get transfers for an address", "address", address) return api.s.transferController.GetTransfersByAddress(ctx, chainID, address, toBlock, limit, fetchMore) } @@ -189,28 +189,28 @@ func (api *API) DeleteSavedAddress(ctx context.Context, address common.Address) func (api *API) GetPendingTransactions(ctx context.Context) ([]*PendingTransaction, error) { log.Debug("call to get pending transactions") - rst, err := api.s.transactionManager.getAllPendings(api.s.rpcClient.UpstreamChainID) + rst, err := api.s.transactionManager.getAllPendings([]uint64{api.s.rpcClient.UpstreamChainID}) log.Debug("result from database for pending transactions", "len", len(rst)) return rst, err } -func (api *API) GetPendingTransactionsByChainID(ctx context.Context, chainID uint64) ([]*PendingTransaction, error) { +func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*PendingTransaction, error) { log.Debug("call to get pending transactions") - rst, err := api.s.transactionManager.getAllPendings(chainID) + rst, err := api.s.transactionManager.getAllPendings(chainIDs) log.Debug("result from database for pending transactions", "len", len(rst)) return rst, err } func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*PendingTransaction, error) { log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.transactionManager.getPendingByAddress(api.s.rpcClient.UpstreamChainID, address) + rst, err := api.s.transactionManager.getPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address) log.Debug("result from database for pending transactions by address", "len", len(rst)) return rst, err } -func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainID uint64, address common.Address) ([]*PendingTransaction, error) { +func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) { log.Debug("call to get pending outbound transactions by address") - rst, err := api.s.transactionManager.getPendingByAddress(chainID, address) + rst, err := api.s.transactionManager.getPendingByAddress(chainIDs, address) log.Debug("result from database for pending transactions by address", "len", len(rst)) return rst, err } diff --git a/services/wallet/reader.go b/services/wallet/reader.go index afa551626..62589908b 100644 --- a/services/wallet/reader.go +++ b/services/wallet/reader.go @@ -206,7 +206,7 @@ func (r *Reader) GetWallet(ctx context.Context, chainIDs []uint64) (*Wallet, err pendingTransactions := make(map[uint64][]*PendingTransaction) for _, chainID := range chainIDs { - pendingTx, err := r.s.transactionManager.getAllPendings(chainID) + pendingTx, err := r.s.transactionManager.getAllPendings([]uint64{chainID}) if err != nil { return nil, err } diff --git a/services/wallet/transaction.go b/services/wallet/transaction.go index d8d31a1f7..c1a98581e 100644 --- a/services/wallet/transaction.go +++ b/services/wallet/transaction.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "errors" + "fmt" "math/big" + "strings" "time" "github.com/ethereum/go-ethereum/common" @@ -43,12 +45,21 @@ type PendingTransaction struct { ChainID uint64 `json:"network_id"` } -func (tm *TransactionManager) getAllPendings(chainID uint64) ([]*PendingTransaction, error) { +func arrayToString(a []uint64, delim string) string { + res := make([]string, len(a)) + for i, v := range a { + res[i] = fmt.Sprint(v) + } + + return strings.Join(res, ",") +} + +func (tm *TransactionManager) getAllPendings(chainIDs []uint64) ([]*PendingTransaction, error) { rows, err := tm.db.Query(`SELECT hash, timestamp, value, from_address, to_address, data, symbol, gas_price, gas_limit, type, additional_data, network_id FROM pending_transactions - WHERE network_id = ?`, chainID) + WHERE network_id in (?)`, arrayToString(chainIDs, ",")) if err != nil { return nil, err } @@ -84,12 +95,12 @@ func (tm *TransactionManager) getAllPendings(chainID uint64) ([]*PendingTransact return transactions, nil } -func (tm *TransactionManager) getPendingByAddress(chainID uint64, address common.Address) ([]*PendingTransaction, error) { +func (tm *TransactionManager) getPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) { rows, err := tm.db.Query(`SELECT hash, timestamp, value, from_address, to_address, data, symbol, gas_price, gas_limit, type, additional_data, network_id FROM pending_transactions - WHERE network_id = ? AND from_address = ?`, chainID, address) + WHERE network_id in (?) AND from_address = ?`, arrayToString(chainIDs, ","), address) if err != nil { return nil, err } diff --git a/services/wallet/transaction_test.go b/services/wallet/transaction_test.go index 81bcaa046..f54046a4f 100644 --- a/services/wallet/transaction_test.go +++ b/services/wallet/transaction_test.go @@ -41,39 +41,39 @@ func TestPendingTransactions(t *testing.T) { ChainID: 777, } - rst, err := manager.getAllPendings(777) + rst, err := manager.getAllPendings([]uint64{777}) require.NoError(t, err) require.Nil(t, rst) - rst, err = manager.getPendingByAddress(777, trx.From) + rst, err = manager.getPendingByAddress([]uint64{777}, trx.From) require.NoError(t, err) require.Nil(t, rst) err = manager.addPending(trx) require.NoError(t, err) - rst, err = manager.getPendingByAddress(777, trx.From) + rst, err = manager.getPendingByAddress([]uint64{777}, trx.From) require.NoError(t, err) require.Equal(t, 1, len(rst)) require.Equal(t, trx, *rst[0]) - rst, err = manager.getAllPendings(777) + rst, err = manager.getAllPendings([]uint64{777}) require.NoError(t, err) require.Equal(t, 1, len(rst)) require.Equal(t, trx, *rst[0]) - rst, err = manager.getPendingByAddress(777, common.Address{2}) + rst, err = manager.getPendingByAddress([]uint64{777}, common.Address{2}) require.NoError(t, err) require.Nil(t, rst) err = manager.deletePending(777, trx.Hash) require.NoError(t, err) - rst, err = manager.getPendingByAddress(777, trx.From) + rst, err = manager.getPendingByAddress([]uint64{777}, trx.From) require.NoError(t, err) require.Equal(t, 0, len(rst)) - rst, err = manager.getAllPendings(777) + rst, err = manager.getAllPendings([]uint64{777}) require.NoError(t, err) require.Equal(t, 0, len(rst)) } diff --git a/services/web3provider/api.go b/services/web3provider/api.go index 6dca30c45..7cdff0a86 100644 --- a/services/web3provider/api.go +++ b/services/web3provider/api.go @@ -71,6 +71,7 @@ type ETHPayload struct { Method string `json:"method"` Params []interface{} `json:"params"` Password string `json:"password,omitempty"` + ChainID uint64 `json:"chainId,omitempty"` } type JSONRPCResponse struct { @@ -326,7 +327,7 @@ func (api *API) ProcessWeb3ReadOnlyRequest(request Web3SendAsyncReadOnlyRequest) return nil, err } - hash, err := api.sendTransaction(trxArgs, request.Payload.Password) + hash, err := api.sendTransaction(request.Payload.ChainID, trxArgs, request.Payload.Password) if err != nil { log.Error("could not send transaction message", "err", err) return &Web3SendAsyncReadOnlyResponse{ diff --git a/services/web3provider/signature.go b/services/web3provider/signature.go index 827427442..3bd47dce9 100644 --- a/services/web3provider/signature.go +++ b/services/web3provider/signature.go @@ -71,13 +71,13 @@ func (api *API) signTypedDataV4(typed signercore.TypedData, address string, pass } // SendTransaction creates a new transaction and waits until it's complete. -func (api *API) sendTransaction(sendArgs transactions.SendTxArgs, password string) (hash types.Hash, err error) { +func (api *API) sendTransaction(chainID uint64, sendArgs transactions.SendTxArgs, password string) (hash types.Hash, err error) { verifiedAccount, err := api.getVerifiedWalletAccount(sendArgs.From.String(), password) if err != nil { return hash, err } - hash, err = api.s.transactor.SendTransaction(sendArgs, verifiedAccount) + hash, err = api.s.transactor.SendTransactionWithChainID(chainID, sendArgs, verifiedAccount) if err != nil { return }