diff --git a/rpc/client.go b/rpc/client.go index 255451cbb..4a61a496f 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -76,6 +76,9 @@ type ClientInterface interface { EthClient(chainID uint64) (chain.ClientInterface, error) EthClients(chainIDs []uint64) (map[uint64]chain.ClientInterface, error) CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error + Call(result interface{}, chainID uint64, method string, args ...interface{}) error + CallRaw(body string) string + GetNetworkManager() *network.Manager } // Client represents RPC client with custom routing @@ -177,6 +180,10 @@ func NewClient(client *gethrpc.Client, upstreamChainID uint64, upstream params.U return &c, nil } +func (c *Client) GetNetworkManager() *network.Manager { + return c.NetworkManager +} + func (c *Client) SetWalletNotifier(notifier func(chainID uint64, message string)) { c.walletNotifier = notifier } diff --git a/rpc/mock/client/client.go b/rpc/mock/client/client.go index e7af522e9..52b471e10 100644 --- a/rpc/mock/client/client.go +++ b/rpc/mock/client/client.go @@ -1,10 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. // Source: rpc/client.go -// -// Generated by this command: -// -// mockgen -package=mock_rpcclient -destination=rpc/mock/client/client.go -source=rpc/client.go -// // Package mock_rpcclient is a generated GoMock package. package mock_rpcclient @@ -14,6 +9,7 @@ import ( reflect "reflect" chain "github.com/status-im/status-go/rpc/chain" + network "github.com/status-im/status-go/rpc/network" common "github.com/status-im/status-go/services/wallet/common" gomock "go.uber.org/mock/gomock" ) @@ -51,15 +47,34 @@ func (m *MockClientInterface) AbstractEthClient(chainID common.ChainID) (chain.B } // AbstractEthClient indicates an expected call of AbstractEthClient. -func (mr *MockClientInterfaceMockRecorder) AbstractEthClient(chainID any) *gomock.Call { +func (mr *MockClientInterfaceMockRecorder) AbstractEthClient(chainID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AbstractEthClient", reflect.TypeOf((*MockClientInterface)(nil).AbstractEthClient), chainID) } -// CallContext mocks base method. -func (m *MockClientInterface) CallContext(context context.Context, result any, chainID uint64, method string, args ...any) error { +// Call mocks base method. +func (m *MockClientInterface) Call(result interface{}, chainID uint64, method string, args ...interface{}) error { m.ctrl.T.Helper() - varargs := []any{context, result, chainID, method} + varargs := []interface{}{result, chainID, method} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Call", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Call indicates an expected call of Call. +func (mr *MockClientInterfaceMockRecorder) Call(result, chainID, method interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{result, chainID, method}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockClientInterface)(nil).Call), varargs...) +} + +// CallContext mocks base method. +func (m *MockClientInterface) CallContext(context context.Context, result interface{}, chainID uint64, method string, args ...interface{}) error { + m.ctrl.T.Helper() + varargs := []interface{}{context, result, chainID, method} for _, a := range args { varargs = append(varargs, a) } @@ -69,12 +84,26 @@ func (m *MockClientInterface) CallContext(context context.Context, result any, c } // CallContext indicates an expected call of CallContext. -func (mr *MockClientInterfaceMockRecorder) CallContext(context, result, chainID, method any, args ...any) *gomock.Call { +func (mr *MockClientInterfaceMockRecorder) CallContext(context, result, chainID, method interface{}, args ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{context, result, chainID, method}, args...) + varargs := append([]interface{}{context, result, chainID, method}, args...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallContext", reflect.TypeOf((*MockClientInterface)(nil).CallContext), varargs...) } +// CallRaw mocks base method. +func (m *MockClientInterface) CallRaw(body string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CallRaw", body) + ret0, _ := ret[0].(string) + return ret0 +} + +// CallRaw indicates an expected call of CallRaw. +func (mr *MockClientInterfaceMockRecorder) CallRaw(body interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallRaw", reflect.TypeOf((*MockClientInterface)(nil).CallRaw), body) +} + // EthClient mocks base method. func (m *MockClientInterface) EthClient(chainID uint64) (chain.ClientInterface, error) { m.ctrl.T.Helper() @@ -85,7 +114,7 @@ func (m *MockClientInterface) EthClient(chainID uint64) (chain.ClientInterface, } // EthClient indicates an expected call of EthClient. -func (mr *MockClientInterfaceMockRecorder) EthClient(chainID any) *gomock.Call { +func (mr *MockClientInterfaceMockRecorder) EthClient(chainID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClient", reflect.TypeOf((*MockClientInterface)(nil).EthClient), chainID) } @@ -100,7 +129,21 @@ func (m *MockClientInterface) EthClients(chainIDs []uint64) (map[uint64]chain.Cl } // EthClients indicates an expected call of EthClients. -func (mr *MockClientInterfaceMockRecorder) EthClients(chainIDs any) *gomock.Call { +func (mr *MockClientInterfaceMockRecorder) EthClients(chainIDs interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EthClients", reflect.TypeOf((*MockClientInterface)(nil).EthClients), chainIDs) } + +// GetNetworkManager mocks base method. +func (m *MockClientInterface) GetNetworkManager() *network.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkManager") + ret0, _ := ret[0].(*network.Manager) + return ret0 +} + +// GetNetworkManager indicates an expected call of GetNetworkManager. +func (mr *MockClientInterfaceMockRecorder) GetNetworkManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkManager", reflect.TypeOf((*MockClientInterface)(nil).GetNetworkManager)) +} diff --git a/services/connector/api.go b/services/connector/api.go index 69d82c15e..58b851b86 100644 --- a/services/connector/api.go +++ b/services/connector/api.go @@ -1,6 +1,7 @@ package connector import ( + "context" "encoding/json" "errors" "fmt" @@ -25,6 +26,7 @@ func NewAPI(s *Service) *API { // Transactions and signing r.Register("eth_sendTransaction", &commands.SendTransactionCommand{ + RpcClient: s.rpc, Db: s.db, ClientHandler: c, }) @@ -65,7 +67,7 @@ func NewAPI(s *Service) *API { } } -func (api *API) forwardRPC(URL string, inputJSON string) (interface{}, error) { +func (api *API) forwardRPC(URL string, request commands.RPCRequest) (interface{}, error) { dApp, err := persistence.SelectDAppByUrl(api.s.db, URL) if err != nil { return "", err @@ -75,8 +77,17 @@ func (api *API) forwardRPC(URL string, inputJSON string) (interface{}, error) { return "", commands.ErrDAppIsNotPermittedByUser } + if request.ChainID != dApp.ChainID { + request.ChainID = dApp.ChainID + } + var response map[string]interface{} - rawResponse := api.s.rpc.CallRaw(inputJSON) + byteRequest, err := json.Marshal(request) + if err != nil { + return "", err + } + + rawResponse := api.s.rpc.CallRaw(string(byteRequest)) if err := json.Unmarshal([]byte(rawResponse), &response); err != nil { return "", err } @@ -95,17 +106,17 @@ func (api *API) forwardRPC(URL string, inputJSON string) (interface{}, error) { return nil, ErrInvalidResponseFromForwardedRpc } -func (api *API) CallRPC(inputJSON string) (interface{}, error) { +func (api *API) CallRPC(ctx context.Context, inputJSON string) (interface{}, error) { request, err := commands.RPCRequestFromJSON(inputJSON) if err != nil { return "", err } if command, exists := api.r.GetCommand(request.Method); exists { - return command.Execute(request) + return command.Execute(ctx, request) } - return api.forwardRPC(request.URL, inputJSON) + return api.forwardRPC(request.URL, request) } func (api *API) RecallDAppPermission(origin string) error { diff --git a/services/connector/chainutils/utils.go b/services/connector/chainutils/utils.go index 4e454a9b8..e4887e1d8 100644 --- a/services/connector/chainutils/utils.go +++ b/services/connector/chainutils/utils.go @@ -5,20 +5,16 @@ import ( "fmt" "strconv" - "github.com/status-im/status-go/params" + "github.com/status-im/status-go/rpc/network" ) -type NetworkManagerInterface interface { - GetActiveNetworks() ([]*params.Network, error) -} - var ( ErrNoActiveNetworks = errors.New("no active networks available") ErrUnsupportedNetwork = errors.New("unsupported network") ) // GetSupportedChainIDs retrieves the chain IDs from the provided NetworkManager. -func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, error) { +func GetSupportedChainIDs(networkManager *network.Manager) ([]uint64, error) { activeNetworks, err := networkManager.GetActiveNetworks() if err != nil { return nil, err @@ -36,7 +32,7 @@ func GetSupportedChainIDs(networkManager NetworkManagerInterface) ([]uint64, err return chainIDs, nil } -func GetDefaultChainID(networkManager NetworkManagerInterface) (uint64, error) { +func GetDefaultChainID(networkManager *network.Manager) (uint64, error) { chainIDs, err := GetSupportedChainIDs(networkManager) if err != nil { return 0, err diff --git a/services/connector/commands/accounts.go b/services/connector/commands/accounts.go index de2a96149..39f3ea1e8 100644 --- a/services/connector/commands/accounts.go +++ b/services/connector/commands/accounts.go @@ -1,6 +1,7 @@ package commands import ( + "context" "database/sql" "strings" @@ -16,7 +17,7 @@ func FormatAccountAddressToResponse(address types.Address) []string { return []string{strings.ToLower(address.Hex())} } -func (c *AccountsCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *AccountsCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/chain_id.go b/services/connector/commands/chain_id.go index 0d2d17cb7..d3cbcd000 100644 --- a/services/connector/commands/chain_id.go +++ b/services/connector/commands/chain_id.go @@ -1,19 +1,21 @@ package commands import ( + "context" "database/sql" + "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" walletCommon "github.com/status-im/status-go/services/wallet/common" ) type ChainIDCommand struct { - NetworkManager NetworkManagerInterface + NetworkManager *network.Manager Db *sql.DB } -func (c *ChainIDCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *ChainIDCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/personal_sign.go b/services/connector/commands/personal_sign.go index 16a1b541f..2c14784e3 100644 --- a/services/connector/commands/personal_sign.go +++ b/services/connector/commands/personal_sign.go @@ -1,6 +1,7 @@ package commands import ( + "context" "database/sql" "errors" "fmt" @@ -51,7 +52,7 @@ func (r *RPCRequest) getPersonalSignParams() (*PersonalSignParams, error) { }, nil } -func (c *PersonalSignCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *PersonalSignCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/request_accounts.go b/services/connector/commands/request_accounts.go index d34d3150d..0521a3ede 100644 --- a/services/connector/commands/request_accounts.go +++ b/services/connector/commands/request_accounts.go @@ -1,6 +1,7 @@ package commands import ( + "context" "database/sql" "errors" @@ -26,7 +27,7 @@ type RawAccountsResponse struct { Result []accounts.Account `json:"result"` } -func (c *RequestAccountsCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/request_permissions.go b/services/connector/commands/request_permissions.go index 220a04ea0..836d148d6 100644 --- a/services/connector/commands/request_permissions.go +++ b/services/connector/commands/request_permissions.go @@ -1,6 +1,7 @@ package commands import ( + "context" "errors" "fmt" "time" @@ -51,7 +52,7 @@ func (c *RequestPermissionsCommand) getPermissionResponse(methodName string) Per return response } -func (c *RequestPermissionsCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *RequestPermissionsCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/revoke_permissions.go b/services/connector/commands/revoke_permissions.go index 997218fe6..9e92d5ade 100644 --- a/services/connector/commands/revoke_permissions.go +++ b/services/connector/commands/revoke_permissions.go @@ -1,6 +1,7 @@ package commands import ( + "context" "database/sql" persistence "github.com/status-im/status-go/services/connector/database" @@ -11,7 +12,7 @@ type RevokePermissionsCommand struct { Db *sql.DB } -func (c *RevokePermissionsCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *RevokePermissionsCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/commands/rpc_traits.go b/services/connector/commands/rpc_traits.go index 932c03c35..8c229f089 100644 --- a/services/connector/commands/rpc_traits.go +++ b/services/connector/commands/rpc_traits.go @@ -1,6 +1,7 @@ package commands import ( + "context" "encoding/json" "errors" "fmt" @@ -11,6 +12,17 @@ import ( "github.com/status-im/status-go/transactions" ) +const ( + Method_EthAccounts = "eth_accounts" + Method_EthRequestAccounts = "eth_requestAccounts" + Method_EthChainId = "eth_chainId" + Method_PersonalSign = "personal_sign" + Method_EthSendTransaction = "eth_sendTransaction" + Method_RequestPermissions = "wallet_requestPermissions" + Method_RevokePermissions = "wallet_revokePermissions" + Method_SwitchEthereumChain = "wallet_switchEthereumChain" +) + // errors var ( ErrRequestMissingDAppData = errors.New("request missing dApp data") @@ -26,10 +38,11 @@ type RPCRequest struct { URL string `json:"url"` Name string `json:"name"` IconURL string `json:"iconUrl"` + ChainID uint64 `json:"chainId"` } type RPCCommand interface { - Execute(request RPCRequest) (interface{}, error) + Execute(ctx context.Context, request RPCRequest) (interface{}, error) } type RequestAccountsAcceptedArgs struct { diff --git a/services/connector/commands/send_transaction.go b/services/connector/commands/send_transaction.go index b22b475cf..332da2a8a 100644 --- a/services/connector/commands/send_transaction.go +++ b/services/connector/commands/send_transaction.go @@ -1,12 +1,18 @@ package commands import ( + "context" "database/sql" "encoding/json" "errors" "fmt" + "math/big" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/status-im/status-go/rpc" persistence "github.com/status-im/status-go/services/connector/database" + "github.com/status-im/status-go/services/wallet/router/fees" "github.com/status-im/status-go/signal" "github.com/status-im/status-go/transactions" ) @@ -14,9 +20,11 @@ import ( var ( ErrParamsFromAddressIsNotShared = errors.New("from parameter address is not dApp's shared account") ErrNoTransactionParamsFound = errors.New("no transaction in params found") + ErrSendingParamsInvalid = errors.New("sending params are invalid") ) type SendTransactionCommand struct { + RpcClient rpc.ClientInterface Db *sql.DB ClientHandler ClientSideHandlerInterface } @@ -45,7 +53,7 @@ func (r *RPCRequest) getSendTransactionParams() (*transactions.SendTxArgs, error return &sendTxArgs, nil } -func (c *SendTransactionCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *SendTransactionCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err @@ -65,10 +73,49 @@ func (c *SendTransactionCommand) Execute(request RPCRequest) (interface{}, error return "", err } + if !params.Valid() { + return "", ErrSendingParamsInvalid + } + if params.From != dApp.SharedAccount { return "", ErrParamsFromAddressIsNotShared } + if params.Value == nil { + params.Value = (*hexutil.Big)(big.NewInt(0)) + } + + if params.GasPrice == nil || (params.MaxFeePerGas == nil && params.MaxPriorityFeePerGas == nil) { + feeManager := &fees.FeeManager{ + RPCClient: c.RpcClient, + } + fetchedFees, err := feeManager.SuggestedFees(ctx, dApp.ChainID) + if err != nil { + return "", err + } + + if !fetchedFees.EIP1559Enabled { + params.GasPrice = (*hexutil.Big)(fetchedFees.GasPrice) + } else { + params.MaxFeePerGas = (*hexutil.Big)(fetchedFees.FeeFor(fees.GasFeeMedium)) + params.MaxPriorityFeePerGas = (*hexutil.Big)(fetchedFees.MaxPriorityFeePerGas) + } + } + + if params.Nonce == nil { + ethClient, err := c.RpcClient.EthClient(dApp.ChainID) + if err != nil { + return "", err + } + + nonce, err := ethClient.PendingNonceAt(ctx, common.Address(dApp.SharedAccount)) + if err != nil { + return "", err + } + + params.Nonce = (*hexutil.Uint64)(&nonce) + } + hash, err := c.ClientHandler.RequestSendTransaction(signal.ConnectorDApp{ URL: request.URL, Name: request.Name, diff --git a/services/connector/commands/switch_ethereum_chain.go b/services/connector/commands/switch_ethereum_chain.go index f2ec93522..8824f547e 100644 --- a/services/connector/commands/switch_ethereum_chain.go +++ b/services/connector/commands/switch_ethereum_chain.go @@ -1,11 +1,13 @@ package commands import ( + "context" "database/sql" "errors" "slices" "strconv" + "github.com/status-im/status-go/rpc/network" "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" walletCommon "github.com/status-im/status-go/services/wallet/common" @@ -20,7 +22,7 @@ var ( ) type SwitchEthereumChainCommand struct { - NetworkManager NetworkManagerInterface + NetworkManager *network.Manager Db *sql.DB } @@ -53,7 +55,7 @@ func (c *SwitchEthereumChainCommand) getSupportedChainIDs() ([]uint64, error) { return chainutils.GetSupportedChainIDs(c.NetworkManager) } -func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (interface{}, error) { +func (c *SwitchEthereumChainCommand) Execute(ctx context.Context, request RPCRequest) (interface{}, error) { err := request.Validate() if err != nil { return "", err diff --git a/services/connector/service.go b/services/connector/service.go index aede99c5e..85d3e3f44 100644 --- a/services/connector/service.go +++ b/services/connector/service.go @@ -6,10 +6,11 @@ import ( "github.com/ethereum/go-ethereum/p2p" gethrpc "github.com/ethereum/go-ethereum/rpc" - "github.com/status-im/status-go/services/connector/commands" + "github.com/status-im/status-go/rpc" + "github.com/status-im/status-go/rpc/network" ) -func NewService(db *sql.DB, rpc commands.RPCClientInterface, nm commands.NetworkManagerInterface) *Service { +func NewService(db *sql.DB, rpc rpc.ClientInterface, nm *network.Manager) *Service { return &Service{ db: db, rpc: rpc, @@ -19,8 +20,8 @@ func NewService(db *sql.DB, rpc commands.RPCClientInterface, nm commands.Network type Service struct { db *sql.DB - rpc commands.RPCClientInterface - nm commands.NetworkManagerInterface + rpc rpc.ClientInterface + nm *network.Manager } func (s *Service) Start() error { diff --git a/services/wallet/router/fees/fees.go b/services/wallet/router/fees/fees.go index 07387ca25..264dc3b6d 100644 --- a/services/wallet/router/fees/fees.go +++ b/services/wallet/router/fees/fees.go @@ -87,7 +87,7 @@ type FeeHistory struct { } type FeeManager struct { - RPCClient *rpc.Client + RPCClient rpc.ClientInterface } func (f *FeeManager) SuggestedFees(ctx context.Context, chainID uint64) (*SuggestedFees, error) {