feat: allows to pass array of chain ids (#2733)

This commit is contained in:
Anthony Laibe 2022-07-04 09:48:30 +02:00 committed by GitHub
parent 72d2a97449
commit 4a3c4ad0ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 34 additions and 22 deletions

View File

@ -74,7 +74,7 @@ func (api *API) LoadTransferByHash(ctx context.Context, address common.Address,
}
func (api *API) GetTransfersByAddressAndChainID(ctx context.Context, chainID uint64, address common.Address, toBlock, limit *hexutil.Big, fetchMore bool) ([]transfer.View, error) {
log.Debug("[WalletAPI:: GetTransfersByAddressAndChainID] get transfers for an address", "address", address)
log.Debug("[WalletAPI:: GetTransfersByAddressAndChainIDs] get transfers for an address", "address", address)
return api.s.transferController.GetTransfersByAddress(ctx, chainID, address, toBlock, limit, fetchMore)
}
@ -189,28 +189,28 @@ func (api *API) DeleteSavedAddress(ctx context.Context, address common.Address)
func (api *API) GetPendingTransactions(ctx context.Context) ([]*PendingTransaction, error) {
log.Debug("call to get pending transactions")
rst, err := api.s.transactionManager.getAllPendings(api.s.rpcClient.UpstreamChainID)
rst, err := api.s.transactionManager.getAllPendings([]uint64{api.s.rpcClient.UpstreamChainID})
log.Debug("result from database for pending transactions", "len", len(rst))
return rst, err
}
func (api *API) GetPendingTransactionsByChainID(ctx context.Context, chainID uint64) ([]*PendingTransaction, error) {
func (api *API) GetPendingTransactionsByChainIDs(ctx context.Context, chainIDs []uint64) ([]*PendingTransaction, error) {
log.Debug("call to get pending transactions")
rst, err := api.s.transactionManager.getAllPendings(chainID)
rst, err := api.s.transactionManager.getAllPendings(chainIDs)
log.Debug("result from database for pending transactions", "len", len(rst))
return rst, err
}
func (api *API) GetPendingOutboundTransactionsByAddress(ctx context.Context, address common.Address) ([]*PendingTransaction, error) {
log.Debug("call to get pending outbound transactions by address")
rst, err := api.s.transactionManager.getPendingByAddress(api.s.rpcClient.UpstreamChainID, address)
rst, err := api.s.transactionManager.getPendingByAddress([]uint64{api.s.rpcClient.UpstreamChainID}, address)
log.Debug("result from database for pending transactions by address", "len", len(rst))
return rst, err
}
func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainID uint64, address common.Address) ([]*PendingTransaction, error) {
func (api *API) GetPendingOutboundTransactionsByAddressAndChainID(ctx context.Context, chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) {
log.Debug("call to get pending outbound transactions by address")
rst, err := api.s.transactionManager.getPendingByAddress(chainID, address)
rst, err := api.s.transactionManager.getPendingByAddress(chainIDs, address)
log.Debug("result from database for pending transactions by address", "len", len(rst))
return rst, err
}

View File

@ -206,7 +206,7 @@ func (r *Reader) GetWallet(ctx context.Context, chainIDs []uint64) (*Wallet, err
pendingTransactions := make(map[uint64][]*PendingTransaction)
for _, chainID := range chainIDs {
pendingTx, err := r.s.transactionManager.getAllPendings(chainID)
pendingTx, err := r.s.transactionManager.getAllPendings([]uint64{chainID})
if err != nil {
return nil, err
}

View File

@ -4,7 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"math/big"
"strings"
"time"
"github.com/ethereum/go-ethereum/common"
@ -43,12 +45,21 @@ type PendingTransaction struct {
ChainID uint64 `json:"network_id"`
}
func (tm *TransactionManager) getAllPendings(chainID uint64) ([]*PendingTransaction, error) {
func arrayToString(a []uint64, delim string) string {
res := make([]string, len(a))
for i, v := range a {
res[i] = fmt.Sprint(v)
}
return strings.Join(res, ",")
}
func (tm *TransactionManager) getAllPendings(chainIDs []uint64) ([]*PendingTransaction, error) {
rows, err := tm.db.Query(`SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id
FROM pending_transactions
WHERE network_id = ?`, chainID)
WHERE network_id in (?)`, arrayToString(chainIDs, ","))
if err != nil {
return nil, err
}
@ -84,12 +95,12 @@ func (tm *TransactionManager) getAllPendings(chainID uint64) ([]*PendingTransact
return transactions, nil
}
func (tm *TransactionManager) getPendingByAddress(chainID uint64, address common.Address) ([]*PendingTransaction, error) {
func (tm *TransactionManager) getPendingByAddress(chainIDs []uint64, address common.Address) ([]*PendingTransaction, error) {
rows, err := tm.db.Query(`SELECT hash, timestamp, value, from_address, to_address, data,
symbol, gas_price, gas_limit, type, additional_data,
network_id
FROM pending_transactions
WHERE network_id = ? AND from_address = ?`, chainID, address)
WHERE network_id in (?) AND from_address = ?`, arrayToString(chainIDs, ","), address)
if err != nil {
return nil, err
}

View File

@ -41,39 +41,39 @@ func TestPendingTransactions(t *testing.T) {
ChainID: 777,
}
rst, err := manager.getAllPendings(777)
rst, err := manager.getAllPendings([]uint64{777})
require.NoError(t, err)
require.Nil(t, rst)
rst, err = manager.getPendingByAddress(777, trx.From)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Nil(t, rst)
err = manager.addPending(trx)
require.NoError(t, err)
rst, err = manager.getPendingByAddress(777, trx.From)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Equal(t, 1, len(rst))
require.Equal(t, trx, *rst[0])
rst, err = manager.getAllPendings(777)
rst, err = manager.getAllPendings([]uint64{777})
require.NoError(t, err)
require.Equal(t, 1, len(rst))
require.Equal(t, trx, *rst[0])
rst, err = manager.getPendingByAddress(777, common.Address{2})
rst, err = manager.getPendingByAddress([]uint64{777}, common.Address{2})
require.NoError(t, err)
require.Nil(t, rst)
err = manager.deletePending(777, trx.Hash)
require.NoError(t, err)
rst, err = manager.getPendingByAddress(777, trx.From)
rst, err = manager.getPendingByAddress([]uint64{777}, trx.From)
require.NoError(t, err)
require.Equal(t, 0, len(rst))
rst, err = manager.getAllPendings(777)
rst, err = manager.getAllPendings([]uint64{777})
require.NoError(t, err)
require.Equal(t, 0, len(rst))
}

View File

@ -71,6 +71,7 @@ type ETHPayload struct {
Method string `json:"method"`
Params []interface{} `json:"params"`
Password string `json:"password,omitempty"`
ChainID uint64 `json:"chainId,omitempty"`
}
type JSONRPCResponse struct {
@ -326,7 +327,7 @@ func (api *API) ProcessWeb3ReadOnlyRequest(request Web3SendAsyncReadOnlyRequest)
return nil, err
}
hash, err := api.sendTransaction(trxArgs, request.Payload.Password)
hash, err := api.sendTransaction(request.Payload.ChainID, trxArgs, request.Payload.Password)
if err != nil {
log.Error("could not send transaction message", "err", err)
return &Web3SendAsyncReadOnlyResponse{

View File

@ -71,13 +71,13 @@ func (api *API) signTypedDataV4(typed signercore.TypedData, address string, pass
}
// SendTransaction creates a new transaction and waits until it's complete.
func (api *API) sendTransaction(sendArgs transactions.SendTxArgs, password string) (hash types.Hash, err error) {
func (api *API) sendTransaction(chainID uint64, sendArgs transactions.SendTxArgs, password string) (hash types.Hash, err error) {
verifiedAccount, err := api.getVerifiedWalletAccount(sendArgs.From.String(), password)
if err != nil {
return hash, err
}
hash, err = api.s.transactor.SendTransaction(sendArgs, verifiedAccount)
hash, err = api.s.transactor.SendTransactionWithChainID(chainID, sendArgs, verifiedAccount)
if err != nil {
return
}