From e06c490ec870a70ae72ede2b37f1235a3d903ed8 Mon Sep 17 00:00:00 2001 From: Stefan Date: Thu, 23 May 2024 15:19:00 +0300 Subject: [PATCH] feat(dapps)_: implement basic dApp persistance Implement required basic CRUD APIs - Add session to wallet connect - Delete session used in tests only - Get active dApps: the order of retrieval is based on the first time the DApp was added in descending order. Also add tests to validate the main requirements Closes: #14615 --- services/wallet/api.go | 20 + services/wallet/walletconnect/database.go | 310 ++++++++------- .../wallet/walletconnect/database_test.go | 353 ++++++++++++++---- .../wallet/walletconnect/walletconnect.go | 75 +++- .../walletconnect/walletconnect_test.go | 167 ++++++--- 5 files changed, 676 insertions(+), 249 deletions(-) diff --git a/services/wallet/api.go b/services/wallet/api.go index b8a2df677..b3d5eb57a 100644 --- a/services/wallet/api.go +++ b/services/wallet/api.go @@ -29,6 +29,7 @@ import ( "github.com/status-im/status-go/services/wallet/thirdparty" "github.com/status-im/status-go/services/wallet/token" "github.com/status-im/status-go/services/wallet/transfer" + "github.com/status-im/status-go/services/wallet/walletconnect" "github.com/status-im/status-go/services/wallet/walletevent" "github.com/status-im/status-go/transactions" ) @@ -768,3 +769,22 @@ func (api *API) getVerifiedWalletAccount(address, password string) (*account.Sel AccountKey: key, }, nil } + +// AddWalletConnectSession adds or updates a session wallet connect session +func (api *API) AddWalletConnectSession(ctx context.Context, session_json string) error { + log.Debug("wallet.api.AddWalletConnectSession", "rpcURL", len(session_json)) + return walletconnect.AddSession(api.s.db, api.s.config.Networks, session_json) +} + +// DisconnectWalletConnectSession removes a wallet connect session +func (api *API) DisconnectWalletConnectSession(ctx context.Context, topic walletconnect.Topic) error { + log.Debug("wallet.api.DisconnectWalletConnectSession", "topic", topic) + return walletconnect.DisconnectSession(api.s.db, topic) +} + +// GetWalletConnectDapps returns all active wallet connect dapps +// Active dApp are those having active sessions (not expired and not disconnected) +func (api *API) GetWalletConnectDapps(ctx context.Context, validAtTimestamp int64, testChains bool) ([]walletconnect.DBDApp, error) { + log.Debug("wallet.api.GetWalletConnectDapps", "validAtTimestamp", validAtTimestamp, "testChains", testChains) + return walletconnect.GetActiveDapps(api.s.db, validAtTimestamp, testChains) +} diff --git a/services/wallet/walletconnect/database.go b/services/wallet/walletconnect/database.go index 0d0d88979..10e47ac0d 100644 --- a/services/wallet/walletconnect/database.go +++ b/services/wallet/walletconnect/database.go @@ -2,70 +2,87 @@ package walletconnect import ( "database/sql" - "errors" + "fmt" + + "github.com/ethereum/go-ethereum/log" ) -type DbSession struct { - Topic Topic `json:"topic"` - PairingTopic Topic `json:"pairingTopic"` - Expiry int64 `json:"expiry"` - Active bool `json:"active"` - DappName string `json:"dappName"` - DappURL string `json:"dappUrl"` - DappDescription string `json:"dappDescription"` - DappIcon string `json:"dappIcon"` - DappVerifyURL string `json:"dappVerifyUrl"` - DappPublicKey string `json:"dappPublicKey"` +type DBSession struct { + Topic Topic + Disconnected bool + SessionJSON string + Expiry int64 + CreatedTimestamp int64 + PairingTopic Topic + TestChains bool + DBDApp } -func UpsertSession(db *sql.DB, session DbSession) error { - insertSQL := ` - INSERT OR IGNORE INTO - wallet_connect_sessions (topic, pairing_topic, expiry, active) - VALUES - (?, ?, ?, ?); +type DBDApp struct { + URL string `json:"url"` + Name string `json:"name"` + IconURL string `json:"iconUrl"` +} - UPDATE - wallet_connect_sessions - SET - pairing_topic = ?, - expiry = ?, - active = ?, - dapp_name = ?, - dapp_url = ?, - dapp_description = ?, - dapp_icon = ?, - dapp_verify_url = ?, - dapp_publicKey = ? - WHERE - topic = ?;` +func UpsertSession(db *sql.DB, data DBSession) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("begin transaction: %v", err) + } + defer func() { + if err != nil { + rollErr := tx.Rollback() + if rollErr != nil { + log.Error("error rolling back transaction", "rollErr", rollErr, "err", err) + } + } + }() - _, err := db.Exec(insertSQL, - session.Topic, - session.PairingTopic, - session.Expiry, - session.Active, - session.PairingTopic, - session.Expiry, - session.Active, - session.DappName, - session.DappURL, - session.DappDescription, - session.DappIcon, - session.DappVerifyURL, - session.DappPublicKey, - session.Topic, - ) + upsertDappStmt := `INSERT INTO wallet_connect_dapps (url, name, icon_url) VALUES (?, ?, ?) + ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url` + _, err = tx.Exec(upsertDappStmt, data.URL, data.Name, data.IconURL) + if err != nil { + return fmt.Errorf("upsert wallet_connect_dapps: %v", err) + } + + upsertSessionStmt := `INSERT INTO wallet_connect_sessions ( + topic, + disconnected, + session_json, + expiry, + created_timestamp, + pairing_topic, + test_chains, + dapp_url + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(topic) DO UPDATE SET + disconnected = excluded.disconnected, + session_json = excluded.session_json, + expiry = excluded.expiry, + created_timestamp = excluded.created_timestamp, + pairing_topic = excluded.pairing_topic, + test_chains = excluded.test_chains, + dapp_url = excluded.dapp_url;` + _, err = tx.Exec(upsertSessionStmt, data.Topic, data.Disconnected, data.SessionJSON, data.Expiry, data.CreatedTimestamp, data.PairingTopic, data.TestChains, data.URL) + if err != nil { + return fmt.Errorf("insert session: %v", err) + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %v", err) + } + + return nil +} + +func DeleteSession(db *sql.DB, topic Topic) error { + _, err := db.Exec("DELETE FROM wallet_connect_sessions WHERE topic = ?", topic) return err } -func ChangeSessionState(db *sql.DB, topic Topic, active bool) error { - stmt, err := db.Prepare("UPDATE wallet_connect_sessions SET active = ? WHERE topic = ?") - if err != nil { - return err - } - - res, err := stmt.Exec(active, topic) +func DisconnectSession(db *sql.DB, topic Topic) error { + res, err := db.Exec("UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?", topic) if err != nil { return err } @@ -75,33 +92,71 @@ func ChangeSessionState(db *sql.DB, topic Topic, active bool) error { return err } if rowsAffected == 0 { - return errors.New("unable to locate session for DB state change") + return fmt.Errorf("topic %s not found to update state", topic) } return nil } -func GetSessionByTopic(db *sql.DB, topic Topic) (*DbSession, error) { - querySQL := ` - SELECT * +// GetSessionByTopic returns sql.ErrNoRows if no session is found. +func GetSessionByTopic(db *sql.DB, topic Topic) (*DBSession, error) { + query := selectAndJoinQueryStr + " WHERE sessions.topic = ?" + + row := db.QueryRow(query, topic) + return scanSession(singleRow{row}) +} + +// GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found. +func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DBSession, error) { + query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?" + + rows, err := db.Query(query, pairingTopic) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSessions(rows) +} + +type Scanner interface { + Scan(dest ...interface{}) error +} + +type singleRow struct { + *sql.Row +} + +func (r singleRow) Scan(dest ...interface{}) error { + return r.Row.Scan(dest...) +} + +const selectAndJoinQueryStr = ` + SELECT + sessions.topic, sessions.disconnected, sessions.session_json, sessions.expiry, sessions.created_timestamp, + sessions.pairing_topic, sessions.test_chains, sessions.dapp_url, dapps.name, dapps.icon_url FROM - wallet_connect_sessions - WHERE - topic = ?` + wallet_connect_sessions sessions + JOIN + wallet_connect_dapps dapps ON sessions.dapp_url = dapps.url` - row := db.QueryRow(querySQL, topic) +// scanSession scans a single session from the given scanner following selectAndJoinQueryStr. +func scanSession(scanner Scanner) (*DBSession, error) { + var session DBSession - var session DbSession - err := row.Scan(&session.Topic, - &session.PairingTopic, + err := scanner.Scan( + &session.Topic, + &session.Disconnected, + &session.SessionJSON, &session.Expiry, - &session.Active, - &session.DappName, - &session.DappURL, - &session.DappDescription, - &session.DappIcon, - &session.DappVerifyURL, - &session.DappPublicKey) + &session.CreatedTimestamp, + &session.PairingTopic, + &session.TestChains, + &session.URL, + &session.Name, + &session.IconURL, + ) + if err != nil { return nil, err } @@ -109,38 +164,16 @@ func GetSessionByTopic(db *sql.DB, topic Topic) (*DbSession, error) { return &session, nil } -func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DbSession, error) { - querySQL := ` - SELECT * - FROM - wallet_connect_sessions - WHERE - pairing_topic = ?` +// scanSessions returns sql.ErrNoRows if nothing is scanned. +func scanSessions(rows *sql.Rows) ([]DBSession, error) { + var sessions []DBSession - rows, err := db.Query(querySQL, pairingTopic) - if err != nil { - return nil, err - } - defer rows.Close() - - sessions := make([]DbSession, 0, 2) for rows.Next() { - var session DbSession - err := rows.Scan(&session.Topic, - &session.PairingTopic, - &session.Expiry, - &session.Active, - &session.DappName, - &session.DappURL, - &session.DappDescription, - &session.DappIcon, - &session.DappVerifyURL, - &session.DappPublicKey) + session, err := scanSession(rows) if err != nil { return nil, err } - - sessions = append(sessions, session) + sessions = append(sessions, *session) } if err := rows.Err(); err != nil { @@ -150,47 +183,68 @@ func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DbSession, err return sessions, nil } -// GetActiveSessions returns all active sessions (active and not expired) that have an expiry timestamp newer or equal to the given timestamp. -func GetActiveSessions(db *sql.DB, expiryNotOlderThanTimestamp int64) ([]DbSession, error) { - querySQL := ` - SELECT * - FROM - wallet_connect_sessions - WHERE - active != 0 AND - expiry >= ? - ORDER BY - expiry DESC` +// GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp. +func GetActiveSessions(db *sql.DB, validAtTimestamp int64) ([]DBSession, error) { + querySQL := selectAndJoinQueryStr + ` + WHERE + sessions.disconnected = 0 AND + sessions.expiry >= ? + ORDER BY + sessions.expiry DESC` - rows, err := db.Query(querySQL, expiryNotOlderThanTimestamp) + rows, err := db.Query(querySQL, validAtTimestamp) + if err != nil { + return nil, err + } + defer rows.Close() + return scanSessions(rows) +} + +// GetSessions returns all sessions in the ascending order of creation time +func GetSessions(db *sql.DB) ([]DBSession, error) { + querySQL := selectAndJoinQueryStr + ` + ORDER BY + sessions.created_timestamp DESC` + + rows, err := db.Query(querySQL) + if err != nil { + return nil, err + } + defer rows.Close() + return scanSessions(rows) +} + +// GetActiveDapps returns all dapps in the order of last first time connected (first session creation time) +func GetActiveDapps(db *sql.DB, validAtTimestamp int64, testChains bool) ([]DBDApp, error) { + query := `SELECT dapps.url, dapps.name, dapps.icon_url, MIN(sessions.created_timestamp) as dapp_creation_time + FROM + wallet_connect_dapps dapps + JOIN + wallet_connect_sessions sessions ON dapps.url = sessions.dapp_url + WHERE sessions.disconnected = 0 AND sessions.expiry >= ? AND sessions.test_chains = ? + GROUP BY dapps.url + ORDER BY dapp_creation_time DESC;` + + rows, err := db.Query(query, validAtTimestamp, testChains) if err != nil { return nil, err } defer rows.Close() - sessions := make([]DbSession, 0, 2) + var dapps []DBDApp + for rows.Next() { - var session DbSession - err := rows.Scan(&session.Topic, - &session.PairingTopic, - &session.Expiry, - &session.Active, - &session.DappName, - &session.DappURL, - &session.DappDescription, - &session.DappIcon, - &session.DappVerifyURL, - &session.DappPublicKey) - if err != nil { + var dapp DBDApp + var creationTime sql.NullInt64 + if err := rows.Scan(&dapp.URL, &dapp.Name, &dapp.IconURL, &creationTime); err != nil { return nil, err } - - sessions = append(sessions, session) + dapps = append(dapps, dapp) } if err := rows.Err(); err != nil { return nil, err } - return sessions, nil + return dapps, nil } diff --git a/services/wallet/walletconnect/database_test.go b/services/wallet/walletconnect/database_test.go index 06c3b605b..6a70cade6 100644 --- a/services/wallet/walletconnect/database_test.go +++ b/services/wallet/walletconnect/database_test.go @@ -6,6 +6,7 @@ import ( "database/sql" + "github.com/status-im/status-go/services/wallet/common" "github.com/status-im/status-go/t/helpers" "github.com/status-im/status-go/walletdatabase" @@ -20,34 +21,80 @@ func setupTestDB(t *testing.T) (db *sql.DB, close func()) { } } +type urlOverride *string +type timestampOverride *int64 + +// testSession will override defaults for the fields that are not null +type testSession struct { + url urlOverride + created timestampOverride + expiry timestampOverride + disconnected *bool + testChains *bool +} + +const testDappUrl = "https://test.url/" + // generateTestData generates alternative disconnected and active sessions starting with the active one -// timestamps start with 1234567890 -func generateTestData(count int) []DbSession { - res := make([]DbSession, count) - j := 0 - for i := 0; i < count; i++ { +// timestamps start with 1234567890 and increase by 1 for each session +// all sessions will share the same two pairing sessions (roll over after index 1) +// testChains is false if not overridden +func generateTestData(sessions []testSession) []DBSession { + res := make([]DBSession, len(sessions)) + pairingIdx := 0 + for i := 0; i < len(res); i++ { strI := strconv.Itoa(i) - if i%4 == 0 { - j++ + if i%2 == 0 { + pairingIdx++ } - strJ := strconv.Itoa(j) - res[i] = DbSession{ - Topic: Topic(strI + "aaaaaa1234567890"), - PairingTopic: Topic(strJ + "bbbbbb1234567890"), - Expiry: 1234567890 + int64(i), - Active: (i % 2) == 0, - DappName: "TestApp" + strI, - DappURL: "https://test.url/" + strI, - DappDescription: "Test Description" + strI, - DappIcon: "https://test.icon" + strI, - DappVerifyURL: "https://test.verify.url/" + strI, - DappPublicKey: strI + "1234567890", + pairingIdxStr := strconv.Itoa(pairingIdx) + + s := sessions[i] + + url := testDappUrl + strI + if s.url != nil { + url = *s.url + } + + createdTimestamp := 1234567890 + int64(i) + if s.created != nil { + createdTimestamp = *s.created + } + + expiryTimestamp := createdTimestamp + 1000 + int64(i) + if s.expiry != nil { + expiryTimestamp = *s.expiry + } + + disconnected := (i % 2) != 0 + if s.disconnected != nil { + disconnected = *s.disconnected + } + + testChains := false + if s.testChains != nil { + testChains = *s.testChains + } + + res[i] = DBSession{ + Topic: Topic(strI + "aaaaaa1234567890"), + Disconnected: disconnected, + SessionJSON: "{}", + Expiry: expiryTimestamp, + CreatedTimestamp: createdTimestamp, + PairingTopic: Topic(pairingIdxStr + "bbbbbb1234567890"), + TestChains: testChains, + DBDApp: DBDApp{ + URL: url, + Name: "TestApp" + strI, + IconURL: "https://test.icon" + strI, + }, } } return res } -func insertTestData(t *testing.T, db *sql.DB, entries []DbSession) { +func insertTestData(t *testing.T, db *sql.DB, entries []DBSession) { for _, entry := range entries { err := UpsertSession(db, entry) require.NoError(t, err) @@ -58,7 +105,7 @@ func TestInsertUpdateAndGetSession(t *testing.T) { db, close := setupTestDB(t) defer close() - entry := generateTestData(1)[0] + entry := generateTestData(make([]testSession, 1))[0] err := UpsertSession(db, entry) require.NoError(t, err) @@ -67,37 +114,38 @@ func TestInsertUpdateAndGetSession(t *testing.T) { require.Equal(t, entry, *retrievedSession) - entry.Active = false - entry.Expiry = 1111111111 - err = UpsertSession(db, entry) + updatedEntry := entry + updatedEntry.Disconnected = true + updatedEntry.Expiry = 1111111111 + err = UpsertSession(db, updatedEntry) require.NoError(t, err) - retrievedSession, err = GetSessionByTopic(db, entry.Topic) + retrievedSession, err = GetSessionByTopic(db, updatedEntry.Topic) require.NoError(t, err) - require.Equal(t, entry, *retrievedSession) + require.Equal(t, updatedEntry, *retrievedSession) } func TestInsertAndGetSessionsByPairingTopic(t *testing.T) { db, close := setupTestDB(t) defer close() - generatedSessions := generateTestData(10) + generatedSessions := generateTestData(make([]testSession, 4)) for _, session := range generatedSessions { err := UpsertSession(db, session) require.NoError(t, err) } - retrievedSessions, err := GetSessionsByPairingTopic(db, generatedSessions[4].Topic) + retrievedSessions, err := GetSessionsByPairingTopic(db, generatedSessions[2].Topic) require.NoError(t, err) require.Equal(t, 0, len(retrievedSessions)) - retrievedSessions, err = GetSessionsByPairingTopic(db, generatedSessions[4].PairingTopic) + retrievedSessions, err = GetSessionsByPairingTopic(db, generatedSessions[2].PairingTopic) require.NoError(t, err) - require.Equal(t, 4, len(retrievedSessions)) + require.Equal(t, 2, len(retrievedSessions)) - for i := 4; i < 8; i++ { + for i := 2; i < 4; i++ { found := false for _, session := range retrievedSessions { if session.Topic == generatedSessions[i].Topic { @@ -109,34 +157,24 @@ func TestInsertAndGetSessionsByPairingTopic(t *testing.T) { } } -func TestChangeSessionState(t *testing.T) { - db, close := setupTestDB(t) - defer close() - - entry := generateTestData(1)[0] - err := UpsertSession(db, entry) - require.NoError(t, err) - - err = ChangeSessionState(db, entry.Topic, false) - require.NoError(t, err) - - retrievedSession, err := GetSessionByTopic(db, entry.Topic) - require.NoError(t, err) - - require.Equal(t, false, retrievedSession.Active) -} - func TestGet(t *testing.T) { db, close := setupTestDB(t) defer close() - entries := generateTestData(3) + entries := generateTestData(make([]testSession, 3)) insertTestData(t, db, entries) retrievedSession, err := GetSessionByTopic(db, entries[1].Topic) require.NoError(t, err) require.Equal(t, entries[1], *retrievedSession) + + err = DeleteSession(db, entries[1].Topic) + require.NoError(t, err) + + deletedSession, err := GetSessionByTopic(db, entries[1].Topic) + require.ErrorIs(t, err, sql.ErrNoRows) + require.Nil(t, deletedSession) } func TestGetActiveSessions(t *testing.T) { @@ -144,10 +182,10 @@ func TestGetActiveSessions(t *testing.T) { defer close() // insert two disconnected and three active sessions - entries := generateTestData(5) + entries := generateTestData(make([]testSession, 5)) insertTestData(t, db, entries) - activeSessions, err := GetActiveSessions(db, 1234567892) + activeSessions, err := GetActiveSessions(db, entries[2].Expiry) require.NoError(t, err) require.Equal(t, 2, len(activeSessions)) @@ -156,19 +194,206 @@ func TestGetActiveSessions(t *testing.T) { require.Equal(t, entries[2], activeSessions[1]) } -// func TestHasActivePairings(t *testing.T) { -// db, close := setupTestDB(t) -// defer close() +func TestDeleteSession(t *testing.T) { + db, close := setupTestDB(t) + defer close() -// // insert one disconnected and two active pairing -// entries := generateTestData(2) -// insertTestData(t, db, entries) + entries := generateTestData(make([]testSession, 3)) + insertTestData(t, db, entries) -// hasActivePairings, err := HasActivePairings(db, 1234567890) -// require.NoError(t, err) -// require.True(t, hasActivePairings) + err := DeleteSession(db, entries[1].Topic) + require.NoError(t, err) -// hasActivePairings, err = HasActivePairings(db, 1234567891) -// require.NoError(t, err) -// require.False(t, hasActivePairings) -// } + sessions, err := GetSessions(db) + require.NoError(t, err) + require.Equal(t, 2, len(sessions)) + + require.Equal(t, entries[0], sessions[1]) + require.Equal(t, entries[2], sessions[0]) + + err = DeleteSession(db, entries[0].Topic) + require.NoError(t, err) + err = DeleteSession(db, entries[2].Topic) + require.NoError(t, err) + + sessions, err = GetSessions(db) + require.NoError(t, err) + require.Equal(t, 0, len(sessions)) +} + +// urlFor prepares a value to be used in testSession +func urlFor(i int) urlOverride { + return common.NewAndSet(testDappUrl + strconv.Itoa(i)) +} + +// at prepares a value to be used in testSession +func at(i int) timestampOverride { + return common.NewAndSet(int64(i)) +} + +// TestGetActiveDapps_JoinWorksAsExpected also validates that GetActiveDapps returns the dapps in the order of the last first time added +func TestGetActiveDapps_JoinWorksAsExpected(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + not := common.NewAndSet(false) + // The first creation date is 1, 2, 3 but the last name update is, respectively, 1, 4, 5 + entries := generateTestData([]testSession{ + {url: urlFor(1), created: at(1), disconnected: not}, + {url: urlFor(1), created: at(2), disconnected: not}, + {url: urlFor(2), created: at(3), disconnected: not}, + {url: urlFor(3), created: at(4), disconnected: not}, + {url: urlFor(2), created: at(5), disconnected: not}, + {url: urlFor(3), created: at(6), disconnected: not}, + }) + insertTestData(t, db, entries) + + getTestnet := false + validAtTimestamp := entries[0].Expiry + dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet) + require.NoError(t, err) + require.Equal(t, 3, len(dapps)) + + require.Equal(t, 3, len(dapps)) + require.Equal(t, entries[5].Name, dapps[0].Name) + require.Equal(t, entries[4].Name, dapps[1].Name) + require.Equal(t, entries[1].Name, dapps[2].Name) +} + +// TestGetActiveDapps_ActiveWorksAsExpected tests the combination of disconnected and expired sessions +func TestGetActiveDapps_ActiveWorksAsExpected(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + not := common.NewAndSet(false) + yes := common.NewAndSet(true) + timeNow := 4 + entries := generateTestData([]testSession{ + {url: urlFor(1), expiry: at(timeNow - 3), disconnected: not}, + {url: urlFor(1), expiry: at(timeNow - 2), disconnected: yes}, + {url: urlFor(2), expiry: at(timeNow - 2), disconnected: not}, + {url: urlFor(3), expiry: at(timeNow - 1), disconnected: yes}, + // ----- timeNow + {url: urlFor(3), expiry: at(timeNow + 1), disconnected: not}, + {url: urlFor(4), expiry: at(timeNow + 1), disconnected: yes}, + {url: urlFor(4), expiry: at(timeNow + 2), disconnected: not}, + {url: urlFor(5), expiry: at(timeNow + 2), disconnected: yes}, + {url: urlFor(6), expiry: at(timeNow + 3), disconnected: not}, + }) + insertTestData(t, db, entries) + + getTestnet := false + dapps, err := GetActiveDapps(db, int64(timeNow), getTestnet) + require.NoError(t, err) + require.Equal(t, 3, len(dapps)) +} + +// TestGetActiveDapps_TestChainsWorksAsExpected tests the combination of disconnected and expired sessions +func TestGetActiveDapps_TestChainsWorksAsExpected(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + not := common.NewAndSet(false) + yes := common.NewAndSet(true) + timeNow := 4 + entries := generateTestData([]testSession{ + {url: urlFor(1), testChains: not, expiry: at(timeNow - 3), disconnected: not}, + {url: urlFor(2), testChains: yes, expiry: at(timeNow - 2), disconnected: not}, + {url: urlFor(2), testChains: not, expiry: at(timeNow - 1), disconnected: not}, + // ----- timeNow + {url: urlFor(3), testChains: not, expiry: at(timeNow + 1), disconnected: not}, + {url: urlFor(4), testChains: not, expiry: at(timeNow + 2), disconnected: not}, + {url: urlFor(4), testChains: yes, expiry: at(timeNow + 3), disconnected: not}, + {url: urlFor(5), testChains: yes, expiry: at(timeNow + 4), disconnected: not}, + }) + insertTestData(t, db, entries) + + getTestnet := true + dapps, err := GetActiveDapps(db, int64(timeNow), getTestnet) + require.NoError(t, err) + require.Equal(t, 2, len(dapps)) +} + +// TestGetDapps_EmptyDB tests that an empty database will return an empty list +func TestGetDapps_EmptyDB(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + entries := generateTestData([]testSession{}) + insertTestData(t, db, entries) + + getTestnet := false + validAtTimestamp := int64(0) + dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet) + require.NoError(t, err) + require.Equal(t, 0, len(dapps)) +} + +// TestGetDapps_OrphanDapps tests that missing session will place the dapp at the end +func TestGetDapps_OrphanDapps(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + not := common.NewAndSet(false) + entries := generateTestData([]testSession{ + {url: urlFor(1), disconnected: not}, + {url: urlFor(2), disconnected: not}, + {url: urlFor(2), disconnected: not}, + }) + insertTestData(t, db, entries) + + err := DeleteSession(db, entries[1].Topic) + require.NoError(t, err) + err = DeleteSession(db, entries[2].Topic) + require.NoError(t, err) + + getTestnet := false + validAtTimestamp := entries[0].Expiry + dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet) + require.NoError(t, err) + // The orphan dapp is not considered active + require.Equal(t, 1, len(dapps)) + require.Equal(t, entries[0].Name, dapps[0].Name) +} + +func TestDisconnectSession(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + not := common.NewAndSet(false) + entries := generateTestData([]testSession{ + {url: urlFor(1), disconnected: not}, + {url: urlFor(2), disconnected: not}, + {url: urlFor(2), disconnected: not}, + }) + insertTestData(t, db, entries) + + activeSessions, err := GetActiveSessions(db, 0) + require.NoError(t, err) + require.Equal(t, 3, len(activeSessions)) + + getTestnet := false + validAtTimestamp := entries[0].Expiry + dapps, err := GetActiveDapps(db, validAtTimestamp, getTestnet) + require.NoError(t, err) + require.Equal(t, 2, len(dapps)) + + err = DisconnectSession(db, entries[1].Topic) + require.NoError(t, err) + + activeSessions, err = GetActiveSessions(db, 0) + require.NoError(t, err) + require.Equal(t, 2, len(activeSessions)) + + err = DisconnectSession(db, entries[2].Topic) + require.NoError(t, err) + + activeSessions, err = GetActiveSessions(db, 0) + require.NoError(t, err) + require.Equal(t, 1, len(activeSessions)) + + dapps, err = GetActiveDapps(db, validAtTimestamp, getTestnet) + require.NoError(t, err) + require.Equal(t, 1, len(dapps)) + require.Equal(t, entries[0].Name, dapps[0].Name) +} diff --git a/services/wallet/walletconnect/walletconnect.go b/services/wallet/walletconnect/walletconnect.go index a193456de..6ae96f7cf 100644 --- a/services/wallet/walletconnect/walletconnect.go +++ b/services/wallet/walletconnect/walletconnect.go @@ -1,15 +1,18 @@ package walletconnect import ( + "database/sql" "encoding/json" "errors" "fmt" "strconv" "strings" + "time" "github.com/ethereum/go-ethereum/log" "github.com/status-im/status-go/multiaccounts/accounts" + "github.com/status-im/status-go/params" "github.com/status-im/status-go/services/wallet/walletevent" ) @@ -61,6 +64,8 @@ type VerifyContext struct { Verified Verified `json:"verified"` } +// Params has RequiredNamespaces entries if part of "proposal namespace" and Namespaces entries if part of "session namespace" +// see https://specs.walletconnect.com/2.0/specs/clients/sign/namespaces#controller-side-validation-of-incoming-proposal-namespaces-wallet type Params struct { ID int64 `json:"id"` PairingTopic Topic `json:"pairingTopic"` @@ -138,8 +143,8 @@ func (n *Namespace) Valid(namespaceName string, chainID *uint64) bool { return true } -// Valid params -func (p *Params) Valid() bool { +// ValidateForProposal validates params part of the Proposal Namespace +func (p *Params) ValidateForProposal() bool { for key, ns := range p.RequiredNamespaces { var chainID *uint64 if strings.Contains(key, ":") { @@ -165,15 +170,62 @@ func (p *Params) Valid() bool { return true } -// Valid session propsal +// ValidateProposal validates params part of the Proposal Namespace // https://specs.walletconnect.com/2.0/specs/clients/sign/namespaces#controller-side-validation-of-incoming-proposal-namespaces-wallet -func (p *SessionProposal) Valid() bool { - return p.Params.Valid() +func (p *SessionProposal) ValidateProposal() bool { + return p.Params.ValidateForProposal() } -func sessionProposalToSupportedChain(caipChains []string, supportsChain func(uint64) bool) (chains []uint64, eipChains []string) { - chains = make([]uint64, 0, 1) - eipChains = make([]string, 0, 1) +// AddSession adds a new active session to the database +func AddSession(db *sql.DB, networks []params.Network, session_json string) error { + var session Session + err := json.Unmarshal([]byte(session_json), &session) + if err != nil { + return fmt.Errorf("unmarshal session: %v", err) + } + + chains := supportedChainsInSession(session) + testChains, err := areTestChains(networks, chains) + if err != nil { + return fmt.Errorf("areTestChains: %v", err) + } + + rowEntry := DBSession{ + Topic: session.PairingTopic, + Disconnected: false, + SessionJSON: session_json, + Expiry: session.Expiry, + CreatedTimestamp: time.Now().Unix(), + PairingTopic: session.PairingTopic, + TestChains: testChains, + DBDApp: DBDApp{ + URL: session.Peer.Metadata.URL, + Name: session.Peer.Metadata.Name, + }, + } + if len(session.Peer.Metadata.Icons) > 0 { + rowEntry.IconURL = session.Peer.Metadata.Icons[0] + } + + return UpsertSession(db, rowEntry) +} + +// areTestChains assumes chains to tests are all testnets or all mainnets +func areTestChains(networks []params.Network, chainIDs []uint64) (isTest bool, err error) { + for _, n := range networks { + for _, chainID := range chainIDs { + if n.ChainID == chainID { + return n.IsTest, nil + } + } + } + + return false, fmt.Errorf("no network found for chainIDs %v", chainIDs) +} + +func supportedChainsInSession(session Session) []uint64 { + caipChains := session.Namespaces[SupportedEip155Namespace].Chains + chains := make([]uint64, 0, len(caipChains)) for _, caip2Str := range caipChains { _, chainID, err := parseCaip2ChainID(caip2Str) if err != nil { @@ -181,14 +233,9 @@ func sessionProposalToSupportedChain(caipChains []string, supportsChain func(uin continue } - if !supportsChain(chainID) { - continue - } - - eipChains = append(eipChains, caip2Str) chains = append(chains, chainID) } - return + return chains } func caip10Accounts(accounts []*accounts.Account, chains []uint64) []string { diff --git a/services/wallet/walletconnect/walletconnect_test.go b/services/wallet/walletconnect/walletconnect_test.go index 67613e654..8e71af86f 100644 --- a/services/wallet/walletconnect/walletconnect_test.go +++ b/services/wallet/walletconnect/walletconnect_test.go @@ -2,16 +2,101 @@ package walletconnect import ( "reflect" + "strconv" "testing" "encoding/json" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/multiaccounts/accounts" + "github.com/status-im/status-go/params" "github.com/stretchr/testify/assert" ) +func getSessionJSONFor(chains []int, expiry int) string { + chainsStr := "[" + for i, chain := range chains { + chainsStr += `"eip155:` + strconv.Itoa(chain) + `"` + if i != len(chains)-1 { + chainsStr += "," + } + } + chainsStr += "]" + expiryStr := strconv.Itoa(expiry) + + return `{ + "expiry": ` + expiryStr + `, + "namespaces": { + "eip155": { + "accounts": [ + "eip155:1:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240", + "eip155:10:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240", + "eip155:42161:0x7F47C2e18a4BBf5487E6fb082eC2D9Ab0E6d7240" + ], + "chains": ` + chainsStr + `, + "events": [ + "accountsChanged", + "chainChanged" + ], + "methods": [ + "eth_sendTransaction", + "personal_sign" + ] + } + }, + "optionalNamespaces": { + "eip155": { + "chains": [], + "events": [], + "methods": [], + "rpcMap": {} + } + }, + "pairingTopic": "50fba141cdb5c015493c2907c46bacf9f7cbd7c8e3d4e97df891f18dddcff69c", + "peer": { + "metadata": { + "description": "Test Dapp Description", + "icons": [ "https://test.org/test.png"], + "name": "Test Dapp", + "url": "https://dapp.test.org" + }, + "publicKey": "1234567890aeb6081cabed26faf48919162fd70cc66d639f118a60507ae0463d" + }, + "relay": { "protocol": "irn"}, + "requiredNamespaces": { + "eip155": { + "chains": [ + "eip155:1" + ], + "events": [ + "chainChanged", + "accountsChanged" + ], + "methods": [ + "eth_sendTransaction", + "personal_sign" + ], + "rpcMap": { + "1": "https://mainnet.infura.io/v3/099fc58e0de9451d80b18d7c74caa7c1" + } + } + }, + "self": { + "metadata": { + "description": "Test Wallet Description", + "icons": [ + "https://wallet.test.org/test.svg" + ], + "name": "Test Wallet", + "url": "http://localhost" + }, + "publicKey": "da4a87d5f0f54951afe870ebf020cf03f8a3522fbd219398c3fa159a37e16d54" + }, + "topic": "e39e1f435a46b5ee6b31484d1751cfbc35be1275653af2ea340974a7592f1a19" + }` +} + func Test_sessionProposalValidity(t *testing.T) { tests := []struct { name string @@ -183,68 +268,44 @@ func Test_sessionProposalValidity(t *testing.T) { err := json.Unmarshal([]byte(tt.sessionProposalJSON), &sessionProposal) assert.NoError(t, err) + validRes := sessionProposal.ValidateProposal() if tt.expectedValidity { - assert.True(t, sessionProposal.Valid()) + assert.True(t, validRes) } else { - assert.False(t, sessionProposal.Valid()) + assert.False(t, validRes) } }) } } -func Test_sessionProposalToSupportedChain(t *testing.T) { +func Test_supportedChainInSession(t *testing.T) { type args struct { - chains []string - supportsChain func(uint64) bool + sessionProposal Session } tests := []struct { - name string - args args - wantChains []uint64 - wantEipChains []string + name string + args args + expectedChains []uint64 }{ { - name: "filter_out_unsupported_chains_and_invalid_chains", + name: "supported_chain", args: args{ - chains: []string{"eip155:1", "eip155:3", "eip155:invalid"}, - supportsChain: func(chainID uint64) bool { - return chainID == 1 + sessionProposal: Session{ + Namespaces: map[string]Namespace{ + "eip155": { + Chains: []string{"eip155:1", "eip155:2", "eip155:3", "eip155:4", "eip155:5"}, + }, + }, }, }, - wantChains: []uint64{1}, - wantEipChains: []string{"eip155:1"}, - }, - { - name: "no_supported_chains", - args: args{ - chains: []string{"eip155:3", "eip155:5"}, - supportsChain: func(chainID uint64) bool { - return false - }, - }, - wantChains: []uint64{}, - wantEipChains: []string{}, - }, - { - name: "empty_proposal", - args: args{ - chains: []string{}, - supportsChain: func(chainID uint64) bool { - return true - }, - }, - wantChains: []uint64{}, - wantEipChains: []string{}, + expectedChains: []uint64{1, 2, 3, 4, 5}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotChains, gotEipChains := sessionProposalToSupportedChain(tt.args.chains, tt.args.supportsChain) - if !reflect.DeepEqual(gotChains, tt.wantChains) { - t.Errorf("sessionProposalToSupportedChain() gotChains = %v, want %v", gotChains, tt.wantChains) - } - if !reflect.DeepEqual(gotEipChains, tt.wantEipChains) { - t.Errorf("sessionProposalToSupportedChain() gotEipChains = %v, want %v", gotEipChains, tt.wantEipChains) + gotChains := supportedChainsInSession(tt.args.sessionProposal) + if !reflect.DeepEqual(gotChains, tt.expectedChains) { + t.Errorf("supportedChainInSessionProposal() gotChains = %v, want %v", gotChains, tt.expectedChains) } }) } @@ -316,3 +377,23 @@ func Test_caip10Accounts(t *testing.T) { }) } } + +// Test_AddSession validates that the new added session is active (not expired and not disconnected) +func Test_AddSession(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + // Add session for testnet + expiry := 1716581732 + sessionJSON := getSessionJSONFor([]int{11155111}, expiry) + networks := []params.Network{ + {ChainID: 1, IsTest: false}, + {ChainID: 11155111, IsTest: true}, + } + err := AddSession(db, networks, sessionJSON) + assert.NoError(t, err) + + dapps, err := GetActiveDapps(db, int64(expiry-1), true) + assert.NoError(t, err) + assert.Equal(t, 1, len(dapps)) +}