fix: `GetAllKnownKeycards` new keypair endpoint added

Handling results of `GetAllMigratedKeyPairs` and `GetMigratedKeyPairByKeyUID`
endpoints updated in a way that account address is unique in the address list.
This commit is contained in:
Sale Djenic 2022-12-07 18:14:19 +01:00 committed by saledjenic
parent d1a4b53d5c
commit 691c930828
3 changed files with 69 additions and 7 deletions

View File

@ -211,12 +211,21 @@ func TestKeypairs(t *testing.T) {
AccountsAddresses: []types.Address{{0x01}, {0x02}}, AccountsAddresses: []types.Address{{0x01}, {0x02}},
KeyUID: "0000000000000000000000000000000000000000000000000000000000000002", KeyUID: "0000000000000000000000000000000000000000000000000000000000000002",
} }
keyPair3 := keypairs.KeyPair{
KeycardUID: "00000000000000000000000000000003",
KeycardName: "Card02 Copy",
KeycardLocked: false,
AccountsAddresses: []types.Address{{0x01}, {0x02}},
KeyUID: "0000000000000000000000000000000000000000000000000000000000000002",
}
// Test adding key pairs // Test adding key pairs
err := db.AddMigratedKeyPair(keyPair1.KeycardUID, keyPair1.KeycardName, keyPair1.KeyUID, keyPair1.AccountsAddresses) err := db.AddMigratedKeyPair(keyPair1.KeycardUID, keyPair1.KeycardName, keyPair1.KeyUID, keyPair1.AccountsAddresses)
require.NoError(t, err) require.NoError(t, err)
err = db.AddMigratedKeyPair(keyPair2.KeycardUID, keyPair2.KeycardName, keyPair2.KeyUID, keyPair2.AccountsAddresses) err = db.AddMigratedKeyPair(keyPair2.KeycardUID, keyPair2.KeycardName, keyPair2.KeyUID, keyPair2.AccountsAddresses)
require.NoError(t, err) require.NoError(t, err)
err = db.AddMigratedKeyPair(keyPair3.KeycardUID, keyPair3.KeycardName, keyPair3.KeyUID, keyPair3.AccountsAddresses)
require.NoError(t, err)
// Test reading migrated key pairs // Test reading migrated key pairs
rows, err := db.GetAllMigratedKeyPairs() rows, err := db.GetAllMigratedKeyPairs()
@ -245,6 +254,28 @@ func TestKeypairs(t *testing.T) {
require.Equal(t, keyPair1.KeycardLocked, rows[0].KeycardLocked) require.Equal(t, keyPair1.KeycardLocked, rows[0].KeycardLocked)
require.Equal(t, len(keyPair1.AccountsAddresses), len(rows[0].AccountsAddresses)) require.Equal(t, len(keyPair1.AccountsAddresses), len(rows[0].AccountsAddresses))
rows, err = db.GetAllKnownKeycards()
require.NoError(t, err)
require.Equal(t, 3, len(rows))
for _, kp := range rows {
if kp.KeycardUID == keyPair1.KeycardUID {
require.Equal(t, keyPair1.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair1.KeycardName, kp.KeycardName)
require.Equal(t, keyPair1.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair1.AccountsAddresses), len(kp.AccountsAddresses))
} else if kp.KeycardUID == keyPair2.KeycardUID {
require.Equal(t, keyPair2.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair2.KeycardName, kp.KeycardName)
require.Equal(t, keyPair2.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair2.AccountsAddresses), len(kp.AccountsAddresses))
} else {
require.Equal(t, keyPair3.KeycardUID, kp.KeycardUID)
require.Equal(t, keyPair3.KeycardName, kp.KeycardName)
require.Equal(t, keyPair3.KeycardLocked, kp.KeycardLocked)
require.Equal(t, len(keyPair3.AccountsAddresses), len(kp.AccountsAddresses))
}
}
// Test seting a new keycard name // Test seting a new keycard name
err = db.SetKeycardName(keyPair1.KeycardUID, "Card101") err = db.SetKeycardName(keyPair1.KeycardUID, "Card101")
require.NoError(t, err) require.NoError(t, err)

View File

@ -25,7 +25,16 @@ func NewKeyPairs(db *sql.DB) *KeyPairs {
} }
} }
func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) { func containsAddress(addresses []types.Address, address types.Address) bool {
for _, addr := range addresses {
if addr == address {
return true
}
}
return false
}
func (kp *KeyPairs) processResult(rows *sql.Rows, groupByKeycard bool) ([]*KeyPair, error) {
keyPairs := []*KeyPair{} keyPairs := []*KeyPair{}
for rows.Next() { for rows.Next() {
keyPair := &KeyPair{} keyPair := &KeyPair{}
@ -37,15 +46,25 @@ func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) {
foundAtIndex := -1 foundAtIndex := -1
for i := range keyPairs { for i := range keyPairs {
if keyPairs[i].KeyUID == keyPair.KeyUID { if groupByKeycard {
foundAtIndex = i if keyPairs[i].KeycardUID == keyPair.KeycardUID {
break foundAtIndex = i
break
}
} else {
if keyPairs[i].KeyUID == keyPair.KeyUID {
foundAtIndex = i
break
}
} }
} }
if foundAtIndex == -1 { if foundAtIndex == -1 {
keyPair.AccountsAddresses = append(keyPair.AccountsAddresses, addr) keyPair.AccountsAddresses = append(keyPair.AccountsAddresses, addr)
keyPairs = append(keyPairs, keyPair) keyPairs = append(keyPairs, keyPair)
} else { } else {
if containsAddress(keyPairs[foundAtIndex].AccountsAddresses, addr) {
continue
}
keyPairs[foundAtIndex].AccountsAddresses = append(keyPairs[foundAtIndex].AccountsAddresses, addr) keyPairs[foundAtIndex].AccountsAddresses = append(keyPairs[foundAtIndex].AccountsAddresses, addr)
} }
} }
@ -53,7 +72,7 @@ func (kp *KeyPairs) processResult(rows *sql.Rows) ([]*KeyPair, error) {
return keyPairs, nil return keyPairs, nil
} }
func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) { func (kp *KeyPairs) getAllRows(groupByKeycard bool) ([]*KeyPair, error) {
rows, err := kp.db.Query(` rows, err := kp.db.Query(`
SELECT SELECT
keycard_uid, keycard_uid,
@ -71,7 +90,15 @@ func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
} }
defer rows.Close() defer rows.Close()
return kp.processResult(rows) return kp.processResult(rows, groupByKeycard)
}
func (kp *KeyPairs) GetAllKnownKeycards() ([]*KeyPair, error) {
return kp.getAllRows(true)
}
func (kp *KeyPairs) GetAllMigratedKeyPairs() ([]*KeyPair, error) {
return kp.getAllRows(false)
} }
func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error) { func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error) {
@ -94,7 +121,7 @@ func (kp *KeyPairs) GetMigratedKeyPairByKeyUID(keyUID string) ([]*KeyPair, error
} }
defer rows.Close() defer rows.Close()
return kp.processResult(rows) return kp.processResult(rows, false)
} }
func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID string, accountAddresses []types.Address) (err error) { func (kp *KeyPairs) AddMigratedKeyPair(kcUID string, kpName string, KeyUID string, accountAddresses []types.Address) (err error) {

View File

@ -445,6 +445,10 @@ func (api *API) AddMigratedKeyPair(ctx context.Context, kcUID string, kpName str
return nil return nil
} }
func (api *API) GetAllKnownKeycards(ctx context.Context) ([]*keypairs.KeyPair, error) {
return api.db.GetAllKnownKeycards()
}
func (api *API) GetAllMigratedKeyPairs(ctx context.Context) ([]*keypairs.KeyPair, error) { func (api *API) GetAllMigratedKeyPairs(ctx context.Context) ([]*keypairs.KeyPair, error) {
return api.db.GetAllMigratedKeyPairs() return api.db.GetAllMigratedKeyPairs()
} }