status-go/services/wallet/walletconnect/database_test.go

175 lines
4.3 KiB
Go
Raw Normal View History

package walletconnect
import (
"strconv"
"testing"
"database/sql"
"github.com/status-im/status-go/t/helpers"
"github.com/status-im/status-go/walletdatabase"
"github.com/stretchr/testify/require"
)
func setupTestDB(t *testing.T) (db *sql.DB, close func()) {
db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
require.NoError(t, err)
return db, func() {
require.NoError(t, db.Close())
}
}
// generateTestData generates alternative disconnected and active sessions starting with the active one
// timestamps start with 1234567890
func generateTestData(count int) []DbSession {
res := make([]DbSession, count)
j := 0
for i := 0; i < count; i++ {
strI := strconv.Itoa(i)
if i%4 == 0 {
j++
}
strJ := strconv.Itoa(j)
res[i] = DbSession{
Topic: Topic(strI + "aaaaaa1234567890"),
PairingTopic: Topic(strJ + "bbbbbb1234567890"),
Expiry: 1234567890 + int64(i),
Active: (i % 2) == 0,
DappName: "TestApp" + strI,
DappURL: "https://test.url/" + strI,
DappDescription: "Test Description" + strI,
DappIcon: "https://test.icon" + strI,
DappVerifyURL: "https://test.verify.url/" + strI,
DappPublicKey: strI + "1234567890",
}
}
return res
}
func insertTestData(t *testing.T, db *sql.DB, entries []DbSession) {
for _, entry := range entries {
err := UpsertSession(db, entry)
require.NoError(t, err)
}
}
func TestInsertUpdateAndGetSession(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entry := generateTestData(1)[0]
err := UpsertSession(db, entry)
require.NoError(t, err)
retrievedSession, err := GetSessionByTopic(db, entry.Topic)
require.NoError(t, err)
require.Equal(t, entry, *retrievedSession)
entry.Active = false
entry.Expiry = 1111111111
err = UpsertSession(db, entry)
require.NoError(t, err)
retrievedSession, err = GetSessionByTopic(db, entry.Topic)
require.NoError(t, err)
require.Equal(t, entry, *retrievedSession)
}
func TestInsertAndGetSessionsByPairingTopic(t *testing.T) {
db, close := setupTestDB(t)
defer close()
generatedSessions := generateTestData(10)
for _, session := range generatedSessions {
err := UpsertSession(db, session)
require.NoError(t, err)
}
retrievedSessions, err := GetSessionsByPairingTopic(db, generatedSessions[4].Topic)
require.NoError(t, err)
require.Equal(t, 0, len(retrievedSessions))
retrievedSessions, err = GetSessionsByPairingTopic(db, generatedSessions[4].PairingTopic)
require.NoError(t, err)
require.Equal(t, 4, len(retrievedSessions))
for i := 4; i < 8; i++ {
found := false
for _, session := range retrievedSessions {
if session.Topic == generatedSessions[i].Topic {
found = true
require.Equal(t, generatedSessions[i], session)
}
}
require.True(t, found)
}
}
func TestChangeSessionState(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entry := generateTestData(1)[0]
err := UpsertSession(db, entry)
require.NoError(t, err)
err = ChangeSessionState(db, entry.Topic, false)
require.NoError(t, err)
retrievedSession, err := GetSessionByTopic(db, entry.Topic)
require.NoError(t, err)
require.Equal(t, false, retrievedSession.Active)
}
func TestGet(t *testing.T) {
db, close := setupTestDB(t)
defer close()
entries := generateTestData(3)
insertTestData(t, db, entries)
retrievedSession, err := GetSessionByTopic(db, entries[1].Topic)
require.NoError(t, err)
require.Equal(t, entries[1], *retrievedSession)
}
func TestGetActiveSessions(t *testing.T) {
db, close := setupTestDB(t)
defer close()
// insert two disconnected and three active sessions
entries := generateTestData(5)
insertTestData(t, db, entries)
activeSessions, err := GetActiveSessions(db, 1234567892)
require.NoError(t, err)
require.Equal(t, 2, len(activeSessions))
// Expect newest on top
require.Equal(t, entries[4], activeSessions[0])
require.Equal(t, entries[2], activeSessions[1])
}
// func TestHasActivePairings(t *testing.T) {
// db, close := setupTestDB(t)
// defer close()
// // insert one disconnected and two active pairing
// entries := generateTestData(2)
// insertTestData(t, db, entries)
// hasActivePairings, err := HasActivePairings(db, 1234567890)
// require.NoError(t, err)
// require.True(t, hasActivePairings)
// hasActivePairings, err = HasActivePairings(db, 1234567891)
// require.NoError(t, err)
// require.False(t, hasActivePairings)
// }