2023-11-19 19:29:17 +02:00
package walletconnect
import (
"database/sql"
2024-05-23 15:19:00 +03:00
"fmt"
"github.com/ethereum/go-ethereum/log"
2023-11-19 19:29:17 +02:00
)
2024-05-23 15:19:00 +03:00
type DBSession struct {
Topic Topic
Disconnected bool
SessionJSON string
Expiry int64
CreatedTimestamp int64
PairingTopic Topic
TestChains bool
DBDApp
2023-11-19 19:29:17 +02:00
}
2024-05-23 15:19:00 +03:00
type DBDApp struct {
URL string ` json:"url" `
Name string ` json:"name" `
IconURL string ` json:"iconUrl" `
2023-11-19 19:29:17 +02:00
}
2024-05-23 15:19:00 +03:00
func UpsertSession ( db * sql . DB , data DBSession ) error {
tx , err := db . Begin ( )
2023-11-26 17:50:12 +02:00
if err != nil {
2024-05-23 15:19:00 +03:00
return fmt . Errorf ( "begin transaction: %v" , err )
}
defer func ( ) {
if err != nil {
rollErr := tx . Rollback ( )
if rollErr != nil {
log . Error ( "error rolling back transaction" , "rollErr" , rollErr , "err" , err )
}
}
} ( )
upsertDappStmt := ` INSERT INTO wallet_connect_dapps ( url , name , icon_url ) VALUES ( ? , ? , ? )
ON CONFLICT ( url ) DO UPDATE SET name = excluded . name , icon_url = excluded . icon_url `
_ , err = tx . Exec ( upsertDappStmt , data . URL , data . Name , data . IconURL )
if err != nil {
return fmt . Errorf ( "upsert wallet_connect_dapps: %v" , err )
2023-11-26 17:50:12 +02:00
}
2024-05-23 15:19:00 +03:00
upsertSessionStmt := ` INSERT INTO wallet_connect_sessions (
topic ,
disconnected ,
session_json ,
expiry ,
created_timestamp ,
pairing_topic ,
test_chains ,
dapp_url
)
VALUES ( ? , ? , ? , ? , ? , ? , ? , ? )
ON CONFLICT ( topic ) DO UPDATE SET
disconnected = excluded . disconnected ,
session_json = excluded . session_json ,
expiry = excluded . expiry ,
created_timestamp = excluded . created_timestamp ,
pairing_topic = excluded . pairing_topic ,
test_chains = excluded . test_chains ,
dapp_url = excluded . dapp_url ; `
_ , err = tx . Exec ( upsertSessionStmt , data . Topic , data . Disconnected , data . SessionJSON , data . Expiry , data . CreatedTimestamp , data . PairingTopic , data . TestChains , data . URL )
if err != nil {
return fmt . Errorf ( "insert session: %v" , err )
}
if err = tx . Commit ( ) ; err != nil {
return fmt . Errorf ( "commit transaction: %v" , err )
}
return nil
}
func DeleteSession ( db * sql . DB , topic Topic ) error {
_ , err := db . Exec ( "DELETE FROM wallet_connect_sessions WHERE topic = ?" , topic )
return err
}
func DisconnectSession ( db * sql . DB , topic Topic ) error {
res , err := db . Exec ( "UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?" , topic )
2023-11-26 17:50:12 +02:00
if err != nil {
return err
}
rowsAffected , err := res . RowsAffected ( )
2023-11-29 14:34:08 +01:00
if err != nil {
return err
}
2023-11-26 17:50:12 +02:00
if rowsAffected == 0 {
2024-05-23 15:19:00 +03:00
return fmt . Errorf ( "topic %s not found to update state" , topic )
2023-11-26 17:50:12 +02:00
}
return nil
}
2024-05-23 15:19:00 +03:00
// GetSessionByTopic returns sql.ErrNoRows if no session is found.
func GetSessionByTopic ( db * sql . DB , topic Topic ) ( * DBSession , error ) {
query := selectAndJoinQueryStr + " WHERE sessions.topic = ?"
2023-11-19 19:29:17 +02:00
2024-05-23 15:19:00 +03:00
row := db . QueryRow ( query , topic )
return scanSession ( singleRow { row } )
}
2023-11-19 19:29:17 +02:00
2024-05-23 15:19:00 +03:00
// GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found.
func GetSessionsByPairingTopic ( db * sql . DB , pairingTopic Topic ) ( [ ] DBSession , error ) {
query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?"
rows , err := db . Query ( query , pairingTopic )
2023-11-19 19:29:17 +02:00
if err != nil {
return nil , err
}
2024-05-23 15:19:00 +03:00
defer rows . Close ( )
2023-11-19 19:29:17 +02:00
2024-05-23 15:19:00 +03:00
return scanSessions ( rows )
}
type Scanner interface {
Scan ( dest ... interface { } ) error
}
type singleRow struct {
* sql . Row
2023-11-19 19:29:17 +02:00
}
2024-05-23 15:19:00 +03:00
func ( r singleRow ) Scan ( dest ... interface { } ) error {
return r . Row . Scan ( dest ... )
}
const selectAndJoinQueryStr = `
SELECT
sessions . topic , sessions . disconnected , sessions . session_json , sessions . expiry , sessions . created_timestamp ,
sessions . pairing_topic , sessions . test_chains , sessions . dapp_url , dapps . name , dapps . icon_url
2023-12-13 15:05:55 +01:00
FROM
2024-05-23 15:19:00 +03:00
wallet_connect_sessions sessions
JOIN
wallet_connect_dapps dapps ON sessions . dapp_url = dapps . url `
// scanSession scans a single session from the given scanner following selectAndJoinQueryStr.
func scanSession ( scanner Scanner ) ( * DBSession , error ) {
var session DBSession
err := scanner . Scan (
& session . Topic ,
& session . Disconnected ,
& session . SessionJSON ,
& session . Expiry ,
& session . CreatedTimestamp ,
& session . PairingTopic ,
& session . TestChains ,
& session . URL ,
& session . Name ,
& session . IconURL ,
)
2023-11-19 19:29:17 +02:00
if err != nil {
return nil , err
}
2024-05-23 15:19:00 +03:00
return & session , nil
}
// scanSessions returns sql.ErrNoRows if nothing is scanned.
func scanSessions ( rows * sql . Rows ) ( [ ] DBSession , error ) {
var sessions [ ] DBSession
2023-11-19 19:29:17 +02:00
for rows . Next ( ) {
2024-05-23 15:19:00 +03:00
session , err := scanSession ( rows )
2023-11-19 19:29:17 +02:00
if err != nil {
return nil , err
}
2024-05-23 15:19:00 +03:00
sessions = append ( sessions , * session )
2023-11-19 19:29:17 +02:00
}
2023-12-13 15:05:55 +01:00
2023-11-19 19:29:17 +02:00
if err := rows . Err ( ) ; err != nil {
return nil , err
}
2023-12-13 15:05:55 +01:00
return sessions , nil
2023-11-19 19:29:17 +02:00
}
2024-05-23 15:19:00 +03:00
// GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp.
func GetActiveSessions ( db * sql . DB , validAtTimestamp int64 ) ( [ ] DBSession , error ) {
querySQL := selectAndJoinQueryStr + `
WHERE
sessions . disconnected = 0 AND
sessions . expiry >= ?
ORDER BY
sessions . expiry DESC `
rows , err := db . Query ( querySQL , validAtTimestamp )
if err != nil {
return nil , err
}
defer rows . Close ( )
return scanSessions ( rows )
}
// GetSessions returns all sessions in the ascending order of creation time
func GetSessions ( db * sql . DB ) ( [ ] DBSession , error ) {
querySQL := selectAndJoinQueryStr + `
ORDER BY
sessions . created_timestamp DESC `
rows , err := db . Query ( querySQL )
2023-11-19 19:29:17 +02:00
if err != nil {
2023-12-13 15:05:55 +01:00
return nil , err
}
defer rows . Close ( )
2024-05-23 15:19:00 +03:00
return scanSessions ( rows )
}
// GetActiveDapps returns all dapps in the order of last first time connected (first session creation time)
func GetActiveDapps ( db * sql . DB , validAtTimestamp int64 , testChains bool ) ( [ ] DBDApp , error ) {
query := ` SELECT dapps . url , dapps . name , dapps . icon_url , MIN ( sessions . created_timestamp ) as dapp_creation_time
FROM
wallet_connect_dapps dapps
JOIN
wallet_connect_sessions sessions ON dapps . url = sessions . dapp_url
WHERE sessions . disconnected = 0 AND sessions . expiry >= ? AND sessions . test_chains = ?
GROUP BY dapps . url
ORDER BY dapp_creation_time DESC ; `
rows , err := db . Query ( query , validAtTimestamp , testChains )
if err != nil {
return nil , err
}
defer rows . Close ( )
var dapps [ ] DBDApp
2023-12-13 15:05:55 +01:00
for rows . Next ( ) {
2024-05-23 15:19:00 +03:00
var dapp DBDApp
var creationTime sql . NullInt64
if err := rows . Scan ( & dapp . URL , & dapp . Name , & dapp . IconURL , & creationTime ) ; err != nil {
2023-12-13 15:05:55 +01:00
return nil , err
}
2024-05-23 15:19:00 +03:00
dapps = append ( dapps , dapp )
2023-12-13 15:05:55 +01:00
}
if err := rows . Err ( ) ; err != nil {
return nil , err
2023-11-19 19:29:17 +02:00
}
2024-05-23 15:19:00 +03:00
return dapps , nil
2023-11-19 19:29:17 +02:00
}