diff --git a/multiaccounts/keypairs/database.go b/multiaccounts/keypairs/database.go index 2fc0f0011..c246b5d06 100644 --- a/multiaccounts/keypairs/database.go +++ b/multiaccounts/keypairs/database.go @@ -133,7 +133,7 @@ func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error return kp.processResult(rows, false) } -func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID string, accountAddresses []types.Address) (err error) { +func (kp *KeyPairs) AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair KeyPair) (err error) { var ( tx *sql.Tx insertKcAcc *sql.Stmt @@ -151,7 +151,7 @@ func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID strin }() var tmpKeyUID string - err = tx.QueryRow(`SELECT keycard_uid FROM keycards WHERE keycard_uid = ?`, kcUID).Scan(&tmpKeyUID) + err = tx.QueryRow(`SELECT keycard_uid FROM keycards WHERE keycard_uid = ?`, keyPair.KeycardUID).Scan(&tmpKeyUID) if err != nil { if err == sql.ErrNoRows { insertKc, err := tx.Prepare(` @@ -173,7 +173,7 @@ func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID strin defer insertKc.Close() - _, err = insertKc.Exec(kcUID, kpName, false, KeyUID) + _, err = insertKc.Exec(keyPair.KeycardUID, keyPair.KeycardName, keyPair.KeycardLocked, keyPair.KeyUID) if err != nil { return err } @@ -199,10 +199,10 @@ func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID strin } defer insertKcAcc.Close() - for i := range accountAddresses { - addr := accountAddresses[i] + for i := range keyPair.AccountsAddresses { + addr := keyPair.AccountsAddresses[i] - _, err = insertKcAcc.Exec(kcUID, addr) + _, err = insertKcAcc.Exec(keyPair.KeycardUID, addr) if err != nil { return err } diff --git a/multiaccounts/keypairs/database_test.go b/multiaccounts/keypairs/database_test.go index 5030cca7d..3cf862f3a 100644 --- a/multiaccounts/keypairs/database_test.go +++ b/multiaccounts/keypairs/database_test.go @@ -58,15 +58,18 @@ func TestKeypairs(t *testing.T) { } // Test adding key pairs - err := db.AddMigratedKeyPair(keyPair1.KeycardUID, keyPair1.KeycardName, keyPair1.KeyUID, keyPair1.AccountsAddresses) + err := db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair1) require.NoError(t, err) - err = db.AddMigratedKeyPair(keyPair2.KeycardUID, keyPair2.KeycardName, keyPair2.KeyUID, keyPair2.AccountsAddresses) + err = db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair2) require.NoError(t, err) - err = db.AddMigratedKeyPair(keyPair3.KeycardUID, keyPair3.KeycardName, keyPair3.KeyUID, keyPair3.AccountsAddresses) + err = db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair3) require.NoError(t, err) - err = db.AddMigratedKeyPair(keyPair3.KeycardUID, keyPair3.KeycardName, keyPair3.KeyUID, []types.Address{{0x03}}) + err = db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(KeyPair{ + KeycardUID: keyPair3.KeycardUID, + AccountsAddresses: []types.Address{{0x03}}, + }) require.NoError(t, err) - err = db.AddMigratedKeyPair(keyPair4.KeycardUID, keyPair4.KeycardName, keyPair4.KeyUID, keyPair4.AccountsAddresses) + err = db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(keyPair4) require.NoError(t, err) // Test reading migrated key pairs diff --git a/services/accounts/accounts.go b/services/accounts/accounts.go index 06a3ccc32..bd2b50333 100644 --- a/services/accounts/accounts.go +++ b/services/accounts/accounts.go @@ -441,20 +441,25 @@ func (api *API) VerifyPassword(password string) bool { return err == nil } -func (api *API) AddMigratedKeyPair(ctx context.Context, kcUID string, kpName string, keyUID string, accountAddresses []string, password string) error { - var addresses []types.Address +func (api *API) AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(ctx context.Context, kcUID string, kpName string, keyUID string, accountAddresses []string, password string) error { + kp := keypairs.KeyPair{ + KeycardUID: kcUID, + KeycardName: kpName, + KeycardLocked: false, + KeyUID: keyUID, + } for _, addr := range accountAddresses { - addresses = append(addresses, types.Address(common.HexToAddress(addr))) + kp.AccountsAddresses = append(kp.AccountsAddresses, types.Address(common.HexToAddress(addr))) } - err := api.db.AddMigratedKeyPair(kcUID, kpName, keyUID, addresses) + err := api.db.AddMigratedKeyPairOrAddAccountsIfKeyPairIsAdded(kp) if err != nil { return err } // Once we migrate a keypair, corresponding keystore files need to be deleted. if len(password) > 0 { - for _, addr := range addresses { + for _, addr := range kp.AccountsAddresses { err = api.manager.DeleteAccount(addr, password) if err != nil { return err