package walletconnect import ( "database/sql" "fmt" "github.com/ethereum/go-ethereum/log" ) type DBSession struct { Topic Topic Disconnected bool SessionJSON string Expiry int64 CreatedTimestamp int64 PairingTopic Topic TestChains bool DBDApp } type DBDApp struct { URL string `json:"url"` Name string `json:"name"` IconURL string `json:"iconUrl"` } func UpsertSession(db *sql.DB, data DBSession) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("begin transaction: %v", err) } defer func() { if err != nil { rollErr := tx.Rollback() if rollErr != nil { log.Error("error rolling back transaction", "rollErr", rollErr, "err", err) } } }() upsertDappStmt := `INSERT INTO wallet_connect_dapps (url, name, icon_url) VALUES (?, ?, ?) ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url` _, err = tx.Exec(upsertDappStmt, data.URL, data.Name, data.IconURL) if err != nil { return fmt.Errorf("upsert wallet_connect_dapps: %v", err) } upsertSessionStmt := `INSERT INTO wallet_connect_sessions ( topic, disconnected, session_json, expiry, created_timestamp, pairing_topic, test_chains, dapp_url ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(topic) DO UPDATE SET disconnected = excluded.disconnected, session_json = excluded.session_json, expiry = excluded.expiry, created_timestamp = excluded.created_timestamp, pairing_topic = excluded.pairing_topic, test_chains = excluded.test_chains, dapp_url = excluded.dapp_url;` _, err = tx.Exec(upsertSessionStmt, data.Topic, data.Disconnected, data.SessionJSON, data.Expiry, data.CreatedTimestamp, data.PairingTopic, data.TestChains, data.URL) if err != nil { return fmt.Errorf("insert session: %v", err) } if err = tx.Commit(); err != nil { return fmt.Errorf("commit transaction: %v", err) } return nil } func DeleteSession(db *sql.DB, topic Topic) error { _, err := db.Exec("DELETE FROM wallet_connect_sessions WHERE topic = ?", topic) return err } func DisconnectSession(db *sql.DB, topic Topic) error { res, err := db.Exec("UPDATE wallet_connect_sessions SET disconnected = 1 WHERE topic = ?", topic) if err != nil { return err } rowsAffected, err := res.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return fmt.Errorf("topic %s not found to update state", topic) } return nil } // GetSessionByTopic returns sql.ErrNoRows if no session is found. func GetSessionByTopic(db *sql.DB, topic Topic) (*DBSession, error) { query := selectAndJoinQueryStr + " WHERE sessions.topic = ?" row := db.QueryRow(query, topic) return scanSession(singleRow{row}) } // GetSessionsByPairingTopic returns sql.ErrNoRows if no session is found. func GetSessionsByPairingTopic(db *sql.DB, pairingTopic Topic) ([]DBSession, error) { query := selectAndJoinQueryStr + " WHERE sessions.pairing_topic = ?" rows, err := db.Query(query, pairingTopic) if err != nil { return nil, err } defer rows.Close() return scanSessions(rows) } type Scanner interface { Scan(dest ...interface{}) error } type singleRow struct { *sql.Row } func (r singleRow) Scan(dest ...interface{}) error { return r.Row.Scan(dest...) } const selectAndJoinQueryStr = ` SELECT sessions.topic, sessions.disconnected, sessions.session_json, sessions.expiry, sessions.created_timestamp, sessions.pairing_topic, sessions.test_chains, sessions.dapp_url, dapps.name, dapps.icon_url FROM wallet_connect_sessions sessions JOIN wallet_connect_dapps dapps ON sessions.dapp_url = dapps.url` // scanSession scans a single session from the given scanner following selectAndJoinQueryStr. func scanSession(scanner Scanner) (*DBSession, error) { var session DBSession err := scanner.Scan( &session.Topic, &session.Disconnected, &session.SessionJSON, &session.Expiry, &session.CreatedTimestamp, &session.PairingTopic, &session.TestChains, &session.URL, &session.Name, &session.IconURL, ) if err != nil { return nil, err } return &session, nil } // scanSessions returns sql.ErrNoRows if nothing is scanned. func scanSessions(rows *sql.Rows) ([]DBSession, error) { var sessions []DBSession for rows.Next() { session, err := scanSession(rows) if err != nil { return nil, err } sessions = append(sessions, *session) } if err := rows.Err(); err != nil { return nil, err } return sessions, nil } // GetActiveSessions returns all active sessions (not disconnected and not expired) that have an expiry timestamp newer or equal to the given timestamp. func GetActiveSessions(db *sql.DB, validAtTimestamp int64) ([]DBSession, error) { querySQL := selectAndJoinQueryStr + ` WHERE sessions.disconnected = 0 AND sessions.expiry >= ? ORDER BY sessions.expiry DESC` rows, err := db.Query(querySQL, validAtTimestamp) if err != nil { return nil, err } defer rows.Close() return scanSessions(rows) } // GetSessions returns all sessions in the ascending order of creation time func GetSessions(db *sql.DB) ([]DBSession, error) { querySQL := selectAndJoinQueryStr + ` ORDER BY sessions.created_timestamp DESC` rows, err := db.Query(querySQL) if err != nil { return nil, err } defer rows.Close() return scanSessions(rows) } // GetActiveDapps returns all dapps in the order of last first time connected (first session creation time) func GetActiveDapps(db *sql.DB, validAtTimestamp int64, testChains bool) ([]DBDApp, error) { query := `SELECT dapps.url, dapps.name, dapps.icon_url, MIN(sessions.created_timestamp) as dapp_creation_time FROM wallet_connect_dapps dapps JOIN wallet_connect_sessions sessions ON dapps.url = sessions.dapp_url WHERE sessions.disconnected = 0 AND sessions.expiry >= ? AND sessions.test_chains = ? GROUP BY dapps.url ORDER BY dapp_creation_time DESC;` rows, err := db.Query(query, validAtTimestamp, testChains) if err != nil { return nil, err } defer rows.Close() var dapps []DBDApp for rows.Next() { var dapp DBDApp var creationTime sql.NullInt64 if err := rows.Scan(&dapp.URL, &dapp.Name, &dapp.IconURL, &creationTime); err != nil { return nil, err } dapps = append(dapps, dapp) } if err := rows.Err(); err != nil { return nil, err } return dapps, nil }