From 883f9677c593494eceadef004932c0434575a5be Mon Sep 17 00:00:00 2001 From: Anthony Laibe Date: Mon, 19 Dec 2022 13:37:37 +0100 Subject: [PATCH] feat: nonce management in multi tx --- contracts/contracts.go | 13 +++ services/wallet/bridge/bridge.go | 1 + services/wallet/bridge/cbridge.go | 42 +++++---- services/wallet/bridge/hop.go | 87 +++++++++++------- services/wallet/bridge/simple.go | 5 + services/wallet/router.go | 148 ++++++++++++++++++++++++------ transactions/nonce.go | 67 ++++++++++++++ transactions/transactor.go | 115 +++++------------------ transactions/transactor_test.go | 25 ++--- 9 files changed, 325 insertions(+), 178 deletions(-) create mode 100644 transactions/nonce.go diff --git a/contracts/contracts.go b/contracts/contracts.go index 32a8126fa..5fff617c5 100644 --- a/contracts/contracts.go +++ b/contracts/contracts.go @@ -12,6 +12,7 @@ import ( hopBridge "github.com/status-im/status-go/contracts/hop/bridge" hopSwap "github.com/status-im/status-go/contracts/hop/swap" hopWrapper "github.com/status-im/status-go/contracts/hop/wrapper" + "github.com/status-im/status-go/contracts/ierc20" "github.com/status-im/status-go/contracts/registrar" "github.com/status-im/status-go/contracts/resolver" "github.com/status-im/status-go/contracts/snt" @@ -64,6 +65,18 @@ func (c *ContractMaker) NewUsernameRegistrar(chainID uint64, contractAddr common ) } +func (c *ContractMaker) NewERC20(chainID uint64, contractAddr common.Address) (*ierc20.IERC20, error) { + backend, err := c.RPCClient.EthClient(chainID) + if err != nil { + return nil, err + } + + return ierc20.NewIERC20( + contractAddr, + backend, + ) +} + func (c *ContractMaker) NewSNT(chainID uint64) (*snt.SNT, error) { contractAddr, err := snt.ContractAddress(chainID) if err != nil { diff --git a/services/wallet/bridge/bridge.go b/services/wallet/bridge/bridge.go index 8c9b14ccc..50a7e6d5c 100644 --- a/services/wallet/bridge/bridge.go +++ b/services/wallet/bridge/bridge.go @@ -84,4 +84,5 @@ type Bridge interface { EstimateGas(from *params.Network, to *params.Network, token *token.Token, amountIn *big.Int) (uint64, error) CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error) + GetContractAddress(network *params.Network, token *token.Token) *common.Address } diff --git a/services/wallet/bridge/cbridge.go b/services/wallet/bridge/cbridge.go index 24eb59531..728ba3086 100644 --- a/services/wallet/bridge/cbridge.go +++ b/services/wallet/bridge/cbridge.go @@ -36,14 +36,16 @@ type CBridgeTxArgs struct { type CBridge struct { rpcClient *rpc.Client + transactor *transactions.Transactor tokenManager *token.Manager prodTransferConfig *cbridge.GetTransferConfigsResponse testTransferConfig *cbridge.GetTransferConfigsResponse } -func NewCbridge(rpcClient *rpc.Client, tokenManager *token.Manager) *CBridge { +func NewCbridge(rpcClient *rpc.Client, transactor *transactions.Transactor, tokenManager *token.Manager) *CBridge { return &CBridge{ rpcClient: rpcClient, + transactor: transactor, tokenManager: tokenManager, } } @@ -219,6 +221,25 @@ func (s *CBridge) EstimateGas(from, to *params.Network, token *token.Token, amou return 200000, nil //default gas limit for erc20 transaction } +func (s *CBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address { + transferConfig, err := s.getTransferConfig(network.IsTest) + if err != nil { + return nil + } + if transferConfig.Err != nil { + return nil + } + + for _, chain := range transferConfig.Chains { + if uint64(chain.Id) == network.ChainID { + addr := common.HexToAddress(chain.ContractAddr) + return &addr + } + } + + return nil +} + func (s *CBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (types.Hash, error) { fromNetwork := s.rpcClient.NetworkManager.Find(sendArgs.ChainID) if fromNetwork == nil { @@ -228,27 +249,16 @@ func (s *CBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.Sel if tk == nil { return types.HexToHash(""), errors.New("token not found") } - transferConfig, err := s.getTransferConfig(fromNetwork.IsTest) - if err != nil { - return types.HexToHash(""), err - } - if transferConfig.Err != nil { - return types.HexToHash(""), errors.New(transferConfig.Err.Msg) - } - - addrs := "" - for _, chain := range transferConfig.Chains { - if uint64(chain.Id) == sendArgs.ChainID { - addrs = chain.ContractAddr - break - } + addrs := s.GetContractAddress(fromNetwork, nil) + if addrs == nil { + return types.HexToHash(""), errors.New("contract not found") } backend, err := s.rpcClient.EthClient(sendArgs.ChainID) if err != nil { return types.HexToHash(""), err } - contract, err := celer.NewCeler(common.HexToAddress(addrs), backend) + contract, err := celer.NewCeler(*addrs, backend) if err != nil { return types.HexToHash(""), err } diff --git a/services/wallet/bridge/hop.go b/services/wallet/bridge/hop.go index c169f30dc..1eba42a99 100644 --- a/services/wallet/bridge/hop.go +++ b/services/wallet/bridge/hop.go @@ -11,6 +11,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/status-im/status-go/account" "github.com/status-im/status-go/contracts" + "github.com/status-im/status-go/contracts/hop" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" "github.com/status-im/status-go/rpc" @@ -83,12 +84,16 @@ type HopTxArgs struct { } type HopBridge struct { + transactor *transactions.Transactor + tokenManager *token.Manager contractMaker *contracts.ContractMaker } -func NewHopBridge(rpcClient *rpc.Client) *HopBridge { +func NewHopBridge(rpcClient *rpc.Client, transactor *transactions.Transactor, tokenManager *token.Manager) *HopBridge { return &HopBridge{ contractMaker: &contracts.ContractMaker{RPCClient: rpcClient}, + transactor: transactor, + tokenManager: tokenManager, } } @@ -157,39 +162,57 @@ func (h *HopBridge) EstimateGas(from, to *params.Network, token *token.Token, am return 500000 + 1000, nil } -func (h *HopBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { - networks, err := h.contractMaker.RPCClient.NetworkManager.Get(false) - if err != nil { - return hash, err - } - var fromNetwork *params.Network - for _, network := range networks { - if network.ChainID == sendArgs.ChainID { - fromNetwork = network - break - } +func (h *HopBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address { + var address common.Address + if network.Layer == 1 { + address, _ = hop.L1BridgeContractAddress(network.ChainID, token.Symbol) + } else { + address, _ = hop.L2AmmWrapperContractAddress(network.ChainID, token.Symbol) } - if fromNetwork.Layer == 1 { - return h.sendToL2(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount) - } - return h.swapAndSend(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount) + return &address } -func (h *HopBridge) sendToL2(chainID uint64, sendArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { - bridge, err := h.contractMaker.NewHopL1Bridge(chainID, sendArgs.Symbol) +func (h *HopBridge) Send(sendArgs *TransactionBridge, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { + fromNetwork := h.contractMaker.RPCClient.NetworkManager.Find(sendArgs.ChainID) + if fromNetwork == nil { + return hash, err + } + + nonce, unlock, err := h.transactor.NextNonce(h.contractMaker.RPCClient, sendArgs.ChainID, sendArgs.HopTx.From) if err != nil { return hash, err } - txOpts := sendArgs.ToTransactOpts(getSigner(chainID, sendArgs.From, verifiedAccount)) - txOpts.Value = (*big.Int)(sendArgs.Amount) + defer func() { + unlock(err == nil, nonce) + }() + argNonce := hexutil.Uint64(nonce) + sendArgs.HopTx.Nonce = &argNonce + + token := h.tokenManager.FindToken(fromNetwork, sendArgs.HopTx.Symbol) + if fromNetwork.Layer == 1 { + hash, err = h.sendToL2(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount, token) + } + hash, err = h.swapAndSend(sendArgs.ChainID, sendArgs.HopTx, verifiedAccount, token) + return hash, err +} + +func (h *HopBridge) sendToL2(chainID uint64, hopArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey, token *token.Token) (hash types.Hash, err error) { + bridge, err := h.contractMaker.NewHopL1Bridge(chainID, hopArgs.Symbol) + if err != nil { + return hash, err + } + txOpts := hopArgs.ToTransactOpts(getSigner(chainID, hopArgs.From, verifiedAccount)) + if token.IsNative() { + txOpts.Value = (*big.Int)(hopArgs.Amount) + } now := time.Now() deadline := big.NewInt(now.Unix() + 604800) tx, err := bridge.SendToL2( txOpts, - big.NewInt(int64(sendArgs.ChainID)), - sendArgs.Recipient, - sendArgs.Amount.ToInt(), + big.NewInt(int64(hopArgs.ChainID)), + hopArgs.Recipient, + hopArgs.Amount.ToInt(), big.NewInt(0), deadline, common.HexToAddress("0x0"), @@ -202,22 +225,24 @@ func (h *HopBridge) sendToL2(chainID uint64, sendArgs *HopTxArgs, verifiedAccoun return types.Hash(tx.Hash()), nil } -func (h *HopBridge) swapAndSend(chainID uint64, sendArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { - ammWrapper, err := h.contractMaker.NewHopL2AmmWrapper(chainID, sendArgs.Symbol) +func (h *HopBridge) swapAndSend(chainID uint64, hopArgs *HopTxArgs, verifiedAccount *account.SelectedExtKey, token *token.Token) (hash types.Hash, err error) { + ammWrapper, err := h.contractMaker.NewHopL2AmmWrapper(chainID, hopArgs.Symbol) if err != nil { return hash, err } - txOpts := sendArgs.ToTransactOpts(getSigner(chainID, sendArgs.From, verifiedAccount)) - txOpts.Value = (*big.Int)(sendArgs.Amount) + txOpts := hopArgs.ToTransactOpts(getSigner(chainID, hopArgs.From, verifiedAccount)) + if token.IsNative() { + txOpts.Value = (*big.Int)(hopArgs.Amount) + } now := time.Now() deadline := big.NewInt(now.Unix() + 604800) tx, err := ammWrapper.SwapAndSend( txOpts, - big.NewInt(int64(sendArgs.ChainID)), - sendArgs.Recipient, - sendArgs.Amount.ToInt(), - sendArgs.BonderFee.ToInt(), + big.NewInt(int64(hopArgs.ChainID)), + hopArgs.Recipient, + hopArgs.Amount.ToInt(), + hopArgs.BonderFee.ToInt(), big.NewInt(0), deadline, big.NewInt(0), diff --git a/services/wallet/bridge/simple.go b/services/wallet/bridge/simple.go index bb06e6e9a..78e42ecd2 100644 --- a/services/wallet/bridge/simple.go +++ b/services/wallet/bridge/simple.go @@ -3,6 +3,7 @@ package bridge import ( "math/big" + "github.com/ethereum/go-ethereum/common" "github.com/status-im/status-go/account" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" @@ -46,3 +47,7 @@ func (s *SimpleBridge) Send(sendArgs *TransactionBridge, verifiedAccount *accoun func (s *SimpleBridge) CalculateAmountOut(from, to *params.Network, amountIn *big.Int, symbol string) (*big.Int, error) { return amountIn, nil } + +func (s *SimpleBridge) GetContractAddress(network *params.Network, token *token.Token) *common.Address { + return nil +} diff --git a/services/wallet/router.go b/services/wallet/router.go index 9d3d1f3b1..08c9877e3 100644 --- a/services/wallet/router.go +++ b/services/wallet/router.go @@ -7,12 +7,19 @@ import ( "math" "math/big" "sort" + "strings" "sync" + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/status-im/status-go/contracts" + "github.com/status-im/status-go/contracts/ierc20" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc" "github.com/status-im/status-go/services/wallet/async" "github.com/status-im/status-go/services/wallet/bigint" "github.com/status-im/status-go/services/wallet/bridge" @@ -103,19 +110,23 @@ func (s SendType) EstimateGas(service *Service, network *params.Network) uint64 var zero = big.NewInt(0) type Path struct { - BridgeName string - From *params.Network - To *params.Network - MaxAmountIn *hexutil.Big - AmountIn *hexutil.Big - AmountInLocked bool - AmountOut *hexutil.Big - GasAmount uint64 - GasFees *SuggestedFees - BonderFees *hexutil.Big - TokenFees *big.Float - Cost *big.Float - EstimatedTime TransactionEstimation + BridgeName string + From *params.Network + To *params.Network + MaxAmountIn *hexutil.Big + AmountIn *hexutil.Big + AmountInLocked bool + AmountOut *hexutil.Big + GasAmount uint64 + GasFees *SuggestedFees + BonderFees *hexutil.Big + TokenFees *big.Float + Cost *big.Float + EstimatedTime TransactionEstimation + ApprovalRequired bool + ApprovalGasFees *big.Float + ApprovalAmountRequired *hexutil.Big + ApprovalContractAddress *common.Address } func (p *Path) Equal(o *Path) bool { @@ -334,13 +345,13 @@ func newSuggestedRoutes( func NewRouter(s *Service) *Router { bridges := make(map[string]bridge.Bridge) simple := bridge.NewSimpleBridge(s.transactor) - hop := bridge.NewHopBridge(s.rpcClient) - cbridge := bridge.NewCbridge(s.rpcClient, s.tokenManager) + cbridge := bridge.NewCbridge(s.rpcClient, s.transactor, s.tokenManager) + hop := bridge.NewHopBridge(s.rpcClient, s.transactor, s.tokenManager) bridges[simple.Name()] = simple bridges[hop.Name()] = hop bridges[cbridge.Name()] = cbridge - return &Router{s, bridges} + return &Router{s, bridges, s.rpcClient} } func containsNetworkChainID(network *params.Network, chainIDs []uint64) bool { @@ -354,8 +365,66 @@ func containsNetworkChainID(network *params.Network, chainIDs []uint64) bool { } type Router struct { - s *Service - bridges map[string]bridge.Bridge + s *Service + bridges map[string]bridge.Bridge + rpcClient *rpc.Client +} + +func (r *Router) requireApproval(ctx context.Context, bridge bridge.Bridge, account common.Address, network *params.Network, token *token.Token, amountIn *big.Int) (bool, *big.Int, uint64, *common.Address, error) { + if token.IsNative() { + return false, nil, 0, nil, nil + } + contractMaker := &contracts.ContractMaker{RPCClient: r.rpcClient} + + bridgeAddress := bridge.GetContractAddress(network, token) + if bridgeAddress == nil { + return false, nil, 0, nil, nil + } + + contract, err := contractMaker.NewERC20(network.ChainID, token.Address) + if err != nil { + return false, nil, 0, nil, err + } + + allowance, err := contract.Allowance(&bind.CallOpts{ + Context: ctx, + }, account, *bridgeAddress) + + if err != nil { + return false, nil, 0, nil, err + } + + if allowance.Cmp(amountIn) >= 0 { + return false, nil, 0, nil, nil + } + + ethClient, err := r.rpcClient.EthClient(network.ChainID) + if err != nil { + return false, nil, 0, nil, err + } + + erc20ABI, err := abi.JSON(strings.NewReader(ierc20.IERC20ABI)) + if err != nil { + return false, nil, 0, nil, err + } + + data, err := erc20ABI.Pack("approve", bridgeAddress, amountIn) + if err != nil { + return false, nil, 0, nil, err + } + + estimate, err := ethClient.EstimateGas(context.Background(), ethereum.CallMsg{ + From: account, + To: &token.Address, + Value: big.NewInt(0), + Data: data, + }) + if err != nil { + return false, nil, 0, nil, err + } + + return true, amountIn, estimate, bridgeAddress, nil + } func (r *Router) getBalance(ctx context.Context, network *params.Network, token *token.Token, account common.Address) (*big.Int, error) { @@ -509,34 +578,53 @@ func (r *Router) suggestedRoutes( continue } + approvalRequired, approvalAmountRequired, approvalGasLimit, approvalContractAddress, err := r.requireApproval(ctx, bridge, account, network, token, amountIn) + if err != nil { + continue + } + approvalGasFees := new(big.Float).Mul(gweiToEth(maxFees), big.NewFloat((float64(approvalGasLimit)))) + + approvalGasCost := new(big.Float) + approvalGasCost.Mul( + approvalGasFees, + big.NewFloat(prices["ETH"]), + ) + gasCost := new(big.Float) gasCost.Mul( new(big.Float).Mul(gweiToEth(maxFees), big.NewFloat((float64(gasLimit)))), big.NewFloat(prices["ETH"]), ) + tokenFeesAsFloat := new(big.Float).Quo( new(big.Float).SetInt(tokenFees), big.NewFloat(math.Pow(10, float64(token.Decimals))), ) tokenCost := new(big.Float) tokenCost.Mul(tokenFeesAsFloat, big.NewFloat(prices[tokenSymbol])) + cost := new(big.Float) cost.Add(tokenCost, gasCost) + cost.Add(cost, approvalGasCost) mu.Lock() candidates = append(candidates, &Path{ - BridgeName: bridge.Name(), - From: network, - To: dest, - MaxAmountIn: maxAmountIn, - AmountIn: (*hexutil.Big)(zero), - AmountOut: (*hexutil.Big)(zero), - GasAmount: gasLimit, - GasFees: gasFees, - BonderFees: (*hexutil.Big)(bonderFees), - TokenFees: tokenFeesAsFloat, - Cost: cost, - EstimatedTime: estimatedTime, + BridgeName: bridge.Name(), + From: network, + To: dest, + MaxAmountIn: maxAmountIn, + AmountIn: (*hexutil.Big)(zero), + AmountOut: (*hexutil.Big)(zero), + GasAmount: gasLimit, + GasFees: gasFees, + BonderFees: (*hexutil.Big)(bonderFees), + TokenFees: tokenFeesAsFloat, + Cost: cost, + EstimatedTime: estimatedTime, + ApprovalRequired: approvalRequired, + ApprovalGasFees: approvalGasFees, + ApprovalAmountRequired: (*hexutil.Big)(approvalAmountRequired), + ApprovalContractAddress: approvalContractAddress, }) mu.Unlock() } diff --git a/transactions/nonce.go b/transactions/nonce.go new file mode 100644 index 000000000..999f971e3 --- /dev/null +++ b/transactions/nonce.go @@ -0,0 +1,67 @@ +package transactions + +import ( + "context" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/status-im/status-go/eth-node/types" +) + +type Nonce struct { + addrLock *AddrLocker + localNonce map[uint64]*sync.Map +} + +func NewNonce() *Nonce { + return &Nonce{ + addrLock: &AddrLocker{}, + localNonce: make(map[uint64]*sync.Map), + } +} + +func (n *Nonce) Next(rpcWrapper *rpcWrapper, from types.Address) (uint64, func(inc bool, nonce uint64), error) { + n.addrLock.LockAddr(from) + current, err := n.GetCurrent(rpcWrapper, from) + unlock := func(inc bool, nonce uint64) { + if inc { + if _, ok := n.localNonce[rpcWrapper.chainID]; !ok { + n.localNonce[rpcWrapper.chainID] = &sync.Map{} + } + + n.localNonce[rpcWrapper.chainID].Store(from, nonce+1) + } + n.addrLock.UnlockAddr(from) + } + + return current, unlock, err +} + +func (n *Nonce) GetCurrent(rpcWrapper *rpcWrapper, from types.Address) (uint64, error) { + var ( + localNonce uint64 + remoteNonce uint64 + ) + if _, ok := n.localNonce[rpcWrapper.chainID]; !ok { + n.localNonce[rpcWrapper.chainID] = &sync.Map{} + } + + // get the local nonce + if val, ok := n.localNonce[rpcWrapper.chainID].Load(from); ok { + localNonce = val.(uint64) + } + + // get the remote nonce + ctx := context.Background() + remoteNonce, err := rpcWrapper.PendingNonceAt(ctx, common.Address(from)) + if err != nil { + return 0, err + } + + // if upstream node returned nonce higher than ours we will use it, as it probably means + // that another client was used for sending transactions + if remoteNonce > localNonce { + return remoteNonce, nil + } + return localNonce, nil +} diff --git a/transactions/transactor.go b/transactions/transactor.go index c16fd6164..c524b7c3f 100644 --- a/transactions/transactor.go +++ b/transactions/transactor.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "math/big" - "sync" "time" ethereum "github.com/ethereum/go-ethereum" @@ -49,18 +48,15 @@ type Transactor struct { sendTxTimeout time.Duration rpcCallTimeout time.Duration networkID uint64 - - addrLock *AddrLocker - localNonce sync.Map - log log.Logger + nonce *Nonce + log log.Logger } // NewTransactor returns a new Manager. func NewTransactor() *Transactor { return &Transactor{ - addrLock: &AddrLocker{}, sendTxTimeout: sendTxTimeout, - localNonce: sync.Map{}, + nonce: NewNonce(), log: log.New("package", "status-go/transactions.Manager"), } } @@ -76,6 +72,11 @@ func (t *Transactor) SetRPC(rpcClient *rpc.Client, timeout time.Duration) { t.rpcCallTimeout = timeout } +func (t *Transactor) NextNonce(rpcClient *rpc.Client, chainID uint64, from types.Address) (uint64, func(inc bool, n uint64), error) { + wrapper := newRPCWrapper(rpcClient, chainID) + return t.nonce.Next(wrapper, from) +} + // SendTransaction is an implementation of eth_sendTransaction. It queues the tx to the sign queue. func (t *Transactor) SendTransaction(sendArgs SendTxArgs, verifiedAccount *account.SelectedExtKey) (hash types.Hash, err error) { hash, err = t.validateAndPropagate(t.rpcWrapper, verifiedAccount, sendArgs) @@ -104,20 +105,13 @@ func (t *Transactor) SendTransactionWithSignature(args SendTxArgs, sig []byte) ( signer := gethtypes.NewLondonSigner(chainID) tx := t.buildTransaction(args) - t.addrLock.LockAddr(args.From) - defer func() { - // nonce should be incremented only if tx completed without error - // and if no other transactions have been sent while signing the current one. - if err == nil { - t.localNonce.Store(args.From, uint64(*args.Nonce)+1) - } - t.addrLock.UnlockAddr(args.From) - }() - - expectedNonce, err := t.getTransactionNonce(args) + expectedNonce, unlock, err := t.nonce.Next(t.rpcWrapper, args.From) if err != nil { return hash, err } + defer func() { + unlock(err == nil, expectedNonce) + }() if tx.Nonce() != expectedNonce { return hash, &ErrBadNonce{tx.Nonce(), expectedNonce} @@ -134,7 +128,6 @@ func (t *Transactor) SendTransactionWithSignature(args SendTxArgs, sig []byte) ( if err := t.rpcWrapper.SendTransaction(ctx, signedTx); err != nil { return hash, err } - return types.Hash(signedTx.Hash()), nil } @@ -145,15 +138,11 @@ func (t *Transactor) HashTransaction(args SendTxArgs) (validatedArgs SendTxArgs, validatedArgs = args - t.addrLock.LockAddr(args.From) - defer func() { - t.addrLock.UnlockAddr(args.From) - }() - - nonce, err := t.getTransactionNonce(validatedArgs) + nonce, unlock, err := t.nonce.Next(t.rpcWrapper, args.From) if err != nil { return validatedArgs, hash, err } + defer unlock(false, 0) gasPrice := (*big.Int)(args.GasPrice) gasFeeCap := (*big.Int)(args.MaxFeePerGas) @@ -250,51 +239,29 @@ func (t *Transactor) validateAndPropagate(rpcWrapper *rpcWrapper, selectedAccoun if !args.Valid() { return hash, ErrInvalidSendTxArgs } - t.addrLock.LockAddr(args.From) - var localNonce uint64 - if val, ok := t.localNonce.Load(args.From); ok { - localNonce = val.(uint64) - } - var nonce uint64 - defer func() { - // nonce should be incremented only if tx completed without error - // if upstream node returned nonce higher than ours we will stick to it - if err == nil && args.Nonce == nil { - t.localNonce.Store(args.From, nonce+1) - } - t.addrLock.UnlockAddr(args.From) + nonce, unlock, err := t.nonce.Next(rpcWrapper, args.From) + if err != nil { + return hash, err + } + if args.Nonce != nil { + nonce = uint64(*args.Nonce) + } + defer func() { + unlock(err == nil, nonce) }() ctx, cancel := context.WithTimeout(context.Background(), t.rpcCallTimeout) defer cancel() - if args.Nonce == nil { - - nonce, err = rpcWrapper.PendingNonceAt(ctx, common.Address(args.From)) - if err != nil { - return hash, err - } - // if upstream node returned nonce higher than ours we will use it, as it probably means - // that another client was used for sending transactions - if localNonce > nonce { - nonce = localNonce - } - } else { - nonce = uint64(*args.Nonce) - } gasPrice := (*big.Int)(args.GasPrice) if !args.IsDynamicFeeTx() && args.GasPrice == nil { - ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout) - defer cancel() gasPrice, err = rpcWrapper.SuggestGasPrice(ctx) if err != nil { return hash, err } } - chainID := big.NewInt(int64(rpcWrapper.chainID)) value := (*big.Int)(args.Value) - var gas uint64 if args.Gas != nil { gas = uint64(*args.Gas) @@ -325,15 +292,13 @@ func (t *Transactor) validateAndPropagate(rpcWrapper *rpcWrapper, selectedAccoun gas = defaultGas } } - tx := t.buildTransactionWithOverrides(nonce, value, gas, gasPrice, args) - signedTx, err := gethtypes.SignTx(tx, gethtypes.NewLondonSigner(chainID), selectedAccount.AccountKey.PrivateKey) if err != nil { return hash, err } - ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout) - defer cancel() + // ctx, cancel = context.WithTimeout(context.Background(), t.rpcCallTimeout) + // defer cancel() if err := rpcWrapper.SendTransaction(ctx, signedTx); err != nil { return hash, err @@ -405,36 +370,6 @@ func (t *Transactor) buildTransactionWithOverrides(nonce uint64, value *big.Int, return tx } -func (t *Transactor) getTransactionNonce(args SendTxArgs) (newNonce uint64, err error) { - var ( - localNonce uint64 - remoteNonce uint64 - ) - - // get the local nonce - if val, ok := t.localNonce.Load(args.From); ok { - localNonce = val.(uint64) - } - - // get the remote nonce - ctx, cancel := context.WithTimeout(context.Background(), t.rpcCallTimeout) - defer cancel() - remoteNonce, err = t.rpcWrapper.PendingNonceAt(ctx, common.Address(args.From)) - if err != nil { - return newNonce, err - } - - // if upstream node returned nonce higher than ours we will use it, as it probably means - // that another client was used for sending transactions - if remoteNonce > localNonce { - newNonce = remoteNonce - } else { - newNonce = localNonce - } - - return newNonce, nil -} - func (t *Transactor) logNewTx(args SendTxArgs, gas uint64, gasPrice *big.Int, value *big.Int) { t.log.Info("New transaction", "From", args.From, diff --git a/transactions/transactor_test.go b/transactions/transactor_test.go index d997c49dc..549e7ca94 100644 --- a/transactions/transactor_test.go +++ b/transactions/transactor_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" "reflect" + "sync" "testing" "time" @@ -49,10 +50,11 @@ func (s *TransactorSuite) SetupTest() { s.server, s.txServiceMock = fake.NewTestServer(s.txServiceMockCtrl) s.client = gethrpc.DialInProc(s.server) - rpcClient, _ := rpc.NewClient(s.client, 1, params.UpstreamRPCConfig{}, nil, nil) - rpcClient.UpstreamChainID = 1 + // expected by simulated backend chainID := gethparams.AllEthashProtocolChanges.ChainID.Uint64() + rpcClient, _ := rpc.NewClient(s.client, chainID, params.UpstreamRPCConfig{}, nil, nil) + rpcClient.UpstreamChainID = chainID nodeConfig, err := utils.MakeTestNodeConfigWithDataDir("", "/tmp", chainID) s.Require().NoError(err) s.nodeConfig = nodeConfig @@ -129,7 +131,7 @@ func (s *TransactorSuite) rlpEncodeTx(args SendTxArgs, config *params.NodeConfig } newTx := gethtypes.NewTx(txData) - chainID := big.NewInt(int64(1)) + chainID := big.NewInt(int64(s.nodeConfig.NetworkID)) signedTx, err := gethtypes.SignTx(newTx, gethtypes.NewLondonSigner(chainID), account.AccountKey.PrivateKey) s.NoError(err) @@ -253,6 +255,7 @@ func (s *TransactorSuite) TestAccountMismatch() { // as the last step, we verify that if tx failed nonce is not updated func (s *TransactorSuite) TestLocalNonce() { txCount := 3 + chainID := s.nodeConfig.NetworkID key, _ := gethcrypto.GenerateKey() selectedAccount := &account.SelectedExtKey{ Address: account.FromAddress(utils.TestConfig.Account1.WalletAddress), @@ -269,7 +272,7 @@ func (s *TransactorSuite) TestLocalNonce() { _, err := s.manager.SendTransaction(args, selectedAccount) s.NoError(err) - resultNonce, _ := s.manager.localNonce.Load(args.From) + resultNonce, _ := s.manager.nonce.localNonce[chainID].Load(args.From) s.Equal(uint64(i)+1, resultNonce.(uint64)) } @@ -284,7 +287,7 @@ func (s *TransactorSuite) TestLocalNonce() { _, err := s.manager.SendTransaction(args, selectedAccount) s.NoError(err) - resultNonce, _ := s.manager.localNonce.Load(args.From) + resultNonce, _ := s.manager.nonce.localNonce[chainID].Load(args.From) s.Equal(uint64(nonce)+1, resultNonce.(uint64)) testErr := errors.New("test") @@ -296,7 +299,7 @@ func (s *TransactorSuite) TestLocalNonce() { _, err = s.manager.SendTransaction(args, selectedAccount) s.EqualError(err, testErr.Error()) - resultNonce, _ = s.manager.localNonce.Load(args.From) + resultNonce, _ = s.manager.nonce.localNonce[chainID].Load(args.From) s.Equal(uint64(nonce)+1, resultNonce.(uint64)) } @@ -330,8 +333,6 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() { for _, scenario := range scenarios { desc := fmt.Sprintf("local nonce: %d, tx nonce: %d, expect error: %v", scenario.localNonce, scenario.txNonce, scenario.expectError) s.T().Run(desc, func(t *testing.T) { - s.manager.localNonce.Store(address, uint64(scenario.localNonce)) - nonce := scenario.txNonce from := address to := address @@ -340,7 +341,8 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() { gasPrice := (*hexutil.Big)(big.NewInt(2000000000)) data := []byte{} chainID := big.NewInt(int64(s.nodeConfig.NetworkID)) - + s.manager.nonce.localNonce[s.nodeConfig.NetworkID] = &sync.Map{} + s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Store(address, uint64(scenario.localNonce)) args := SendTxArgs{ From: from, To: &to, @@ -376,12 +378,13 @@ func (s *TransactorSuite) TestSendTransactionWithSignature() { if scenario.expectError { s.Error(err) // local nonce should not be incremented - resultNonce, _ := s.manager.localNonce.Load(args.From) + resultNonce, _ := s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Load(args.From) s.Equal(uint64(scenario.localNonce), resultNonce.(uint64)) } else { s.NoError(err) // local nonce should be incremented - resultNonce, _ := s.manager.localNonce.Load(args.From) + resultNonce, _ := s.manager.nonce.localNonce[s.nodeConfig.NetworkID].Load(args.From) + s.Equal(uint64(nonce)+1, resultNonce.(uint64)) } })