feat(dapps)_: implement basic dApp persistance

Implement required basic CRUD APIs

- Add session to wallet connect
- Delete session used in tests only
- Get active dApps: the order of retrieval is
    based on the first time the DApp was added
    in descending order.

Also add tests to validate the main requirements

Closes: #14615
This commit is contained in:
Stefan 2024-05-23 15:19:00 +03:00 committed by Stefan Dunca
parent 36273bc9b2
commit e06c490ec8
5 changed files with 676 additions and 249 deletions

View File

@ -29,6 +29,7 @@ import (
"github.com/status-im/status-go/services/wallet/thirdparty"
"github.com/status-im/status-go/services/wallet/token"
"github.com/status-im/status-go/services/wallet/transfer"
"github.com/status-im/status-go/services/wallet/walletconnect"
"github.com/status-im/status-go/services/wallet/walletevent"
"github.com/status-im/status-go/transactions"
)
@ -768,3 +769,22 @@ func (api *API) getVerifiedWalletAccount(address, password string) (*account.Sel
AccountKey: key,
}, nil
}
// AddWalletConnectSession adds or updates a session wallet connect session
func (api *API) AddWalletConnectSession(ctx context.Context, session_json string) error {
log.Debug("wallet.api.AddWalletConnectSession", "rpcURL", len(session_json))
return walletconnect.AddSession(api.s.db, api.s.config.Networks, session_json)
}
// DisconnectWalletConnectSession removes a wallet connect session
func (api *API) DisconnectWalletConnectSession(ctx context.Context, topic walletconnect.Topic) error {
log.Debug("wallet.api.DisconnectWalletConnectSession", "topic", topic)
return walletconnect.DisconnectSession(api.s.db, topic)
}
// GetWalletConnectDapps returns all active wallet connect dapps
// Active dApp are those having active sessions (not expired and not disconnected)
func (api *API) GetWalletConnectDapps(ctx context.Context, validAtTimestamp int64, testChains bool) ([]walletconnect.DBDApp, error) {
log.Debug("wallet.api.GetWalletConnectDapps", "validAtTimestamp", validAtTimestamp, "testChains", testChains)
return walletconnect.GetActiveDapps(api.s.db, validAtTimestamp, testChains)
}

View File

@ -2,70 +2,87 @@ package walletconnect
import (
"database/sql"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
)
type DbSession struct {
Topic Topic `json:"topic"`
PairingTopic Topic `json:"pairingTopic"`
Expiry int64 `json:"expiry"`
Active bool `json:"active"`
DappName string `json:"dappName"`
DappURL string `json:"dappUrl"`
DappDescription string `json:"dappDescription"`
DappIcon string `json:"dappIcon"`
DappVerifyURL string `json:"dappVerifyUrl"`
DappPublicKey string `json:"dappPublicKey"`
type DBSession struct {
Topic Topic
Disconnected bool
SessionJSON string
Expiry int64
CreatedTimestamp int64
PairingTopic Topic
TestChains bool
DBDApp
}
func UpsertSession(db *sql.DB, session DbSession) error {
insertSQL := `
INSERT OR IGNORE INTO
wallet_connect_sessions (topic, pairing_topic, expiry, active)
VALUES
(?, ?, ?, ?);
type DBDApp struct {
URL string `json:"url"`
Name string `json:"name"`
IconURL string `json:"iconUrl"`
}
UPDATE
wallet_connect_sessions
SET
pairing_topic = ?,
expiry = ?,
active = ?,
dapp_name = ?,
dapp_url = ?,
dapp_description = ?,
dapp_icon = ?,
dapp_verify_url = ?,
dapp_publicKey = ?
WHERE
topic = ?;`
func UpsertSession(db *sql.DB, data DBSession) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("begin transaction: %v", err)
}
defer func() {
if err != nil {
rollErr := tx.Rollback()
if rollErr != nil {
log.Error("error rolling back transaction", "rollErr", rollErr, "err", err)
}
}
}()
_, err := db.Exec(insertSQL,
session.Topic,
session.PairingTopic,
session.Expiry,
session.Active,
session.PairingTopic,
session.Expiry,
session.Active,
session.DappName,
session.DappURL,
session.DappDescription,
session.DappIcon,
session.DappVerifyURL,
session.DappPublicKey,
session.Topic,
)
upsertDappStmt := `INSERT INTO wallet_connect_dapps (url, name, icon_url) VALUES (?, ?, ?)
ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url`
_, err = tx.Exec(upsertDappStmt, data.URL, data.Name, data.IconURL)
if err != nil {
return fmt.Errorf("upsert wallet_connect_dapps: %v", err)
}
upsertSessionStmt := `INSERT INTO wallet_connect_sessions (
topic,
disconnected,
session_json,
expiry,
created_timestamp,
pairing_topic,
test_chains,
dapp_url
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(topic) DO UPDATE SET
disconnected = excluded.disconnected,
session_json = excluded.session_json,
expiry = excluded.expiry,
created_timestamp = excluded.created_timestamp,
pairing_topic = excluded.pairing_topic,
test_chains = excluded.test_chains,
dapp_url = excluded.dapp_url;`
_, err = tx.Exec(upsertSessionStmt, data.Topic, data.Disconnected, data.SessionJSON, data.Expiry, data.CreatedTimestamp, data.PairingTopic, data.TestChains, data.URL)
if err != nil {
return fmt.Errorf("insert session: %v", err)
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %v", err)
}
return nil
}
func DeleteSession(db *sql.DB, topic Topic) error {
_, err := db.Exec("DELETE FROM wallet_connect_sessions WHERE topic = ?", topic)
return err
}
func ChangeSessionState(db *sql.DB, topic Topic, active bool) error {
stmt, err := db.Prepare("UPDATE wallet_connect_sessions SET active = ? WHERE topic = ?")
if err != nil {
return err
}
res, err := stmt.Exec(active, topic)
func DisconnectSession(db *sql.DB, topic Topic) error {
res, err := db.Exec("UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?", topic)
if err != nil {
return err
}
@ -75,33 +92,71 @@ func ChangeSessionState(db *sql.DB, topic Topic, active bool) error {
return err
}
if rowsAffected == 0 {
return errors.New("unable to locate session for DB state change")
return fmt.Errorf("topic %s not found to update state", topic)
}
return nil
}
func GetSessionByTopic(db *sql.DB, topic Topic) (*DbSession, error) {
querySQL := `
SELECT *
// GetSessionByTopic returns sql.ErrNoRows if no session is found.
func GetSessionByTopic(db *sql.DB, topic Topic) (*DBSession, error) {
query := selectAndJoinQueryStr + " WHERE sessions.topic = ?"
row := db.QueryRow(query, topic)
return scanSession(singleRow{row})
}
// GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found.
func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DBSession, error) {
query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?"
rows, err := db.Query(query, pairingTopic)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessions(rows)
}
type Scanner interface {
Scan(dest ...interface{}) error
}
type singleRow struct {
*sql.Row
}
func (r singleRow) Scan(dest ...interface{}) error {
return r.Row.Scan(dest...)
}
const selectAndJoinQueryStr = `
SELECT
sessions.topic, sessions.disconnected, sessions.session_json, sessions.expiry, sessions.created_timestamp,
sessions.pairing_topic, sessions.test_chains, sessions.dapp_url, dapps.name, dapps.icon_url
FROM
wallet_connect_sessions
WHERE
topic = ?`
wallet_connect_sessions sessions
JOIN
wallet_connect_dapps dapps ON sessions.dapp_url = dapps.url`
row := db.QueryRow(querySQL, topic)
// scanSession scans a single session from the given scanner following selectAndJoinQueryStr.
func scanSession(scanner Scanner) (*DBSession, error) {
var session DBSession
var session DbSession
err := row.Scan(&session.Topic,
&session.PairingTopic,
err := scanner.Scan(
&session.Topic,
&session.Disconnected,
&session.SessionJSON,
&session.Expiry,
&session.Active,
&session.DappName,
&session.DappURL,
&session.DappDescription,
&session.DappIcon,
&session.DappVerifyURL,
&session.DappPublicKey)
&session.CreatedTimestamp,
&session.PairingTopic,
&session.TestChains,
&session.URL,
&session.Name,
&session.IconURL,
)
if err != nil {
return nil, err
}
@ -109,38 +164,16 @@ func GetSessionByTopic(db *sql.DB, topic Topic) (*DbSession, error) {
return &session, nil
}
func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DbSession, error) {
querySQL := `
SELECT *
FROM
wallet_connect_sessions
WHERE
pairing_topic = ?`
// scanSessions returns sql.ErrNoRows if nothing is scanned.
func scanSessions(rows *sql.Rows) ([]DBSession, error) {
var sessions []DBSession
rows, err := db.Query(querySQL, pairingTopic)
if err != nil {
return nil, err
}
defer rows.Close()
sessions := make([]DbSession, 0, 2)
for rows.Next() {
var session DbSession
err := rows.Scan(&session.Topic,
&session.PairingTopic,
&session.Expiry,
&session.Active,
&session.DappName,
&session.DappURL,
&session.DappDescription,
&session.DappIcon,
&session.DappVerifyURL,
&session.DappPublicKey)
session, err := scanSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, session)
sessions = append(sessions, *session)
}
if err := rows.Err(); err != nil {
@ -150,47 +183,68 @@ func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DbSession, err
return sessions, nil
}
// GetActiveSessions returns all active sessions (active and not expired) that have an expiry timestamp newer or equal to the given timestamp.
func GetActiveSessions(db *sql.DB, expiryNotOlderThanTimestamp int64) ([]DbSession, error) {
querySQL := `
SELECT *
FROM
wallet_connect_sessions
WHERE
active != 0 AND
expiry >= ?
ORDER BY
expiry DESC`
// GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp.
func GetActiveSessions(db *sql.DB, validAtTimestamp int64) ([]DBSession, error) {
querySQL := selectAndJoinQueryStr + `
WHERE
sessions.disconnected = 0 AND
sessions.expiry >= ?
ORDER BY
sessions.expiry DESC`
rows, err := db.Query(querySQL, expiryNotOlderThanTimestamp)
rows, err := db.Query(querySQL, validAtTimestamp)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessions(rows)
}
// GetSessions returns all sessions in the ascending order of creation time
func GetSessions(db *sql.DB) ([]DBSession, error) {
querySQL := selectAndJoinQueryStr + `
ORDER BY
sessions.created_timestamp DESC`
rows, err := db.Query(querySQL)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSessions(rows)
}
// GetActiveDapps returns all dapps in the order of last first time connected (first session creation time)
func GetActiveDapps(db *sql.DB, validAtTimestamp int64, testChains bool) ([]DBDApp, error) {
query := `SELECT dapps.url, dapps.name, dapps.icon_url, MIN(sessions.created_timestamp) as dapp_creation_time
FROM
wallet_connect_dapps dapps
JOIN
wallet_connect_sessions sessions ON dapps.url = sessions.dapp_url
WHERE sessions.disconnected = 0 AND sessions.expiry >= ? AND sessions.test_chains = ?
GROUP BY dapps.url
ORDER BY dapp_creation_time DESC;`
rows, err := db.Query(query, validAtTimestamp, testChains)
if err != nil {
return nil, err
}
defer rows.Close()
sessions := make([]DbSession, 0, 2)
var dapps []DBDApp
for rows.Next() {
var session DbSession
err := rows.Scan(&session.Topic,
&session.PairingTopic,
&session.Expiry,
&session.Active,
&session.DappName,
&session.DappURL,
&session.DappDescription,
&session.DappIcon,
&session.DappVerifyURL,
&session.DappPublicKey)
if err != nil {
var dapp DBDApp
var creationTime sql.NullInt64
if err := rows.Scan(&dapp.URL, &dapp.Name, &dapp.IconURL, &creationTime); err != nil {
return nil, err
}
sessions = append(sessions, session)
dapps = append(dapps, dapp)
}
if err := rows.Err(); err != nil {
return nil, err
}
return sessions, nil
return dapps, nil
}

View File

@ -6,6 +6,7 @@ import (
"database/sql"
"github.com/status-im/status-go/services/wallet/common"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/walletdatabase"
@ -20,34 +21,80 @@ func setupTestDB(t *testing.T) (db *sql.DB, close func()) {
}
}
type urlOverride *string
type timestampOverride *int64
// testSession will override defaults for the fields that are not null
type testSession struct {
url urlOverride
created timestampOverride
expiry timestampOverride
disconnected *bool
testChains *bool
}
const testDappUrl = "https://test.url/"
// generateTestData generates alternative disconnected and active sessions starting with the active one
// timestamps start with 1234567890
func generateTestData(count int) []DbSession {
res := make([]DbSession, count)
j := 0
for i := 0; i < count; i++ {
// timestamps start with 1234567890 and increase by 1 for each session
// all sessions will share the same two pairing sessions (roll over after index 1)
// testChains is false if not overridden
func generateTestData(sessions []testSession) []DBSession {
res := make([]DBSession, len(sessions))
pairingIdx := 0
for i := 0; i < len(res); i++ {
strI := strconv.Itoa(i)
if i%4 == 0 {
j++
if i%2 == 0 {
pairingIdx++
}
strJ := strconv.Itoa(j)
res[i] = DbSession{
Topic: Topic(strI + "aaaaaa1234567890"),
PairingTopic: Topic(strJ + "bbbbbb1234567890"),
Expiry: 1234567890 + int64(i),
Active: (i % 2) == 0,
DappName: "TestApp" + strI,
DappURL: "https://test.url/" + strI,
DappDescription: "Test Description" + strI,
DappIcon: "https://test.icon" + strI,
DappVerifyURL: "https://test.verify.url/" + strI,
DappPublicKey: strI + "1234567890",
pairingIdxStr := strconv.Itoa(pairingIdx)
s := sessions[i]
url := testDappUrl + strI
if s.url != nil {
url = *s.url
}
createdTimestamp := 1234567890 + int64(i)
if s.created != nil {
createdTimestamp = *s.created
}
expiryTimestamp := createdTimestamp + 1000 + int64(i)
if s.expiry != nil {
expiryTimestamp = *s.expiry
}
disconnected := (i % 2) != 0
if s.disconnected != nil {
disconnected = *s.disconnected
}
testChains := false
if s.testChains != nil {
testChains = *s.testChains
}
res[i] = DBSession{
Topic: Topic(strI + "aaaaaa1234567890"),
Disconnected: disconnected,
SessionJSON: "{}",
Expiry: expiryTimestamp,
CreatedTimestamp: createdTimestamp,
PairingTopic: Topic(pairingIdxStr + "bbbbbb1234567890"),
TestChains: testChains,
DBDApp: DBDApp{
URL: url,
Name: "TestApp" + strI,
IconURL: "https://test.icon" + strI,
},
}
}
return res
}
func insertTestData(t *testing.T, db *sql.DB, entries []DbSession) {
func insertTestData(t *testing.T, db *sql.DB, entries []DBSession) {
for _, entry := range entries {
err := UpsertSession(db, entry)
require.NoError(t, err)
@ -58,7 +105,7 @@ func TestInsertUpdateAndGetSession(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entry := generateTestData(1)[0]
entry := generateTestData(make([]testSession, 1))[0]
err := UpsertSession(db, entry)
require.NoError(t, err)
@ -67,37 +114,38 @@ func TestInsertUpdateAndGetSession(t *testing.T) {
require.Equal(t, entry, *retrievedSession)
entry.Active = false
entry.Expiry = 1111111111
err = UpsertSession(db, entry)
updatedEntry := entry
updatedEntry.Disconnected = true
updatedEntry.Expiry = 1111111111
err = UpsertSession(db, updatedEntry)
require.NoError(t, err)
retrievedSession, err = GetSessionByTopic(db, entry.Topic)
retrievedSession, err = GetSessionByTopic(db, updatedEntry.Topic)
require.NoError(t, err)
require.Equal(t, entry, *retrievedSession)
require.Equal(t, updatedEntry, *retrievedSession)
}
func TestInsertAndGetSessionsByPairingTopic(t *testing.T) {
db, close := setupTestDB(t)
defer close()
generatedSessions := generateTestData(10)
generatedSessions := generateTestData(make([]testSession, 4))
for _, session := range generatedSessions {
err := UpsertSession(db, session)
require.NoError(t, err)
}
retrievedSessions, err := GetSessionsByPairingTopic(db, generatedSessions[4].Topic)
retrievedSessions, err := GetSessionsByPairingTopic(db, generatedSessions[2].Topic)
require.NoError(t, err)
require.Equal(t, 0, len(retrievedSessions))
retrievedSessions, err = GetSessionsByPairingTopic(db, generatedSessions[4].PairingTopic)
retrievedSessions, err = GetSessionsByPairingTopic(db, generatedSessions[2].PairingTopic)
require.NoError(t, err)
require.Equal(t, 4, len(retrievedSessions))
require.Equal(t, 2, len(retrievedSessions))
for i := 4; i < 8; i++ {
for i := 2; i < 4; i++ {
found := false
for _, session := range retrievedSessions {
if session.Topic == generatedSessions[i].Topic {
@ -109,34 +157,24 @@ func TestInsertAndGetSessionsByPairingTopic(t *testing.T) {
}
}
func TestChangeSessionState(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entry := generateTestData(1)[0]
err := UpsertSession(db, entry)
require.NoError(t, err)
err = ChangeSessionState(db, entry.Topic, false)
require.NoError(t, err)
retrievedSession, err := GetSessionByTopic(db, entry.Topic)
require.NoError(t, err)
require.Equal(t, false, retrievedSession.Active)
}
func TestGet(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entries := generateTestData(3)
entries := generateTestData(make([]testSession, 3))
insertTestData(t, db, entries)
retrievedSession, err := GetSessionByTopic(db, entries[1].Topic)
require.NoError(t, err)
require.Equal(t, entries[1], *retrievedSession)
err = DeleteSession(db, entries[1].Topic)
require.NoError(t, err)
deletedSession, err := GetSessionByTopic(db, entries[1].Topic)
require.ErrorIs(t, err, sql.ErrNoRows)
require.Nil(t, deletedSession)
}
func TestGetActiveSessions(t *testing.T) {
@ -144,10 +182,10 @@ func TestGetActiveSessions(t *testing.T) {
defer close()
// insert two disconnected and three active sessions
entries := generateTestData(5)
entries := generateTestData(make([]testSession, 5))
insertTestData(t, db, entries)
activeSessions, err := GetActiveSessions(db, 1234567892)
activeSessions, err := GetActiveSessions(db, entries[2].Expiry)
require.NoError(t, err)
require.Equal(t, 2, len(activeSessions))
@ -156,19 +194,206 @@ func TestGetActiveSessions(t *testing.T) {
require.Equal(t, entries[2], activeSessions[1])
}
// func TestHasActivePairings(t *testing.T) {
// db, close := setupTestDB(t)
// defer close()
func TestDeleteSession(t *testing.T) {
db, close := setupTestDB(t)
defer close()
// // insert one disconnected and two active pairing
// entries := generateTestData(2)
// insertTestData(t, db, entries)
entries := generateTestData(make([]testSession, 3))
insertTestData(t, db, entries)
// hasActivePairings, err := HasActivePairings(db, 1234567890)
// require.NoError(t, err)
// require.True(t, hasActivePairings)
err := DeleteSession(db, entries[1].Topic)
require.NoError(t, err)
// hasActivePairings, err = HasActivePairings(db, 1234567891)
// require.NoError(t, err)
// require.False(t, hasActivePairings)
// }
sessions, err := GetSessions(db)
require.NoError(t, err)
require.Equal(t, 2, len(sessions))
require.Equal(t, entries[0], sessions[1])
require.Equal(t, entries[2], sessions[0])
err = DeleteSession(db, entries[0].Topic)
require.NoError(t, err)
err = DeleteSession(db, entries[2].Topic)
require.NoError(t, err)
sessions, err = GetSessions(db)
require.NoError(t, err)
require.Equal(t, 0, len(sessions))
}
// urlFor prepares a value to be used in testSession
func urlFor(i int) urlOverride {
return common.NewAndSet(testDappUrl + strconv.Itoa(i))
}
// at prepares a value to be used in testSession
func at(i int) timestampOverride {
return common.NewAndSet(int64(i))
}
// TestGetActiveDapps_JoinWorksAsExpected also validates that GetActiveDapps returns the dapps in the order of the last first time added
func TestGetActiveDapps_JoinWorksAsExpected(t *testing.T) {
db, close := setupTestDB(t)
defer close()
not := common.NewAndSet(false)
// The first creation date is 1, 2, 3 but the last name update is, respectively, 1, 4, 5
entries := generateTestData([]testSession{
{url: urlFor(1), created: at(1), disconnected: not},
{url: urlFor(1), created: at(2), disconnected: not},
{url: urlFor(2), created: at(3), disconnected: not},
{url: urlFor(3), created: at(4), disconnected: not},
{url: urlFor(2), created: at(5), disconnected: not},
{url: urlFor(3), created: at(6), disconnected: not},
})
insertTestData(t, db, entries)
getTestnet := false
validAtTimestamp := entries[0].Expiry
dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet)
require.NoError(t, err)
require.Equal(t, 3, len(dapps))
require.Equal(t, 3, len(dapps))
require.Equal(t, entries[5].Name, dapps[0].Name)
require.Equal(t, entries[4].Name, dapps[1].Name)
require.Equal(t, entries[1].Name, dapps[2].Name)
}
// TestGetActiveDapps_ActiveWorksAsExpected tests the combination of disconnected and expired sessions
func TestGetActiveDapps_ActiveWorksAsExpected(t *testing.T) {
db, close := setupTestDB(t)
defer close()
not := common.NewAndSet(false)
yes := common.NewAndSet(true)
timeNow := 4
entries := generateTestData([]testSession{
{url: urlFor(1), expiry: at(timeNow - 3), disconnected: not},
{url: urlFor(1), expiry: at(timeNow - 2), disconnected: yes},
{url: urlFor(2), expiry: at(timeNow - 2), disconnected: not},
{url: urlFor(3), expiry: at(timeNow - 1), disconnected: yes},
// ----- timeNow
{url: urlFor(3), expiry: at(timeNow + 1), disconnected: not},
{url: urlFor(4), expiry: at(timeNow + 1), disconnected: yes},
{url: urlFor(4), expiry: at(timeNow + 2), disconnected: not},
{url: urlFor(5), expiry: at(timeNow + 2), disconnected: yes},
{url: urlFor(6), expiry: at(timeNow + 3), disconnected: not},
})
insertTestData(t, db, entries)
getTestnet := false
dapps, err := GetActiveDapps(db, int64(timeNow), getTestnet)
require.NoError(t, err)
require.Equal(t, 3, len(dapps))
}
// TestGetActiveDapps_TestChainsWorksAsExpected tests the combination of disconnected and expired sessions
func TestGetActiveDapps_TestChainsWorksAsExpected(t *testing.T) {
db, close := setupTestDB(t)
defer close()
not := common.NewAndSet(false)
yes := common.NewAndSet(true)
timeNow := 4
entries := generateTestData([]testSession{
{url: urlFor(1), testChains: not, expiry: at(timeNow - 3), disconnected: not},
{url: urlFor(2), testChains: yes, expiry: at(timeNow - 2), disconnected: not},
{url: urlFor(2), testChains: not, expiry: at(timeNow - 1), disconnected: not},
// ----- timeNow
{url: urlFor(3), testChains: not, expiry: at(timeNow + 1), disconnected: not},
{url: urlFor(4), testChains: not, expiry: at(timeNow + 2), disconnected: not},
{url: urlFor(4), testChains: yes, expiry: at(timeNow + 3), disconnected: not},
{url: urlFor(5), testChains: yes, expiry: at(timeNow + 4), disconnected: not},
})
insertTestData(t, db, entries)
getTestnet := true
dapps, err := GetActiveDapps(db, int64(timeNow), getTestnet)
require.NoError(t, err)
require.Equal(t, 2, len(dapps))
}
// TestGetDapps_EmptyDB tests that an empty database will return an empty list
func TestGetDapps_EmptyDB(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entries := generateTestData([]testSession{})
insertTestData(t, db, entries)
getTestnet := false
validAtTimestamp := int64(0)
dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet)
require.NoError(t, err)
require.Equal(t, 0, len(dapps))
}
// TestGetDapps_OrphanDapps tests that missing session will place the dapp at the end
func TestGetDapps_OrphanDapps(t *testing.T) {
db, close := setupTestDB(t)
defer close()
not := common.NewAndSet(false)
entries := generateTestData([]testSession{
{url: urlFor(1), disconnected: not},
{url: urlFor(2), disconnected: not},
{url: urlFor(2), disconnected: not},
})
insertTestData(t, db, entries)
err := DeleteSession(db, entries[1].Topic)
require.NoError(t, err)
err = DeleteSession(db, entries[2].Topic)
require.NoError(t, err)
getTestnet := false
validAtTimestamp := entries[0].Expiry
dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet)
require.NoError(t, err)
// The orphan dapp is not considered active
require.Equal(t, 1, len(dapps))
require.Equal(t, entries[0].Name, dapps[0].Name)
}
func TestDisconnectSession(t *testing.T) {
db, close := setupTestDB(t)
defer close()
not := common.NewAndSet(false)
entries := generateTestData([]testSession{
{url: urlFor(1), disconnected: not},
{url: urlFor(2), disconnected: not},
{url: urlFor(2), disconnected: not},
})
insertTestData(t, db, entries)
activeSessions, err := GetActiveSessions(db, 0)
require.NoError(t, err)
require.Equal(t, 3, len(activeSessions))
getTestnet := false
validAtTimestamp := entries[0].Expiry
dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet)
require.NoError(t, err)
require.Equal(t, 2, len(dapps))
err = DisconnectSession(db, entries[1].Topic)
require.NoError(t, err)
activeSessions, err = GetActiveSessions(db, 0)
require.NoError(t, err)
require.Equal(t, 2, len(activeSessions))
err = DisconnectSession(db, entries[2].Topic)
require.NoError(t, err)
activeSessions, err = GetActiveSessions(db, 0)
require.NoError(t, err)
require.Equal(t, 1, len(activeSessions))
dapps, err = GetActiveDapps(db, validAtTimestamp, getTestnet)
require.NoError(t, err)
require.Equal(t, 1, len(dapps))
require.Equal(t, entries[0].Name, dapps[0].Name)
}

View File

@ -1,15 +1,18 @@
package walletconnect
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params"
"github.com/status-im/status-go/services/wallet/walletevent"
)
@ -61,6 +64,8 @@ type VerifyContext struct {
Verified Verified `json:"verified"`
}
// Params has RequiredNamespaces entries if part of "proposal namespace" and Namespaces entries if part of "session namespace"
// see https://specs.walletconnect.com/2.0/specs/clients/sign/namespaces#controller-side-validation-of-incoming-proposal-namespaces-wallet
type Params struct {
ID int64 `json:"id"`
PairingTopic Topic `json:"pairingTopic"`
@ -138,8 +143,8 @@ func (n *Namespace) Valid(namespaceName string, chainID *uint64) bool {
return true
}
// Valid params
func (p *Params) Valid() bool {
// ValidateForProposal validates params part of the Proposal Namespace
func (p *Params) ValidateForProposal() bool {
for key, ns := range p.RequiredNamespaces {
var chainID *uint64
if strings.Contains(key, ":") {
@ -165,15 +170,62 @@ func (p *Params) Valid() bool {
return true
}
// Valid session propsal
// ValidateProposal validates params part of the Proposal Namespace
// https://specs.walletconnect.com/2.0/specs/clients/sign/namespaces#controller-side-validation-of-incoming-proposal-namespaces-wallet
func (p *SessionProposal) Valid() bool {
return p.Params.Valid()
func (p *SessionProposal) ValidateProposal() bool {
return p.Params.ValidateForProposal()
}
func sessionProposalToSupportedChain(caipChains []string, supportsChain func(uint64) bool) (chains []uint64, eipChains []string) {
chains = make([]uint64, 0, 1)
eipChains = make([]string, 0, 1)
// AddSession adds a new active session to the database
func AddSession(db *sql.DB, networks []params.Network, session_json string) error {
var session Session
err := json.Unmarshal([]byte(session_json), &session)
if err != nil {
return fmt.Errorf("unmarshal session: %v", err)
}
chains := supportedChainsInSession(session)
testChains, err := areTestChains(networks, chains)
if err != nil {
return fmt.Errorf("areTestChains: %v", err)
}
rowEntry := DBSession{
Topic: session.PairingTopic,
Disconnected: false,
SessionJSON: session_json,
Expiry: session.Expiry,
CreatedTimestamp: time.Now().Unix(),
PairingTopic: session.PairingTopic,
TestChains: testChains,
DBDApp: DBDApp{
URL: session.Peer.Metadata.URL,
Name: session.Peer.Metadata.Name,
},
}
if len(session.Peer.Metadata.Icons) > 0 {
rowEntry.IconURL = session.Peer.Metadata.Icons[0]
}
return UpsertSession(db, rowEntry)
}
// areTestChains assumes chains to tests are all testnets or all mainnets
func areTestChains(networks []params.Network, chainIDs []uint64) (isTest bool, err error) {
for _, n := range networks {
for _, chainID := range chainIDs {
if n.ChainID == chainID {
return n.IsTest, nil
}
}
}
return false, fmt.Errorf("no network found for chainIDs %v", chainIDs)
}
func supportedChainsInSession(session Session) []uint64 {
caipChains := session.Namespaces[SupportedEip155Namespace].Chains
chains := make([]uint64, 0, len(caipChains))
for _, caip2Str := range caipChains {
_, chainID, err := parseCaip2ChainID(caip2Str)
if err != nil {
@ -181,14 +233,9 @@ func sessionProposalToSupportedChain(caipChains []string, supportsChain func(uin
continue
}
if !supportsChain(chainID) {
continue
}
eipChains = append(eipChains, caip2Str)
chains = append(chains, chainID)
}
return
return chains
}
func caip10Accounts(accounts []*accounts.Account, chains []uint64) []string {

View File

@ -2,16 +2,101 @@ package walletconnect
import (
"reflect"
"strconv"
"testing"
"encoding/json"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/multiaccounts/accounts"
"github.com/status-im/status-go/params"
"github.com/stretchr/testify/assert"
)
func getSessionJSONFor(chains []int, expiry int) string {
chainsStr := "["
for i, chain := range chains {
chainsStr += `"eip155:` + strconv.Itoa(chain) + `"`
if i != len(chains)-1 {
chainsStr += ","
}
}
chainsStr += "]"
expiryStr := strconv.Itoa(expiry)
return `{
"expiry": ` + expiryStr + `,
"namespaces": {
"eip155": {
"accounts": [
"eip155:1:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240",
"eip155:10:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240",
"eip155:42161:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240"
],
"chains": ` + chainsStr + `,
"events": [
"accountsChanged",
"chainChanged"
],
"methods": [
"eth_sendTransaction",
"personal_sign"
]
}
},
"optionalNamespaces": {
"eip155": {
"chains": [],
"events": [],
"methods": [],
"rpcMap": {}
}
},
"pairingTopic": "50fba141cdb5c015493c2907c46bacf9f7cbd7c8e3d4e97df891f18dddcff69c",
"peer": {
"metadata": {
"description": "Test Dapp Description",
"icons": [ "https://test.org/test.png"],
"name": "Test Dapp",
"url": "https://dapp.test.org"
},
"publicKey": "1234567890aeb6081cabed26faf48919162fd70cc66d639f118a60507ae0463d"
},
"relay": { "protocol": "irn"},
"requiredNamespaces": {
"eip155": {
"chains": [
"eip155:1"
],
"events": [
"chainChanged",
"accountsChanged"
],
"methods": [
"eth_sendTransaction",
"personal_sign"
],
"rpcMap": {
"1": "https://mainnet.infura.io/v3/099fc58e0de9451d80b18d7c74caa7c1"
}
}
},
"self": {
"metadata": {
"description": "Test Wallet Description",
"icons": [
"https://wallet.test.org/test.svg"
],
"name": "Test Wallet",
"url": "http://localhost"
},
"publicKey": "da4a87d5f0f54951afe870ebf020cf03f8a3522fbd219398c3fa159a37e16d54"
},
"topic": "e39e1f435a46b5ee6b31484d1751cfbc35be1275653af2ea340974a7592f1a19"
}`
}
func Test_sessionProposalValidity(t *testing.T) {
tests := []struct {
name string
@ -183,68 +268,44 @@ func Test_sessionProposalValidity(t *testing.T) {
err := json.Unmarshal([]byte(tt.sessionProposalJSON), &sessionProposal)
assert.NoError(t, err)
validRes := sessionProposal.ValidateProposal()
if tt.expectedValidity {
assert.True(t, sessionProposal.Valid())
assert.True(t, validRes)
} else {
assert.False(t, sessionProposal.Valid())
assert.False(t, validRes)
}
})
}
}
func Test_sessionProposalToSupportedChain(t *testing.T) {
func Test_supportedChainInSession(t *testing.T) {
type args struct {
chains []string
supportsChain func(uint64) bool
sessionProposal Session
}
tests := []struct {
name string
args args
wantChains []uint64
wantEipChains []string
name string
args args
expectedChains []uint64
}{
{
name: "filter_out_unsupported_chains_and_invalid_chains",
name: "supported_chain",
args: args{
chains: []string{"eip155:1", "eip155:3", "eip155:invalid"},
supportsChain: func(chainID uint64) bool {
return chainID == 1
sessionProposal: Session{
Namespaces: map[string]Namespace{
"eip155": {
Chains: []string{"eip155:1", "eip155:2", "eip155:3", "eip155:4", "eip155:5"},
},
},
},
},
wantChains: []uint64{1},
wantEipChains: []string{"eip155:1"},
},
{
name: "no_supported_chains",
args: args{
chains: []string{"eip155:3", "eip155:5"},
supportsChain: func(chainID uint64) bool {
return false
},
},
wantChains: []uint64{},
wantEipChains: []string{},
},
{
name: "empty_proposal",
args: args{
chains: []string{},
supportsChain: func(chainID uint64) bool {
return true
},
},
wantChains: []uint64{},
wantEipChains: []string{},
expectedChains: []uint64{1, 2, 3, 4, 5},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotChains, gotEipChains := sessionProposalToSupportedChain(tt.args.chains, tt.args.supportsChain)
if !reflect.DeepEqual(gotChains, tt.wantChains) {
t.Errorf("sessionProposalToSupportedChain() gotChains = %v, want %v", gotChains, tt.wantChains)
}
if !reflect.DeepEqual(gotEipChains, tt.wantEipChains) {
t.Errorf("sessionProposalToSupportedChain() gotEipChains = %v, want %v", gotEipChains, tt.wantEipChains)
gotChains := supportedChainsInSession(tt.args.sessionProposal)
if !reflect.DeepEqual(gotChains, tt.expectedChains) {
t.Errorf("supportedChainInSessionProposal() gotChains = %v, want %v", gotChains, tt.expectedChains)
}
})
}
@ -316,3 +377,23 @@ func Test_caip10Accounts(t *testing.T) {
})
}
}
// Test_AddSession validates that the new added session is active (not expired and not disconnected)
func Test_AddSession(t *testing.T) {
db, close := setupTestDB(t)
defer close()
// Add session for testnet
expiry := 1716581732
sessionJSON := getSessionJSONFor([]int{11155111}, expiry)
networks := []params.Network{
{ChainID: 1, IsTest: false},
{ChainID: 11155111, IsTest: true},
}
err := AddSession(db, networks, sessionJSON)
assert.NoError(t, err)
dapps, err := GetActiveDapps(db, int64(expiry-1), true)
assert.NoError(t, err)
assert.Equal(t, 1, len(dapps))
}