diff --git a/multiaccounts/accounts/database.go b/multiaccounts/accounts/database.go index f65d03642..6b94fce65 100644 --- a/multiaccounts/accounts/database.go +++ b/multiaccounts/accounts/database.go @@ -555,8 +555,8 @@ func (db *Database) getKeypairs(tx *sql.Tx, keyUID string, includeRemoved bool) return keypairs, nil } -func (db *Database) getKeypairByKeyUID(tx *sql.Tx, keyUID string) (*Keypair, error) { - keypairs, err := db.getKeypairs(tx, keyUID, false) +func (db *Database) getKeypairByKeyUID(tx *sql.Tx, keyUID string, includeRemoved bool) (*Keypair, error) { + keypairs, err := db.getKeypairs(tx, keyUID, includeRemoved) if err != nil && err != sql.ErrNoRows { return nil, err } @@ -700,7 +700,7 @@ func (db *Database) markKeypairRemoved(tx *sql.Tx, keyUID string, clock uint64) return errDbTransactionIsNil } - keypair, err := db.getKeypairByKeyUID(tx, keyUID) + keypair, err := db.getKeypairByKeyUID(tx, keyUID, false) if err != nil { return err } @@ -752,7 +752,7 @@ func (db *Database) GetAllKeypairs() ([]*Keypair, error) { // Returns keypair if it is not marked as removed and its accounts which are not marked as removed. func (db *Database) GetKeypairByKeyUID(keyUID string) (*Keypair, error) { - return db.getKeypairByKeyUID(nil, keyUID) + return db.getKeypairByKeyUID(nil, keyUID, false) } // Returns active accounts (excluding removed). @@ -799,7 +799,7 @@ func (db *Database) GetAllWatchOnlyAccounts() (res []*Account, err error) { } func (db *Database) IsAnyAccountPartiallyOrFullyOperableForKeyUID(keyUID string) (bool, error) { - kp, err := db.getKeypairByKeyUID(nil, keyUID) + kp, err := db.getKeypairByKeyUID(nil, keyUID, false) if err != nil { return false, err } @@ -847,7 +847,7 @@ func (db *Database) RemoveAccount(address types.Address, clock uint64) error { return err } - kp, err := db.getKeypairByKeyUID(tx, acc.KeyUID) + kp, err := db.getKeypairByKeyUID(tx, acc.KeyUID, false) if err != nil && err != ErrDbKeypairNotFound { return err } @@ -937,7 +937,7 @@ func (db *Database) saveOrUpdateAccounts(tx *sql.Tx, accounts []*Account, update // only watch only accounts have an empty `KeyUID` field var keyUID *string if acc.KeyUID != "" { - relatedKeypair, err = db.getKeypairByKeyUID(tx, acc.KeyUID) + relatedKeypair, err = db.getKeypairByKeyUID(tx, acc.KeyUID, true) if err != nil { if err == sql.ErrNoRows { // all accounts, except watch only accounts, must have a row in `keypairs` table with the same key uid @@ -948,7 +948,7 @@ func (db *Database) saveOrUpdateAccounts(tx *sql.Tx, accounts []*Account, update keyUID = &acc.KeyUID } var exists bool - err = tx.QueryRow("SELECT EXISTS (SELECT 1 FROM keypairs_accounts WHERE address = ?)", acc.Address).Scan(&exists) + err = tx.QueryRow("SELECT EXISTS (SELECT 1 FROM keypairs_accounts WHERE address = ? AND removed = 0)", acc.Address).Scan(&exists) if err != nil { return err } @@ -1089,7 +1089,7 @@ func (db *Database) SaveOrUpdateKeypair(keypair *Keypair) error { }() // If keypair is being saved, not updated, then it must be at least one account and all accounts must have the same key uid. - dbKeypair, err := db.getKeypairByKeyUID(tx, keypair.KeyUID) + dbKeypair, err := db.getKeypairByKeyUID(tx, keypair.KeyUID, true) if err != nil && err != ErrDbKeypairNotFound { return err } @@ -1141,7 +1141,7 @@ func (db *Database) UpdateKeypairName(keyUID string, name string, clock uint64, _ = tx.Rollback() }() - _, err = db.getKeypairByKeyUID(tx, keyUID) + _, err = db.getKeypairByKeyUID(tx, keyUID, false) if err != nil { return err } diff --git a/protocol/messenger_handler.go b/protocol/messenger_handler.go index 1cc7955db..45fc0b13a 100644 --- a/protocol/messenger_handler.go +++ b/protocol/messenger_handler.go @@ -3193,10 +3193,6 @@ func (m *Messenger) handleSyncKeypair(message *protobuf.SyncKeypair) (*accounts. } // in case of keypair update, we need to keep `synced_from` field as it was when keypair was introduced to this device for the first time kp.SyncedFrom = dbKeypair.SyncedFrom - } else { - if kp.Removed { - return nil, nil - } } for _, sAcc := range message.Accounts { @@ -3221,11 +3217,9 @@ func (m *Messenger) handleSyncKeypair(message *protobuf.SyncKeypair) (*accounts. if kp.Removed { // delete all keystore files - for _, dbAcc := range dbKeypair.Accounts { - err = m.deleteKeystoreFileForAddress(dbAcc.Address) - if err != nil { - return nil, err - } + err = m.deleteKeystoreFilesForKeypair(dbKeypair) + if err != nil { + return nil, err } } else if !accountReceivedFromLocalPairing && dbKeypair != nil { for _, dbAcc := range dbKeypair.Accounts { @@ -3251,8 +3245,8 @@ func (m *Messenger) handleSyncKeypair(message *protobuf.SyncKeypair) (*accounts. return nil, err } - // if entire keypair was removed, there is no point to continue - if kp.Removed { + // if entire keypair was removed and keypair is already in db, there is no point to continue + if kp.Removed && dbKeypair != nil { // if keypair is retrieved from backed up data, no need for resolving accounts positions if message.SyncedFrom != accounts.SyncedFromBackup { err = m.settings.ResolveAccountsPositions(message.Clock) diff --git a/protocol/messenger_wallet.go b/protocol/messenger_wallet.go index 4ba127b07..3100b5c1e 100644 --- a/protocol/messenger_wallet.go +++ b/protocol/messenger_wallet.go @@ -190,8 +190,6 @@ func (m *Messenger) deleteKeystoreFileForAddress(address types.Address) error { return err } - lastAcccountOfKeypairWithTheSameKey := len(kp.Accounts) == 1 - if len(kp.Keycards) == 0 { err = m.accountsManager.DeleteAccount(address) var e *account.ErrCannotLocateKeyFile @@ -200,6 +198,7 @@ func (m *Messenger) deleteKeystoreFileForAddress(address types.Address) error { } if acc.Type != accounts.AccountTypeKey { + lastAcccountOfKeypairWithTheSameKey := len(kp.Accounts) == 1 if lastAcccountOfKeypairWithTheSameKey { err = m.accountsManager.DeleteAccount(types.Address(ethcommon.HexToAddress(kp.DerivedFrom))) var e *account.ErrCannotLocateKeyFile @@ -214,6 +213,40 @@ func (m *Messenger) deleteKeystoreFileForAddress(address types.Address) error { return nil } +func (m *Messenger) deleteKeystoreFilesForKeypair(keypair *accounts.Keypair) (err error) { + if keypair == nil || len(keypair.Keycards) > 0 { + return + } + + anyAccountFullyOrPartiallyOperable := false + for _, acc := range keypair.Accounts { + if acc.Removed || acc.Operable == accounts.AccountNonOperable { + continue + } + if !anyAccountFullyOrPartiallyOperable { + anyAccountFullyOrPartiallyOperable = true + } + if acc.Operable == accounts.AccountPartiallyOperable { + continue + } + err = m.accountsManager.DeleteAccount(acc.Address) + var e *account.ErrCannotLocateKeyFile + if err != nil && !errors.As(err, &e) { + return err + } + } + + if anyAccountFullyOrPartiallyOperable && keypair.Type != accounts.KeypairTypeKey { + err = m.accountsManager.DeleteAccount(types.Address(ethcommon.HexToAddress(keypair.DerivedFrom))) + var e *account.ErrCannotLocateKeyFile + if err != nil && !errors.As(err, &e) { + return err + } + } + + return +} + func (m *Messenger) DeleteAccount(address types.Address) error { acc, err := m.settings.GetAccountByAddress(address) if err != nil { @@ -266,11 +299,9 @@ func (m *Messenger) DeleteKeypair(keyUID string) error { return accounts.ErrCannotRemoveProfileKeypair } - for _, acc := range kp.Accounts { - err = m.deleteKeystoreFileForAddress(acc.Address) - if err != nil { - return err - } + err = m.deleteKeystoreFilesForKeypair(kp) + if err != nil { + return err } clock, _ := m.getLastClockWithRelatedChat()