feat: nonce management in multi tx

This commit is contained in:
Anthony Laibe 2022-12-19 13:37:37 +01:00 committed by Anthony Laibe
parent d60c1d00ed
commit 883f9677c5
9 changed files with 325 additions and 178 deletions

View File

@ -12,6 +12,7 @@ import (
hopBridge "github.com/status-im/status-go/contracts/hop/bridge" hopBridge "github.com/status-im/status-go/contracts/hop/bridge"
hopSwap "github.com/status-im/status-go/contracts/hop/swap" hopSwap "github.com/status-im/status-go/contracts/hop/swap"
hopWrapper "github.com/status-im/status-go/contracts/hop/wrapper" hopWrapper "github.com/status-im/status-go/contracts/hop/wrapper"
"github.com/status-im/status-go/contracts/ierc20"
"github.com/status-im/status-go/contracts/registrar" "github.com/status-im/status-go/contracts/registrar"
"github.com/status-im/status-go/contracts/resolver" "github.com/status-im/status-go/contracts/resolver"
"github.com/status-im/status-go/contracts/snt" "github.com/status-im/status-go/contracts/snt"
@ -64,6 +65,18 @@ func (c *ContractMaker) NewUsernameRegistrar(chainID uint64, contractAddr common
) )
} }
func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (*ierc20.IERC20, error) {
backend, err := c.RPCClient.EthClient(chainID)
if err != nil {
return nil, err
}
return ierc20.NewIERC20(
contractAddr,
backend,
)
}
func (c *ContractMaker) NewSNT(chainID uint64) (*snt.SNT, error) { func (c *ContractMaker) NewSNT(chainID uint64) (*snt.SNT, error) {
contractAddr, err := snt.ContractAddress(chainID) contractAddr, err := snt.ContractAddress(chainID)
if err != nil { if err != nil {

View File

@ -84,4 +84,5 @@ type Bridge interface {
EstimateGas(from *params.Network, to *params.Network, token *token.Token, amountIn *big.Int) (uint64, error) EstimateGas(from *params.Network, to *params.Network, token *token.Token, amountIn *big.Int) (uint64, error)
CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error) CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error)
Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error)
GetContractAddress(network *params.Network, token *token.Token) *common.Address
} }

View File

@ -36,14 +36,16 @@ type CBridgeTxArgs struct {
type CBridge struct { type CBridge struct {
rpcClient *rpc.Client rpcClient *rpc.Client
transactor *transactions.Transactor
tokenManager *token.Manager tokenManager *token.Manager
prodTransferConfig *cbridge.GetTransferConfigsResponse prodTransferConfig *cbridge.GetTransferConfigsResponse
testTransferConfig *cbridge.GetTransferConfigsResponse testTransferConfig *cbridge.GetTransferConfigsResponse
} }
func NewCbridge(rpcClient *rpc.Client, tokenManager *token.Manager) *CBridge { func NewCbridge(rpcClient *rpc.Client, transactor *transactions.Transactor, tokenManager *token.Manager) *CBridge {
return &CBridge{ return &CBridge{
rpcClient: rpcClient, rpcClient: rpcClient,
transactor: transactor,
tokenManager: tokenManager, tokenManager: tokenManager,
} }
} }
@ -219,6 +221,25 @@ func (s *CBridge) EstimateGas(from, to *params.Network, token *token.Token, amou
return 200000, nil //default gas limit for erc20 transaction return 200000, nil //default gas limit for erc20 transaction
} }
func (s *CBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address {
transferConfig, err := s.getTransferConfig(network.IsTest)
if err != nil {
return nil
}
if transferConfig.Err != nil {
return nil
}
for _, chain := range transferConfig.Chains {
if uint64(chain.Id) == network.ChainID {
addr := common.HexToAddress(chain.ContractAddr)
return &addr
}
}
return nil
}
func (s *CBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error) { func (s *CBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error) {
fromNetwork := s.rpcClient.NetworkManager.Find(sendArgs.ChainID) fromNetwork := s.rpcClient.NetworkManager.Find(sendArgs.ChainID)
if fromNetwork == nil { if fromNetwork == nil {
@ -228,27 +249,16 @@ func (s *CBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.Sel
if tk == nil { if tk == nil {
return types.HexToHash(""), errors.New("token not found") return types.HexToHash(""), errors.New("token not found")
} }
transferConfig, err := s.getTransferConfig(fromNetwork.IsTest) addrs := s.GetContractAddress(fromNetwork, nil)
if err != nil { if addrs == nil {
return types.HexToHash(""), err return types.HexToHash(""), errors.New("contract not found")
}
if transferConfig.Err != nil {
return types.HexToHash(""), errors.New(transferConfig.Err.Msg)
}
addrs := ""
for _, chain := range transferConfig.Chains {
if uint64(chain.Id) == sendArgs.ChainID {
addrs = chain.ContractAddr
break
}
} }
backend, err := s.rpcClient.EthClient(sendArgs.ChainID) backend, err := s.rpcClient.EthClient(sendArgs.ChainID)
if err != nil { if err != nil {
return types.HexToHash(""), err return types.HexToHash(""), err
} }
contract, err := celer.NewCeler(common.HexToAddress(addrs), backend) contract, err := celer.NewCeler(*addrs, backend)
if err != nil { if err != nil {
return types.HexToHash(""), err return types.HexToHash(""), err
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/status-im/status-go/account" "github.com/status-im/status-go/account"
"github.com/status-im/status-go/contracts" "github.com/status-im/status-go/contracts"
"github.com/status-im/status-go/contracts/hop"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc" "github.com/status-im/status-go/rpc"
@ -83,12 +84,16 @@ type HopTxArgs struct {
} }
type HopBridge struct { type HopBridge struct {
transactor *transactions.Transactor
tokenManager *token.Manager
contractMaker *contracts.ContractMaker contractMaker *contracts.ContractMaker
} }
func NewHopBridge(rpcClient *rpc.Client) *HopBridge { func NewHopBridge(rpcClient *rpc.Client, transactor *transactions.Transactor, tokenManager *token.Manager) *HopBridge {
return &HopBridge{ return &HopBridge{
contractMaker: &contracts.ContractMaker{RPCClient: rpcClient}, contractMaker: &contracts.ContractMaker{RPCClient: rpcClient},
transactor: transactor,
tokenManager: tokenManager,
} }
} }
@ -157,39 +162,57 @@ func (h *HopBridge) EstimateGas(from, to *params.Network, token *token.Token, am
return 500000 + 1000, nil return 500000 + 1000, nil
} }
func (h *HopBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { func (h *HopBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address {
networks, err := h.contractMaker.RPCClient.NetworkManager.Get(false) var address common.Address
if err != nil { if network.Layer == 1 {
return hash, err address, _ = hop.L1BridgeContractAddress(network.ChainID, token.Symbol)
} } else {
var fromNetwork *params.Network address, _ = hop.L2AmmWrapperContractAddress(network.ChainID, token.Symbol)
for _, network := range networks {
if network.ChainID == sendArgs.ChainID {
fromNetwork = network
break
}
} }
if fromNetwork.Layer == 1 { return &address
return h.sendToL2(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount)
}
return h.swapAndSend(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount)
} }
func (h *HopBridge) sendToL2(chainID uint64, sendArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { func (h *HopBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) {
bridge, err := h.contractMaker.NewHopL1Bridge(chainID, sendArgs.Symbol) fromNetwork := h.contractMaker.RPCClient.NetworkManager.Find(sendArgs.ChainID)
if fromNetwork == nil {
return hash, err
}
nonce, unlock, err := h.transactor.NextNonce(h.contractMaker.RPCClient, sendArgs.ChainID, sendArgs.HopTx.From)
if err != nil { if err != nil {
return hash, err return hash, err
} }
txOpts := sendArgs.ToTransactOpts(getSigner(chainID, sendArgs.From, verifiedAccount)) defer func() {
txOpts.Value = (*big.Int)(sendArgs.Amount) unlock(err == nil, nonce)
}()
argNonce := hexutil.Uint64(nonce)
sendArgs.HopTx.Nonce = &argNonce
token := h.tokenManager.FindToken(fromNetwork, sendArgs.HopTx.Symbol)
if fromNetwork.Layer == 1 {
hash, err = h.sendToL2(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount, token)
}
hash, err = h.swapAndSend(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount, token)
return hash, err
}
func (h *HopBridge) sendToL2(chainID uint64, hopArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey, token *token.Token) (hash types.Hash, err error) {
bridge, err := h.contractMaker.NewHopL1Bridge(chainID, hopArgs.Symbol)
if err != nil {
return hash, err
}
txOpts := hopArgs.ToTransactOpts(getSigner(chainID, hopArgs.From, verifiedAccount))
if token.IsNative() {
txOpts.Value = (*big.Int)(hopArgs.Amount)
}
now := time.Now() now := time.Now()
deadline := big.NewInt(now.Unix() + 604800) deadline := big.NewInt(now.Unix() + 604800)
tx, err := bridge.SendToL2( tx, err := bridge.SendToL2(
txOpts, txOpts,
big.NewInt(int64(sendArgs.ChainID)), big.NewInt(int64(hopArgs.ChainID)),
sendArgs.Recipient, hopArgs.Recipient,
sendArgs.Amount.ToInt(), hopArgs.Amount.ToInt(),
big.NewInt(0), big.NewInt(0),
deadline, deadline,
common.HexToAddress("0x0"), common.HexToAddress("0x0"),
@ -202,22 +225,24 @@ func (h *HopBridge) sendToL2(chainID uint64, sendArgs *HopTxArgs, verifiedAccoun
return types.Hash(tx.Hash()), nil return types.Hash(tx.Hash()), nil
} }
func (h *HopBridge) swapAndSend(chainID uint64, sendArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { func (h *HopBridge) swapAndSend(chainID uint64, hopArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey, token *token.Token) (hash types.Hash, err error) {
ammWrapper, err := h.contractMaker.NewHopL2AmmWrapper(chainID, sendArgs.Symbol) ammWrapper, err := h.contractMaker.NewHopL2AmmWrapper(chainID, hopArgs.Symbol)
if err != nil { if err != nil {
return hash, err return hash, err
} }
txOpts := sendArgs.ToTransactOpts(getSigner(chainID, sendArgs.From, verifiedAccount)) txOpts := hopArgs.ToTransactOpts(getSigner(chainID, hopArgs.From, verifiedAccount))
txOpts.Value = (*big.Int)(sendArgs.Amount) if token.IsNative() {
txOpts.Value = (*big.Int)(hopArgs.Amount)
}
now := time.Now() now := time.Now()
deadline := big.NewInt(now.Unix() + 604800) deadline := big.NewInt(now.Unix() + 604800)
tx, err := ammWrapper.SwapAndSend( tx, err := ammWrapper.SwapAndSend(
txOpts, txOpts,
big.NewInt(int64(sendArgs.ChainID)), big.NewInt(int64(hopArgs.ChainID)),
sendArgs.Recipient, hopArgs.Recipient,
sendArgs.Amount.ToInt(), hopArgs.Amount.ToInt(),
sendArgs.BonderFee.ToInt(), hopArgs.BonderFee.ToInt(),
big.NewInt(0), big.NewInt(0),
deadline, deadline,
big.NewInt(0), big.NewInt(0),

View File

@ -3,6 +3,7 @@ package bridge
import ( import (
"math/big" "math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/account" "github.com/status-im/status-go/account"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
@ -46,3 +47,7 @@ func (s *SimpleBridge) Send(sendArgs *TransactionBridge, verifiedAccount *accoun
func (s *SimpleBridge) CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error) { func (s *SimpleBridge) CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error) {
return amountIn, nil return amountIn, nil
} }
func (s *SimpleBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address {
return nil
}

View File

@ -7,12 +7,19 @@ import (
"math" "math"
"math/big" "math/big"
"sort" "sort"
"strings"
"sync" "sync"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/status-im/status-go/contracts"
"github.com/status-im/status-go/contracts/ierc20"
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/params" "github.com/status-im/status-go/params"
"github.com/status-im/status-go/rpc"
"github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/async"
"github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/bigint"
"github.com/status-im/status-go/services/wallet/bridge" "github.com/status-im/status-go/services/wallet/bridge"
@ -116,6 +123,10 @@ type Path struct {
TokenFees *big.Float TokenFees *big.Float
Cost *big.Float Cost *big.Float
EstimatedTime TransactionEstimation EstimatedTime TransactionEstimation
ApprovalRequired bool
ApprovalGasFees *big.Float
ApprovalAmountRequired *hexutil.Big
ApprovalContractAddress *common.Address
} }
func (p *Path) Equal(o *Path) bool { func (p *Path) Equal(o *Path) bool {
@ -334,13 +345,13 @@ func newSuggestedRoutes(
func NewRouter(s *Service) *Router { func NewRouter(s *Service) *Router {
bridges := make(map[string]bridge.Bridge) bridges := make(map[string]bridge.Bridge)
simple := bridge.NewSimpleBridge(s.transactor) simple := bridge.NewSimpleBridge(s.transactor)
hop := bridge.NewHopBridge(s.rpcClient) cbridge := bridge.NewCbridge(s.rpcClient, s.transactor, s.tokenManager)
cbridge := bridge.NewCbridge(s.rpcClient, s.tokenManager) hop := bridge.NewHopBridge(s.rpcClient, s.transactor, s.tokenManager)
bridges[simple.Name()] = simple bridges[simple.Name()] = simple
bridges[hop.Name()] = hop bridges[hop.Name()] = hop
bridges[cbridge.Name()] = cbridge bridges[cbridge.Name()] = cbridge
return &Router{s, bridges} return &Router{s, bridges, s.rpcClient}
} }
func containsNetworkChainID(network *params.Network, chainIDs []uint64) bool { func containsNetworkChainID(network *params.Network, chainIDs []uint64) bool {
@ -356,6 +367,64 @@ func containsNetworkChainID(network *params.Network, chainIDs []uint64) bool {
type Router struct { type Router struct {
s *Service s *Service
bridges map[string]bridge.Bridge bridges map[string]bridge.Bridge
rpcClient *rpc.Client
}
func (r *Router) requireApproval(ctx context.Context, bridge bridge.Bridge, account common.Address, network *params.Network, token *token.Token, amountIn *big.Int) (bool, *big.Int, uint64, *common.Address, error) {
if token.IsNative() {
return false, nil, 0, nil, nil
}
contractMaker := &contracts.ContractMaker{RPCClient: r.rpcClient}
bridgeAddress := bridge.GetContractAddress(network, token)
if bridgeAddress == nil {
return false, nil, 0, nil, nil
}
contract, err := contractMaker.NewERC20(network.ChainID, token.Address)
if err != nil {
return false, nil, 0, nil, err
}
allowance, err := contract.Allowance(&bind.CallOpts{
Context: ctx,
}, account, *bridgeAddress)
if err != nil {
return false, nil, 0, nil, err
}
if allowance.Cmp(amountIn) >= 0 {
return false, nil, 0, nil, nil
}
ethClient, err := r.rpcClient.EthClient(network.ChainID)
if err != nil {
return false, nil, 0, nil, err
}
erc20ABI, err := abi.JSON(strings.NewReader(ierc20.IERC20ABI))
if err != nil {
return false, nil, 0, nil, err
}
data, err := erc20ABI.Pack("approve", bridgeAddress, amountIn)
if err != nil {
return false, nil, 0, nil, err
}
estimate, err := ethClient.EstimateGas(context.Background(), ethereum.CallMsg{
From: account,
To: &token.Address,
Value: big.NewInt(0),
Data: data,
})
if err != nil {
return false, nil, 0, nil, err
}
return true, amountIn, estimate, bridgeAddress, nil
} }
func (r *Router) getBalance(ctx context.Context, network *params.Network, token *token.Token, account common.Address) (*big.Int, error) { func (r *Router) getBalance(ctx context.Context, network *params.Network, token *token.Token, account common.Address) (*big.Int, error) {
@ -509,19 +578,34 @@ func (r *Router) suggestedRoutes(
continue continue
} }
approvalRequired, approvalAmountRequired, approvalGasLimit, approvalContractAddress, err := r.requireApproval(ctx, bridge, account, network, token, amountIn)
if err != nil {
continue
}
approvalGasFees := new(big.Float).Mul(gweiToEth(maxFees), big.NewFloat((float64(approvalGasLimit))))
approvalGasCost := new(big.Float)
approvalGasCost.Mul(
approvalGasFees,
big.NewFloat(prices["ETH"]),
)
gasCost := new(big.Float) gasCost := new(big.Float)
gasCost.Mul( gasCost.Mul(
new(big.Float).Mul(gweiToEth(maxFees), big.NewFloat((float64(gasLimit)))), new(big.Float).Mul(gweiToEth(maxFees), big.NewFloat((float64(gasLimit)))),
big.NewFloat(prices["ETH"]), big.NewFloat(prices["ETH"]),
) )
tokenFeesAsFloat := new(big.Float).Quo( tokenFeesAsFloat := new(big.Float).Quo(
new(big.Float).SetInt(tokenFees), new(big.Float).SetInt(tokenFees),
big.NewFloat(math.Pow(10, float64(token.Decimals))), big.NewFloat(math.Pow(10, float64(token.Decimals))),
) )
tokenCost := new(big.Float) tokenCost := new(big.Float)
tokenCost.Mul(tokenFeesAsFloat, big.NewFloat(prices[tokenSymbol])) tokenCost.Mul(tokenFeesAsFloat, big.NewFloat(prices[tokenSymbol]))
cost := new(big.Float) cost := new(big.Float)
cost.Add(tokenCost, gasCost) cost.Add(tokenCost, gasCost)
cost.Add(cost, approvalGasCost)
mu.Lock() mu.Lock()
candidates = append(candidates, &Path{ candidates = append(candidates, &Path{
@ -537,6 +621,10 @@ func (r *Router) suggestedRoutes(
TokenFees: tokenFeesAsFloat, TokenFees: tokenFeesAsFloat,
Cost: cost, Cost: cost,
EstimatedTime: estimatedTime, EstimatedTime: estimatedTime,
ApprovalRequired: approvalRequired,
ApprovalGasFees: approvalGasFees,
ApprovalAmountRequired: (*hexutil.Big)(approvalAmountRequired),
ApprovalContractAddress: approvalContractAddress,
}) })
mu.Unlock() mu.Unlock()
} }

67
transactions/nonce.go Normal file
View File

@ -0,0 +1,67 @@
package transactions
import (
"context"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/status-im/status-go/eth-node/types"
)
type Nonce struct {
addrLock *AddrLocker
localNonce map[uint64]*sync.Map
}
func NewNonce() *Nonce {
return &Nonce{
addrLock: &AddrLocker{},
localNonce: make(map[uint64]*sync.Map),
}
}
func (n *Nonce) Next(rpcWrapper *rpcWrapper, from types.Address) (uint64, func(inc bool, nonce uint64), error) {
n.addrLock.LockAddr(from)
current, err := n.GetCurrent(rpcWrapper, from)
unlock := func(inc bool, nonce uint64) {
if inc {
if _, ok := n.localNonce[rpcWrapper.chainID]; !ok {
n.localNonce[rpcWrapper.chainID] = &sync.Map{}
}
n.localNonce[rpcWrapper.chainID].Store(from, nonce+1)
}
n.addrLock.UnlockAddr(from)
}
return current, unlock, err
}
func (n *Nonce) GetCurrent(rpcWrapper *rpcWrapper, from types.Address) (uint64, error) {
var (
localNonce uint64
remoteNonce uint64
)
if _, ok := n.localNonce[rpcWrapper.chainID]; !ok {
n.localNonce[rpcWrapper.chainID] = &sync.Map{}
}
// get the local nonce
if val, ok := n.localNonce[rpcWrapper.chainID].Load(from); ok {
localNonce = val.(uint64)
}
// get the remote nonce
ctx := context.Background()
remoteNonce, err := rpcWrapper.PendingNonceAt(ctx, common.Address(from))
if err != nil {
return 0, err
}
// if upstream node returned nonce higher than ours we will use it, as it probably means
// that another client was used for sending transactions
if remoteNonce > localNonce {
return remoteNonce, nil
}
return localNonce, nil
}

View File

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"sync"
"time" "time"
ethereum "github.com/ethereum/go-ethereum" ethereum "github.com/ethereum/go-ethereum"
@ -49,18 +48,15 @@ type Transactor struct {
sendTxTimeout time.Duration sendTxTimeout time.Duration
rpcCallTimeout time.Duration rpcCallTimeout time.Duration
networkID uint64 networkID uint64
nonce *Nonce
addrLock *AddrLocker
localNonce sync.Map
log log.Logger log log.Logger
} }
// NewTransactor returns a new Manager. // NewTransactor returns a new Manager.
func NewTransactor() *Transactor { func NewTransactor() *Transactor {
return &Transactor{ return &Transactor{
addrLock: &AddrLocker{},
sendTxTimeout: sendTxTimeout, sendTxTimeout: sendTxTimeout,
localNonce: sync.Map{}, nonce: NewNonce(),
log: log.New("package", "status-go/transactions.Manager"), log: log.New("package", "status-go/transactions.Manager"),
} }
} }
@ -76,6 +72,11 @@ func (t *Transactor) SetRPC(rpcClient *rpc.Client, timeout time.Duration) {
t.rpcCallTimeout = timeout t.rpcCallTimeout = timeout
} }
func (t *Transactor) NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, func(inc bool, n uint64), error) {
wrapper := newRPCWrapper(rpcClient, chainID)
return t.nonce.Next(wrapper, from)
}
// SendTransaction is an implementation of eth_sendTransaction. It queues the tx to the sign queue. // SendTransaction is an implementation of eth_sendTransaction. It queues the tx to the sign queue.
func (t *Transactor) SendTransaction(sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { func (t *Transactor) SendTransaction(sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) {
hash, err = t.validateAndPropagate(t.rpcWrapper, verifiedAccount, sendArgs) hash, err = t.validateAndPropagate(t.rpcWrapper, verifiedAccount, sendArgs)
@ -104,20 +105,13 @@ func (t *Transactor) SendTransactionWithSignature(args SendTxArgs, sig []byte) (
signer := gethtypes.NewLondonSigner(chainID) signer := gethtypes.NewLondonSigner(chainID)
tx := t.buildTransaction(args) tx := t.buildTransaction(args)
t.addrLock.LockAddr(args.From) expectedNonce, unlock, err := t.nonce.Next(t.rpcWrapper, args.From)
defer func() {
// nonce should be incremented only if tx completed without error
// and if no other transactions have been sent while signing the current one.
if err == nil {
t.localNonce.Store(args.From, uint64(*args.Nonce)+1)
}
t.addrLock.UnlockAddr(args.From)
}()
expectedNonce, err := t.getTransactionNonce(args)
if err != nil { if err != nil {
return hash, err return hash, err
} }
defer func() {
unlock(err == nil, expectedNonce)
}()
if tx.Nonce() != expectedNonce { if tx.Nonce() != expectedNonce {
return hash, &ErrBadNonce{tx.Nonce(), expectedNonce} return hash, &ErrBadNonce{tx.Nonce(), expectedNonce}
@ -134,7 +128,6 @@ func (t *Transactor) SendTransactionWithSignature(args SendTxArgs, sig []byte) (
if err := t.rpcWrapper.SendTransaction(ctx, signedTx); err != nil { if err := t.rpcWrapper.SendTransaction(ctx, signedTx); err != nil {
return hash, err return hash, err
} }
return types.Hash(signedTx.Hash()), nil return types.Hash(signedTx.Hash()), nil
} }
@ -145,15 +138,11 @@ func (t *Transactor) HashTransaction(args SendTxArgs) (validatedArgs SendTxArgs,
validatedArgs = args validatedArgs = args
t.addrLock.LockAddr(args.From) nonce, unlock, err := t.nonce.Next(t.rpcWrapper, args.From)
defer func() {
t.addrLock.UnlockAddr(args.From)
}()
nonce, err := t.getTransactionNonce(validatedArgs)
if err != nil { if err != nil {
return validatedArgs, hash, err return validatedArgs, hash, err
} }
defer unlock(false, 0)
gasPrice := (*big.Int)(args.GasPrice) gasPrice := (*big.Int)(args.GasPrice)
gasFeeCap := (*big.Int)(args.MaxFeePerGas) gasFeeCap := (*big.Int)(args.MaxFeePerGas)
@ -250,51 +239,29 @@ func (t *Transactor) validateAndPropagate(rpcWrapper *rpcWrapper, selectedAccoun
if !args.Valid() { if !args.Valid() {
return hash, ErrInvalidSendTxArgs return hash, ErrInvalidSendTxArgs
} }
t.addrLock.LockAddr(args.From)
var localNonce uint64
if val, ok := t.localNonce.Load(args.From); ok {
localNonce = val.(uint64)
}
var nonce uint64
defer func() {
// nonce should be incremented only if tx completed without error
// if upstream node returned nonce higher than ours we will stick to it
if err == nil && args.Nonce == nil {
t.localNonce.Store(args.From, nonce+1)
}
t.addrLock.UnlockAddr(args.From)
nonce, unlock, err := t.nonce.Next(rpcWrapper, args.From)
if err != nil {
return hash, err
}
if args.Nonce != nil {
nonce = uint64(*args.Nonce)
}
defer func() {
unlock(err == nil, nonce)
}() }()
ctx, cancel := context.WithTimeout(context.Background(), t.rpcCallTimeout) ctx, cancel := context.WithTimeout(context.Background(), t.rpcCallTimeout)
defer cancel() defer cancel()
if args.Nonce == nil {
nonce, err = rpcWrapper.PendingNonceAt(ctx, common.Address(args.From))
if err != nil {
return hash, err
}
// if upstream node returned nonce higher than ours we will use it, as it probably means
// that another client was used for sending transactions
if localNonce > nonce {
nonce = localNonce
}
} else {
nonce = uint64(*args.Nonce)
}
gasPrice := (*big.Int)(args.GasPrice) gasPrice := (*big.Int)(args.GasPrice)
if !args.IsDynamicFeeTx() && args.GasPrice == nil { if !args.IsDynamicFeeTx() && args.GasPrice == nil {
ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout)
defer cancel()
gasPrice, err = rpcWrapper.SuggestGasPrice(ctx) gasPrice, err = rpcWrapper.SuggestGasPrice(ctx)
if err != nil { if err != nil {
return hash, err return hash, err
} }
} }
chainID := big.NewInt(int64(rpcWrapper.chainID)) chainID := big.NewInt(int64(rpcWrapper.chainID))
value := (*big.Int)(args.Value) value := (*big.Int)(args.Value)
var gas uint64 var gas uint64
if args.Gas != nil { if args.Gas != nil {
gas = uint64(*args.Gas) gas = uint64(*args.Gas)
@ -325,15 +292,13 @@ func (t *Transactor) validateAndPropagate(rpcWrapper *rpcWrapper, selectedAccoun
gas = defaultGas gas = defaultGas
} }
} }
tx := t.buildTransactionWithOverrides(nonce, value, gas, gasPrice, args) tx := t.buildTransactionWithOverrides(nonce, value, gas, gasPrice, args)
signedTx, err := gethtypes.SignTx(tx, gethtypes.NewLondonSigner(chainID), selectedAccount.AccountKey.PrivateKey) signedTx, err := gethtypes.SignTx(tx, gethtypes.NewLondonSigner(chainID), selectedAccount.AccountKey.PrivateKey)
if err != nil { if err != nil {
return hash, err return hash, err
} }
ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout) // ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout)
defer cancel() // defer cancel()
if err := rpcWrapper.SendTransaction(ctx, signedTx); err != nil { if err := rpcWrapper.SendTransaction(ctx, signedTx); err != nil {
return hash, err return hash, err
@ -405,36 +370,6 @@ func (t *Transactor) buildTransactionWithOverrides(nonce uint64, value *big.Int,
return tx return tx
} }
func (t *Transactor) getTransactionNonce(args SendTxArgs) (newNonce uint64, err error) {
var (
localNonce uint64
remoteNonce uint64
)
// get the local nonce
if val, ok := t.localNonce.Load(args.From); ok {
localNonce = val.(uint64)
}
// get the remote nonce
ctx, cancel := context.WithTimeout(context.Background(), t.rpcCallTimeout)
defer cancel()
remoteNonce, err = t.rpcWrapper.PendingNonceAt(ctx, common.Address(args.From))
if err != nil {
return newNonce, err
}
// if upstream node returned nonce higher than ours we will use it, as it probably means
// that another client was used for sending transactions
if remoteNonce > localNonce {
newNonce = remoteNonce
} else {
newNonce = localNonce
}
return newNonce, nil
}
func (t *Transactor) logNewTx(args SendTxArgs, gas uint64, gasPrice *big.Int, value *big.Int) { func (t *Transactor) logNewTx(args SendTxArgs, gas uint64, gasPrice *big.Int, value *big.Int) {
t.log.Info("New transaction", t.log.Info("New transaction",
"From", args.From, "From", args.From,

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"reflect" "reflect"
"sync"
"testing" "testing"
"time" "time"
@ -49,10 +50,11 @@ func (s *TransactorSuite) SetupTest() {
s.server, s.txServiceMock = fake.NewTestServer(s.txServiceMockCtrl) s.server, s.txServiceMock = fake.NewTestServer(s.txServiceMockCtrl)
s.client = gethrpc.DialInProc(s.server) s.client = gethrpc.DialInProc(s.server)
rpcClient, _ := rpc.NewClient(s.client, 1, params.UpstreamRPCConfig{}, nil, nil)
rpcClient.UpstreamChainID = 1
// expected by simulated backend // expected by simulated backend
chainID := gethparams.AllEthashProtocolChanges.ChainID.Uint64() chainID := gethparams.AllEthashProtocolChanges.ChainID.Uint64()
rpcClient, _ := rpc.NewClient(s.client, chainID, params.UpstreamRPCConfig{}, nil, nil)
rpcClient.UpstreamChainID = chainID
nodeConfig, err := utils.MakeTestNodeConfigWithDataDir("", "/tmp", chainID) nodeConfig, err := utils.MakeTestNodeConfigWithDataDir("", "/tmp", chainID)
s.Require().NoError(err) s.Require().NoError(err)
s.nodeConfig = nodeConfig s.nodeConfig = nodeConfig
@ -129,7 +131,7 @@ func (s *TransactorSuite) rlpEncodeTx(args SendTxArgs, config *params.NodeConfig
} }
newTx := gethtypes.NewTx(txData) newTx := gethtypes.NewTx(txData)
chainID := big.NewInt(int64(1)) chainID := big.NewInt(int64(s.nodeConfig.NetworkID))
signedTx, err := gethtypes.SignTx(newTx, gethtypes.NewLondonSigner(chainID), account.AccountKey.PrivateKey) signedTx, err := gethtypes.SignTx(newTx, gethtypes.NewLondonSigner(chainID), account.AccountKey.PrivateKey)
s.NoError(err) s.NoError(err)
@ -253,6 +255,7 @@ func (s *TransactorSuite) TestAccountMismatch() {
// as the last step, we verify that if tx failed nonce is not updated // as the last step, we verify that if tx failed nonce is not updated
func (s *TransactorSuite) TestLocalNonce() { func (s *TransactorSuite) TestLocalNonce() {
txCount := 3 txCount := 3
chainID := s.nodeConfig.NetworkID
key, _ := gethcrypto.GenerateKey() key, _ := gethcrypto.GenerateKey()
selectedAccount := &account.SelectedExtKey{ selectedAccount := &account.SelectedExtKey{
Address: account.FromAddress(utils.TestConfig.Account1.WalletAddress), Address: account.FromAddress(utils.TestConfig.Account1.WalletAddress),
@ -269,7 +272,7 @@ func (s *TransactorSuite) TestLocalNonce() {
_, err := s.manager.SendTransaction(args, selectedAccount) _, err := s.manager.SendTransaction(args, selectedAccount)
s.NoError(err) s.NoError(err)
resultNonce, _ := s.manager.localNonce.Load(args.From) resultNonce, _ := s.manager.nonce.localNonce[chainID].Load(args.From)
s.Equal(uint64(i)+1, resultNonce.(uint64)) s.Equal(uint64(i)+1, resultNonce.(uint64))
} }
@ -284,7 +287,7 @@ func (s *TransactorSuite) TestLocalNonce() {
_, err := s.manager.SendTransaction(args, selectedAccount) _, err := s.manager.SendTransaction(args, selectedAccount)
s.NoError(err) s.NoError(err)
resultNonce, _ := s.manager.localNonce.Load(args.From) resultNonce, _ := s.manager.nonce.localNonce[chainID].Load(args.From)
s.Equal(uint64(nonce)+1, resultNonce.(uint64)) s.Equal(uint64(nonce)+1, resultNonce.(uint64))
testErr := errors.New("test") testErr := errors.New("test")
@ -296,7 +299,7 @@ func (s *TransactorSuite) TestLocalNonce() {
_, err = s.manager.SendTransaction(args, selectedAccount) _, err = s.manager.SendTransaction(args, selectedAccount)
s.EqualError(err, testErr.Error()) s.EqualError(err, testErr.Error())
resultNonce, _ = s.manager.localNonce.Load(args.From) resultNonce, _ = s.manager.nonce.localNonce[chainID].Load(args.From)
s.Equal(uint64(nonce)+1, resultNonce.(uint64)) s.Equal(uint64(nonce)+1, resultNonce.(uint64))
} }
@ -330,8 +333,6 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() {
for _, scenario := range scenarios { for _, scenario := range scenarios {
desc := fmt.Sprintf("local nonce: %d, tx nonce: %d, expect error: %v", scenario.localNonce, scenario.txNonce, scenario.expectError) desc := fmt.Sprintf("local nonce: %d, tx nonce: %d, expect error: %v", scenario.localNonce, scenario.txNonce, scenario.expectError)
s.T().Run(desc, func(t *testing.T) { s.T().Run(desc, func(t *testing.T) {
s.manager.localNonce.Store(address, uint64(scenario.localNonce))
nonce := scenario.txNonce nonce := scenario.txNonce
from := address from := address
to := address to := address
@ -340,7 +341,8 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() {
gasPrice := (*hexutil.Big)(big.NewInt(2000000000)) gasPrice := (*hexutil.Big)(big.NewInt(2000000000))
data := []byte{} data := []byte{}
chainID := big.NewInt(int64(s.nodeConfig.NetworkID)) chainID := big.NewInt(int64(s.nodeConfig.NetworkID))
s.manager.nonce.localNonce[s.nodeConfig.NetworkID] = &sync.Map{}
s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Store(address, uint64(scenario.localNonce))
args := SendTxArgs{ args := SendTxArgs{
From: from, From: from,
To: &to, To: &to,
@ -376,12 +378,13 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() {
if scenario.expectError { if scenario.expectError {
s.Error(err) s.Error(err)
// local nonce should not be incremented // local nonce should not be incremented
resultNonce, _ := s.manager.localNonce.Load(args.From) resultNonce, _ := s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Load(args.From)
s.Equal(uint64(scenario.localNonce), resultNonce.(uint64)) s.Equal(uint64(scenario.localNonce), resultNonce.(uint64))
} else { } else {
s.NoError(err) s.NoError(err)
// local nonce should be incremented // local nonce should be incremented
resultNonce, _ := s.manager.localNonce.Load(args.From) resultNonce, _ := s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Load(args.From)
s.Equal(uint64(nonce)+1, resultNonce.(uint64)) s.Equal(uint64(nonce)+1, resultNonce.(uint64))
} }
}) })