diff --git a/api/backend_test.go b/api/backend_test.go index 71df252b9..40991933a 100644 --- a/api/backend_test.go +++ b/api/backend_test.go @@ -1296,7 +1296,7 @@ func TestChangeDatabasePassword(t *testing.T) { func TestCreateWallet(t *testing.T) { utils.Init() - password := "some-password2" + password := "some-password2" // nolint: goconst tmpdir := t.TempDir() b := NewGethStatusBackend() @@ -1331,17 +1331,12 @@ func TestCreateWallet(t *testing.T) { walletRootAddress, err := db.GetWalletRootAddress() require.NoError(t, err) - masterRootAddress, err := db.GetMasterAddress() require.NoError(t, err) - fmt.Println("WALLET ROOT", walletRootAddress.String()) - fmt.Println("MASTER ROOT", masterRootAddress.String()) - derivedAddress, err := walletAPI.GetDerivedAddresses(context.Background(), password, walletRootAddress.String(), paths) require.NoError(t, err) require.Len(t, derivedAddress, 1) - fmt.Println("DERVIED", derivedAddress) accountsService := statusNode.AccountService() require.NotNil(t, accountsService) accountsAPI := accountsService.AccountsAPI() @@ -1356,5 +1351,63 @@ func TestCreateWallet(t *testing.T) { Path: derivedAddress[0].Path, }) require.NoError(t, err) - +} + +func TestSetFleet(t *testing.T) { + utils.Init() + password := "some-password2" // nolint: goconst + tmpdir := t.TempDir() + + b := NewGethStatusBackend() + createAccountRequest := &requests.CreateAccount{ + DisplayName: "some-display-name", + CustomizationColor: "#ffffff", + Password: password, + BackupDisabledDataDir: tmpdir, + NetworkID: 1, + LogFilePath: tmpdir + "/log", + } + c := make(chan interface{}, 10) + signal.SetMobileSignalHandler(func(data []byte) { + if strings.Contains(string(data), "node.login") { + c <- struct{}{} + } + }) + + newAccount, err := b.CreateAccountAndLogin(createAccountRequest) + require.NoError(t, err) + statusNode := b.statusNode + require.NotNil(t, statusNode) + + savedSettings, err := b.GetSettings() + require.NoError(t, err) + require.Empty(t, savedSettings.Fleet) + + accountsDB, err := b.accountsDB() + require.NoError(t, err) + err = accountsDB.SaveSettingField(settings.Fleet, statusTestFleet) + require.NoError(t, err) + + savedSettings, err = b.GetSettings() + require.NoError(t, err) + require.NotEmpty(t, savedSettings.Fleet) + require.Equal(t, statusTestFleet, *savedSettings.Fleet) + + require.NoError(t, b.Logout()) + + loginAccountRequest := &requests.Login{ + KeyUID: newAccount.KeyUID, + Password: password, + } + require.NoError(t, b.LoginAccount(loginAccountRequest)) + select { + case <-c: + break + case <-time.After(5 * time.Second): + t.FailNow() + } + // Check is using the right fleet + require.Equal(t, b.config.ClusterConfig.WakuNodes, defaultWakuNodes[statusTestFleet]) + + require.NoError(t, b.Logout()) } diff --git a/api/defaults.go b/api/defaults.go index 254c8b57e..078bb9ea3 100644 --- a/api/defaults.go +++ b/api/defaults.go @@ -106,14 +106,18 @@ func defaultSettings(generatedAccountInfo generator.GeneratedAccountInfo, derive } func SetDefaultFleet(nodeConfig *params.NodeConfig) error { - clusterConfig, err := params.LoadClusterConfigFromFleet(statusProdFleet) + return SetFleet(statusProdFleet, nodeConfig) +} + +func SetFleet(fleet string, nodeConfig *params.NodeConfig) error { + clusterConfig, err := params.LoadClusterConfigFromFleet(fleet) if err != nil { return err } nodeConfig.ClusterConfig = *clusterConfig - nodeConfig.ClusterConfig.WakuNodes = defaultWakuNodes[statusProdFleet] - nodeConfig.ClusterConfig.DiscV5BootstrapNodes = defaultWakuNodes[statusProdFleet] + nodeConfig.ClusterConfig.WakuNodes = defaultWakuNodes[fleet] + nodeConfig.ClusterConfig.DiscV5BootstrapNodes = defaultWakuNodes[fleet] return nil } diff --git a/api/geth_backend.go b/api/geth_backend.go index c29477955..52121a12e 100644 --- a/api/geth_backend.go +++ b/api/geth_backend.go @@ -622,7 +622,20 @@ func (b *GethStatusBackend) loginAccount(request *requests.Login) error { KeycardPairingDataFile: defaultKeycardPairingDataFile, } - err = SetDefaultFleet(defaultCfg) + settings, err := b.GetSettings() + if err != nil { + return err + } + + var fleet string + fleetPtr := settings.Fleet + if fleetPtr == nil || *fleetPtr == "" { + fleet = statusProdFleet + } else { + fleet = *fleetPtr + } + + err = SetFleet(fleet, defaultCfg) if err != nil { return err } @@ -762,13 +775,17 @@ func (b *GethStatusBackend) startNodeWithAccount(acc multiaccounts.Account, pass return nil } +func (b *GethStatusBackend) accountsDB() (*accounts.Database, error) { + return accounts.NewDB(b.appDB) +} + func (b *GethStatusBackend) GetSettings() (*settings.Settings, error) { - accountDB, err := accounts.NewDB(b.appDB) + accountsDB, err := b.accountsDB() if err != nil { return nil, err } - settings, err := accountDB.GetSettings() + settings, err := accountsDB.GetSettings() if err != nil { return nil, err }