fix: `GetAllKnownKeycards` new keypair endpoint added

Handling results of `GetAllMigratedKeyPairs` and `GetMigratedKeyPairByKeyUID`
endpoints updated in a way that account address is unique in the address list.
This commit is contained in:
Sale Djenic 2022-12-07 18:14:19 +01:00 committed by saledjenic
parent d1a4b53d5c
commit 691c930828
3 changed files with 69 additions and 7 deletions

View File

@ -211,12 +211,21 @@ func TestKeypairs(t *testing.T) {
AccountsAddresses: []types.Address{{0x01}, {0x02}},
KeyUID: "0000000000000000000000000000000000000000000000000000000000000002",
}
keyPair3 := keypairs.KeyPair{
KeycardUID: "00000000000000000000000000000003",
KeycardName: "Card02 Copy",
KeycardLocked: false,
AccountsAddresses: []types.Address{{0x01}, {0x02}},
KeyUID: "0000000000000000000000000000000000000000000000000000000000000002",
}
// Test adding key pairs
err := db.AddMigratedKeyPair(keyPair1.KeycardUID, keyPair1.KeycardName, keyPair1.KeyUID, keyPair1.AccountsAddresses)
require.NoError(t, err)
err = db.AddMigratedKeyPair(keyPair2.KeycardUID, keyPair2.KeycardName, keyPair2.KeyUID, keyPair2.AccountsAddresses)
require.NoError(t, err)
err = db.AddMigratedKeyPair(keyPair3.KeycardUID, keyPair3.KeycardName, keyPair3.KeyUID, keyPair3.AccountsAddresses)
require.NoError(t, err)
// Test reading migrated key pairs
rows, err := db.GetAllMigratedKeyPairs()
@ -245,6 +254,28 @@ func TestKeypairs(t *testing.T) {
require.Equal(t, keyPair1.KeycardLocked, rows[0].KeycardLocked)
require.Equal(t, len(keyPair1.AccountsAddresses), len(rows[0].AccountsAddresses))
rows, err = db.GetAllKnownKeycards()
require.NoError(t, err)
require.Equal(t, 3, len(rows))
for _, kp := range rows {
if kp.KeycardUID == keyPair1.KeycardUID {
require.Equal(t, keyPair1.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair1.KeycardName, kp.KeycardName)
require.Equal(t, keyPair1.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair1.AccountsAddresses), len(kp.AccountsAddresses))
} else if kp.KeycardUID == keyPair2.KeycardUID {
require.Equal(t, keyPair2.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair2.KeycardName, kp.KeycardName)
require.Equal(t, keyPair2.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair2.AccountsAddresses), len(kp.AccountsAddresses))
} else {
require.Equal(t, keyPair3.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair3.KeycardName, kp.KeycardName)
require.Equal(t, keyPair3.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair3.AccountsAddresses), len(kp.AccountsAddresses))
}
}
// Test seting a new keycard name
err = db.SetKeycardName(keyPair1.KeycardUID, "Card101")
require.NoError(t, err)

View File

@ -25,7 +25,16 @@ func NewKeyPairs(db *sql.DB) *KeyPairs {
}
}
func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) {
func containsAddress(addresses []types.Address, address types.Address) bool {
for _, addr := range addresses {
if addr == address {
return true
}
}
return false
}
func (kp *KeyPairs) processResult(rows *sql.Rows, groupByKeycard bool) ([]*KeyPair, error) {
keyPairs := []*KeyPair{}
for rows.Next() {
keyPair := &KeyPair{}
@ -37,15 +46,25 @@ func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) {
foundAtIndex := -1
for i := range keyPairs {
if keyPairs[i].KeyUID == keyPair.KeyUID {
foundAtIndex = i
break
if groupByKeycard {
if keyPairs[i].KeycardUID == keyPair.KeycardUID {
foundAtIndex = i
break
}
} else {
if keyPairs[i].KeyUID == keyPair.KeyUID {
foundAtIndex = i
break
}
}
}
if foundAtIndex == -1 {
keyPair.AccountsAddresses = append(keyPair.AccountsAddresses, addr)
keyPairs = append(keyPairs, keyPair)
} else {
if containsAddress(keyPairs[foundAtIndex].AccountsAddresses, addr) {
continue
}
keyPairs[foundAtIndex].AccountsAddresses = append(keyPairs[foundAtIndex].AccountsAddresses, addr)
}
}
@ -53,7 +72,7 @@ func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) {
return keyPairs, nil
}
func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
func (kp *KeyPairs) getAllRows(groupByKeycard bool) ([]*KeyPair, error) {
rows, err := kp.db.Query(`
SELECT
keycard_uid,
@ -71,7 +90,15 @@ func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
}
defer rows.Close()
return kp.processResult(rows)
return kp.processResult(rows, groupByKeycard)
}
func (kp *KeyPairs) GetAllKnownKeycards() ([]*KeyPair, error) {
return kp.getAllRows(true)
}
func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
return kp.getAllRows(false)
}
func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error) {
@ -94,7 +121,7 @@ func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error
}
defer rows.Close()
return kp.processResult(rows)
return kp.processResult(rows, false)
}
func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID string, accountAddresses []types.Address) (err error) {

View File

@ -445,6 +445,10 @@ func (api *API) AddMigratedKeyPair(ctx context.Context, kcUID string, kpName str
return nil
}
func (api *API) GetAllKnownKeycards(ctx context.Context) ([]*keypairs.KeyPair, error) {
return api.db.GetAllKnownKeycards()
}
func (api *API) GetAllMigratedKeyPairs(ctx context.Context) ([]*keypairs.KeyPair, error) {
return api.db.GetAllMigratedKeyPairs()
}