diff --git a/services/connector/api.go b/services/connector/api.go index 7212abe25..69d82c15e 100644 --- a/services/connector/api.go +++ b/services/connector/api.go @@ -36,8 +36,8 @@ func NewAPI(s *Service) *API { // Accounts query and dapp permissions // NOTE: Some dApps expect same behavior for both eth_accounts and eth_requestAccounts accountsCommand := &commands.RequestAccountsCommand{ - ClientHandler: c, - AccountsCommand: commands.AccountsCommand{Db: s.db}, + ClientHandler: c, + Db: s.db, } r.Register("eth_accounts", accountsCommand) r.Register("eth_requestAccounts", accountsCommand) @@ -109,6 +109,7 @@ func (api *API) CallRPC(inputJSON string) (interface{}, error) { } func (api *API) RecallDAppPermission(origin string) error { + // TODO: close the websocket connection return persistence.DeleteDApp(api.s.db, origin) } diff --git a/services/connector/commands/chain_id.go b/services/connector/commands/chain_id.go index f72206b15..0d2d17cb7 100644 --- a/services/connector/commands/chain_id.go +++ b/services/connector/commands/chain_id.go @@ -2,7 +2,6 @@ package commands import ( "database/sql" - "strconv" "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" @@ -25,24 +24,20 @@ func (c *ChainIDCommand) Execute(request RPCRequest) (interface{}, error) { return "", err } + var chainId uint64 if dApp == nil { - defaultChainID, err := chainutils.GetDefaultChainID(c.NetworkManager) + chainId, err = chainutils.GetDefaultChainID(c.NetworkManager) if err != nil { return "", err } - - chainId, err := chainutils.GetHexChainID(strconv.FormatUint(defaultChainID, 16)) - if err != nil { - return "", err - } - - return chainId, nil + } else { + chainId = dApp.ChainID } - chainId, err := chainutils.GetHexChainID(walletCommon.ChainID(dApp.ChainID).String()) + chainIdHex, err := chainutils.GetHexChainID(walletCommon.ChainID(chainId).String()) if err != nil { return "", err } - return chainId, nil + return chainIdHex, nil } diff --git a/services/connector/commands/client_handler.go b/services/connector/commands/client_handler.go index 4dda88bd1..a2114f345 100644 --- a/services/connector/commands/client_handler.go +++ b/services/connector/commands/client_handler.go @@ -98,6 +98,10 @@ func (c *ClientSideHandler) RequestShareAccountForDApp(dApp signal.ConnectorDApp } func (c *ClientSideHandler) RequestAccountsAccepted(args RequestAccountsAcceptedArgs) error { + if args.RequestID == "" { + return ErrEmptyRequestID + } + c.responseChannel <- Message{Type: RequestAccountsAccepted, Data: args} return nil } diff --git a/services/connector/commands/request_accounts.go b/services/connector/commands/request_accounts.go index 8e53e6b5a..d34d3150d 100644 --- a/services/connector/commands/request_accounts.go +++ b/services/connector/commands/request_accounts.go @@ -1,6 +1,7 @@ package commands import ( + "database/sql" "errors" "github.com/status-im/status-go/multiaccounts/accounts" @@ -16,7 +17,7 @@ var ( type RequestAccountsCommand struct { ClientHandler ClientSideHandlerInterface - AccountsCommand + Db *sql.DB } type RawAccountsResponse struct { diff --git a/services/connector/commands/request_accounts_test.go b/services/connector/commands/request_accounts_test.go index e055d45ec..ada6d2400 100644 --- a/services/connector/commands/request_accounts_test.go +++ b/services/connector/commands/request_accounts_test.go @@ -17,7 +17,7 @@ func TestFailToRequestAccountsWithMissingDAppFields(t *testing.T) { db, close := SetupTestDB(t) defer close() - cmd := &RequestAccountsCommand{AccountsCommand: AccountsCommand{Db: db}} + cmd := &RequestAccountsCommand{Db: db} // Missing DApp fields request, err := ConstructRPCRequest("eth_requestAccounts", []interface{}{}, nil) @@ -35,8 +35,8 @@ func TestRequestAccountsWithSignalTimeout(t *testing.T) { clientHandler := NewClientSideHandler() cmd := &RequestAccountsCommand{ - ClientHandler: clientHandler, - AccountsCommand: AccountsCommand{Db: db}, + ClientHandler: clientHandler, + Db: db, } request, err := prepareSendTransactionRequest(testDAppData, types.Address{0x01}) @@ -57,8 +57,8 @@ func TestRequestAccountsAcceptedAndRequestAgain(t *testing.T) { clientHandler := NewClientSideHandler() cmd := &RequestAccountsCommand{ - ClientHandler: clientHandler, - AccountsCommand: AccountsCommand{Db: db}, + ClientHandler: clientHandler, + Db: db, } request, err := ConstructRPCRequest("eth_requestAccounts", []interface{}{}, &testDAppData) @@ -118,8 +118,8 @@ func TestRequestAccountsRejected(t *testing.T) { clientHandler := NewClientSideHandler() cmd := &RequestAccountsCommand{ - ClientHandler: clientHandler, - AccountsCommand: AccountsCommand{Db: db}, + ClientHandler: clientHandler, + Db: db, } request, err := ConstructRPCRequest("eth_requestAccounts", []interface{}{}, &testDAppData) diff --git a/services/connector/commands/switch_ethereum_chain.go b/services/connector/commands/switch_ethereum_chain.go index 0f86f6ab6..f2ec93522 100644 --- a/services/connector/commands/switch_ethereum_chain.go +++ b/services/connector/commands/switch_ethereum_chain.go @@ -9,6 +9,7 @@ import ( "github.com/status-im/status-go/services/connector/chainutils" persistence "github.com/status-im/status-go/services/connector/database" walletCommon "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/signal" ) // errors @@ -93,5 +94,10 @@ func (c *SwitchEthereumChainCommand) Execute(request RPCRequest) (interface{}, e return "", err } + signal.SendConnectorDAppChainIdSwitched(signal.ConnectorDAppChainIdSwitchedSignal{ + URL: request.URL, + ChainId: chainId, + }) + return chainId, nil } diff --git a/services/connector/commands/switch_ethereum_chain_test.go b/services/connector/commands/switch_ethereum_chain_test.go index 479ca8bed..774b1d13b 100644 --- a/services/connector/commands/switch_ethereum_chain_test.go +++ b/services/connector/commands/switch_ethereum_chain_test.go @@ -1,6 +1,7 @@ package commands import ( + "encoding/json" "fmt" "testing" @@ -9,6 +10,7 @@ import ( "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/params" walletCommon "github.com/status-im/status-go/services/wallet/common" + "github.com/status-im/status-go/signal" ) func TestFailToSwitchEthereumChainWithMissingDAppFields(t *testing.T) { @@ -67,7 +69,7 @@ func TestFailToSwitchEthereumChainWithUnsupportedChainId(t *testing.T) { assert.Equal(t, ErrUnsupportedNetwork, err) } -func TestSwitchEthereumChain(t *testing.T) { +func TestSwitchEthereumChainSuccess(t *testing.T) { db, close := SetupTestDB(t) defer close() @@ -81,6 +83,26 @@ func TestSwitchEthereumChain(t *testing.T) { }, }) + chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String()) + chainIdSwitched := false + + signal.SetMobileSignalHandler(signal.MobileSignalHandler(func(s []byte) { + var evt EventType + err := json.Unmarshal(s, &evt) + assert.NoError(t, err) + + switch evt.Type { + case signal.EventConnectorDAppChainIdSwitched: + var ev signal.ConnectorDAppChainIdSwitchedSignal + err := json.Unmarshal(evt.Event, &ev) + assert.NoError(t, err) + + assert.Equal(t, chainId, ev.ChainId) + assert.Equal(t, testDAppData.URL, ev.URL) + chainIdSwitched = true + } + })) + cmd := &SwitchEthereumChainCommand{ Db: db, NetworkManager: &nm, @@ -99,6 +121,6 @@ func TestSwitchEthereumChain(t *testing.T) { response, err := cmd.Execute(request) assert.NoError(t, err) - chainId := fmt.Sprintf(`0x%s`, walletCommon.ChainID(walletCommon.EthereumMainnet).String()) assert.Equal(t, chainId, response) + assert.True(t, chainIdSwitched) } diff --git a/services/connector/connector_flows_test.go b/services/connector/connector_flows_test.go index 0f2661a78..d5905c449 100644 --- a/services/connector/connector_flows_test.go +++ b/services/connector/connector_flows_test.go @@ -172,3 +172,72 @@ func TestForwardedRPCs(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedResponse, response) } + +func TestRequestAccountsAfterPermisasionsRevokeTest(t *testing.T) { + db, close := createDB(t) + defer close() + + nm := commands.NetworkManagerMock{} + nm.SetNetworks([]*params.Network{ + { + ChainID: walletCommon.EthereumMainnet, + Layer: 1, + }, + { + ChainID: walletCommon.OptimismMainnet, + Layer: 1, + }, + }) + rpc := commands.RPCClientMock{} + + service := NewService(db, &rpc, &nm) + + api := NewAPI(service) + + accountAddress := types.BytesToAddress(types.FromHex("0x6d0aa2a774b74bb1d36f97700315adf962c69fcg")) + dAppPermissionRevoked := false + dAppPermissionGranted := false + + signal.SetMobileSignalHandler(signal.MobileSignalHandler(func(s []byte) { + var evt commands.EventType + err := json.Unmarshal(s, &evt) + assert.NoError(t, err) + + switch evt.Type { + case signal.EventConnectorDAppPermissionRevoked: + dAppPermissionRevoked = true + case signal.EventConnectorDAppPermissionGranted: + dAppPermissionGranted = true + case signal.EventConnectorSendRequestAccounts: + var ev signal.ConnectorSendRequestAccountsSignal + err := json.Unmarshal(evt.Event, &ev) + assert.NoError(t, err) + + err = api.RequestAccountsAccepted(commands.RequestAccountsAcceptedArgs{ + RequestID: ev.RequestID, + Account: accountAddress, + ChainID: 0x1, + }) + assert.NoError(t, err) + } + })) + + for range [10]int{} { + dAppPermissionRevoked = false + dAppPermissionGranted = false + + // Request accounts + request := "{\"method\": \"eth_requestAccounts\", \"params\": [], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }" + response, err := api.CallRPC(request) + assert.NoError(t, err) + assert.Equal(t, commands.FormatAccountAddressToResponse(accountAddress), response) + assert.Equal(t, true, dAppPermissionGranted) + assert.Equal(t, false, dAppPermissionRevoked) + + // Revoke permissions + request = "{\"method\": \"wallet_revokePermissions\", \"params\": [], \"url\": \"http://testDAppURL123\", \"name\": \"testDAppName\", \"iconUrl\": \"http://testDAppIconUrl\" }" + _, err = api.CallRPC(request) + assert.NoError(t, err) + assert.Equal(t, true, dAppPermissionRevoked) + } +} diff --git a/signal/events_connector.go b/signal/events_connector.go index 27deaae60..de2f16247 100644 --- a/signal/events_connector.go +++ b/signal/events_connector.go @@ -6,6 +6,7 @@ const ( EventConnectorPersonalSign = "connector.personalSign" EventConnectorDAppPermissionGranted = "connector.dAppPermissionGranted" EventConnectorDAppPermissionRevoked = "connector.dAppPermissionRevoked" + EventConnectorDAppChainIdSwitched = "connector.dAppChainIdSwitched" ) type ConnectorDApp struct { @@ -35,6 +36,11 @@ type ConnectorPersonalSignSignal struct { Address string `json:"address"` } +type ConnectorDAppChainIdSwitchedSignal struct { + URL string `json:"url"` + ChainId string `json:"chainId"` +} + func SendConnectorSendRequestAccounts(dApp ConnectorDApp, requestID string) { send(EventConnectorSendRequestAccounts, ConnectorSendRequestAccountsSignal{ ConnectorDApp: dApp, @@ -67,3 +73,7 @@ func SendConnectorDAppPermissionGranted(dApp ConnectorDApp) { func SendConnectorDAppPermissionRevoked(dApp ConnectorDApp) { send(EventConnectorDAppPermissionRevoked, dApp) } + +func SendConnectorDAppChainIdSwitched(payload ConnectorDAppChainIdSwitchedSignal) { + send(EventConnectorDAppChainIdSwitched, payload) +}