test(store): add fixture for sqlite and postgres

This commit is contained in:
harsh-98 2023-10-03 23:02:23 +07:00 committed by harsh jain
parent d268b2e403
commit 5d0692b339
8 changed files with 75 additions and 48 deletions

View File

@ -1,19 +1,17 @@
package persistence package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"log" "log"
"os" "os"
_ "github.com/jackc/pgx/v5/stdlib" // Blank import to register the postgres driver
) )
// var dbUrlTemplate = "postgres://postgres@localhost:%s/%s?sslmode=disable" // var dbUrlTemplate = "postgres://postgres@localhost:%s/%s?sslmode=disable"
var dbUrlTemplate = "postgres://harshjain@localhost:%s/%s?sslmode=disable" var dbUrlTemplate = "postgres://harshjain@localhost:%s/%s?sslmode=disable"
func ResetDefaultTestPostgresDB(dropDBUrl string) error { func ResetDefaultTestPostgresDB(dropDBUrl string) error {
db, err := sql.Open("postgres", dropDBUrl) db, err := sql.Open("pgx", dropDBUrl)
if err != nil { if err != nil {
return err return err
} }
@ -41,9 +39,8 @@ func NewMockPgDB() *sql.DB {
// //
dropDBUrl := fmt.Sprintf(dbUrlTemplate, mockPgDBPort, "template1") dropDBUrl := fmt.Sprintf(dbUrlTemplate, mockPgDBPort, "template1")
fmt.Println(dropDBUrl)
if err := ResetDefaultTestPostgresDB(dropDBUrl); err != nil { if err := ResetDefaultTestPostgresDB(dropDBUrl); err != nil {
log.Fatalf("an error '%s' was not expected when opening a stub database connection", err) log.Fatalf("an error '%s' while reseting the db", err)
} }
mockDBUrl := fmt.Sprintf(dbUrlTemplate, mockPgDBPort, "postgres") mockDBUrl := fmt.Sprintf(dbUrlTemplate, mockPgDBPort, "postgres")
db, err := sql.Open("pgx", mockDBUrl) db, err := sql.Open("pgx", mockDBUrl)

View File

@ -71,5 +71,5 @@ func NewQueries(tbl string, db *sql.DB) (*persistence.Queries, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return persistence.CreateQueries(tbl, db), nil return persistence.CreateQueries(tbl), nil
} }

View File

@ -1,7 +1,6 @@
package persistence package persistence
import ( import (
"database/sql"
"fmt" "fmt"
) )
@ -20,7 +19,7 @@ type Queries struct {
// CreateQueries Function creates a set of queries for an SQL table. // CreateQueries Function creates a set of queries for an SQL table.
// Note: Do not use this function to create queries for a table, rather use <rdb>.NewQueries to create table as well as queries. // Note: Do not use this function to create queries for a table, rather use <rdb>.NewQueries to create table as well as queries.
func CreateQueries(tbl string, db *sql.DB) *Queries { func CreateQueries(tbl string) *Queries {
return &Queries{ return &Queries{
deleteQuery: fmt.Sprintf("DELETE FROM %s WHERE key = $1", tbl), deleteQuery: fmt.Sprintf("DELETE FROM %s WHERE key = $1", tbl),
existsQuery: fmt.Sprintf("SELECT exists(SELECT 1 FROM %s WHERE key=$1)", tbl), existsQuery: fmt.Sprintf("SELECT exists(SELECT 1 FROM %s WHERE key=$1)", tbl),

View File

@ -0,0 +1,17 @@
package sqlite
import (
"database/sql"
"log"
_ "github.com/mattn/go-sqlite3"
)
func NewMockSqliteDB() *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
log.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
return db
}

View File

@ -91,5 +91,5 @@ func NewQueries(tbl string, db *sql.DB) (*persistence.Queries, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return persistence.CreateQueries(tbl, db), nil return persistence.CreateQueries(tbl), nil
} }

View File

@ -1,24 +1,13 @@
package sqlite package sqlite
import ( import (
"database/sql"
"log"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func NewMock() *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
log.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
return db
}
func TestQueries(t *testing.T) { func TestQueries(t *testing.T) {
db := NewMock() db := NewMockSqliteDB()
queries, err := NewQueries("test_queries", db) queries, err := NewQueries("test_queries", db)
require.NoError(t, err) require.NoError(t, err)
@ -51,7 +40,7 @@ func TestQueries(t *testing.T) {
} }
func TestCreateTable(t *testing.T) { func TestCreateTable(t *testing.T) {
db := NewMock() db := NewMockSqliteDB()
err := CreateTable(db, "test_create_table") err := CreateTable(db, "test_create_table")
require.NoError(t, err) require.NoError(t, err)

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"reflect"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -224,8 +225,8 @@ func (d *DBStore) cleanOlderRecords(ctx context.Context) error {
// Limit number of records to a max N // Limit number of records to a max N
if d.maxMessages > 0 { if d.maxMessages > 0 {
start := time.Now() start := time.Now()
sqlStmt := `DELETE FROM message WHERE id IN (SELECT id FROM message ORDER BY receiverTimestamp DESC OFFSET $1)`
_, err := d.db.Exec(sqlStmt, d.maxMessages) _, err := d.db.Exec(d.getDeleteOldRowsQuery(), d.maxMessages)
if err != nil { if err != nil {
d.metrics.RecordError(retPolicyFailure) d.metrics.RecordError(retPolicyFailure)
return err return err
@ -238,6 +239,16 @@ func (d *DBStore) cleanOlderRecords(ctx context.Context) error {
return nil return nil
} }
func (d *DBStore) getDeleteOldRowsQuery() string {
sqlStmt := `DELETE FROM message WHERE id IN (SELECT id FROM message ORDER BY receiverTimestamp DESC %s OFFSET $1)`
switch reflect.TypeOf(d.db.Driver()).String() {
case "*sqlite3.SQLiteDriver":
sqlStmt = fmt.Sprintf(sqlStmt, "LIMIT -1")
case "*stdlib.Driver":
sqlStmt = fmt.Sprintf(sqlStmt, "")
}
return sqlStmt
}
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) { func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
defer d.wg.Done() defer d.wg.Done()

View File

@ -1,7 +1,7 @@
//go:build include_postgres_tests //go:build include_postgres_tests
// +build include_postgres_tests // +build include_postgres_tests
package persistence package utils
import ( import (
"context" "context"
@ -10,32 +10,49 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/mattn/go-sqlite3" // Blank import to register the sqlite3 driver
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/tests" "github.com/waku-org/go-waku/tests"
"github.com/waku-org/go-waku/waku/persistence/migrate" "github.com/waku-org/go-waku/waku/persistence"
postgresmigration "github.com/waku-org/go-waku/waku/persistence/postgres/migrations" "github.com/waku-org/go-waku/waku/persistence/postgres"
"github.com/waku-org/go-waku/waku/persistence/sqlite"
"github.com/waku-org/go-waku/waku/v2/protocol" "github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/store/pb" "github.com/waku-org/go-waku/waku/v2/protocol/store/pb"
"github.com/waku-org/go-waku/waku/v2/timesource" "github.com/waku-org/go-waku/waku/v2/timesource"
"github.com/waku-org/go-waku/waku/v2/utils" "github.com/waku-org/go-waku/waku/v2/utils"
) )
func Migrate(db *sql.DB) error { func TestStore(t *testing.T) {
migrationDriver, err := postgres.WithInstance(db, &postgres.Config{ tests := []struct {
MigrationsTable: "gowaku_" + postgres.DefaultMigrationsTable, name string
}) fn func(t *testing.T, db *sql.DB, migrationFn func(*sql.DB) error)
if err != nil { }{
return err {"testDbStore", testDbStore},
{"testStoreRetention", testStoreRetention},
{"testQuery", testQuery},
}
for _, driverName := range []string{"postgres", "sqlite"} {
// all tests are run for each db
for _, tc := range tests {
db, migrationFn := getDB(driverName)
t.Run(driverName+"_"+tc.name, func(t *testing.T) {
tc.fn(t, db, migrationFn)
})
}
} }
return migrate.Migrate(db, migrationDriver, postgresmigration.AssetNames(), postgresmigration.Asset)
} }
func TestDbStore(t *testing.T) { func getDB(driver string) (*sql.DB, func(*sql.DB) error) {
db := NewMockPgDB() switch driver {
store, err := NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), WithDB(db), WithMigrations(Migrate)) case "postgres":
return postgres.NewMockPgDB(), postgres.Migrations
case "sqlite":
return sqlite.NewMockSqliteDB(), sqlite.Migrations
}
return nil, nil
}
func testDbStore(t *testing.T, db *sql.DB, migrationFn func(*sql.DB) error) {
store, err := persistence.NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migrationFn))
require.NoError(t, err) require.NoError(t, err)
err = store.Start(context.Background(), timesource.NewDefaultClock()) err = store.Start(context.Background(), timesource.NewDefaultClock())
@ -53,9 +70,8 @@ func TestDbStore(t *testing.T) {
require.NotEmpty(t, res) require.NotEmpty(t, res)
} }
func TestStoreRetention(t *testing.T) { func testStoreRetention(t *testing.T, db *sql.DB, migrationFn func(*sql.DB) error) {
db := NewMockPgDB() store, err := persistence.NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migrationFn), persistence.WithRetentionPolicy(5, 20*time.Second))
store, err := NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), WithDB(db), WithMigrations(Migrate), WithRetentionPolicy(5, 20*time.Second))
require.NoError(t, err) require.NoError(t, err)
err = store.Start(context.Background(), timesource.NewDefaultClock()) err = store.Start(context.Background(), timesource.NewDefaultClock())
@ -78,7 +94,7 @@ func TestStoreRetention(t *testing.T) {
// This step simulates starting go-waku again from scratch // This step simulates starting go-waku again from scratch
store, err = NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), WithDB(db), WithRetentionPolicy(5, 40*time.Second)) store, err = persistence.NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), persistence.WithDB(db), persistence.WithRetentionPolicy(5, 40*time.Second))
require.NoError(t, err) require.NoError(t, err)
err = store.Start(context.Background(), timesource.NewDefaultClock()) err = store.Start(context.Background(), timesource.NewDefaultClock())
@ -96,10 +112,8 @@ func TestStoreRetention(t *testing.T) {
require.Equal(t, msgCount, 3) require.Equal(t, msgCount, 3)
} }
func TestQuery(t *testing.T) { func testQuery(t *testing.T, db *sql.DB, migrationFn func(*sql.DB) error) {
db := NewMockPgDB() store, err := persistence.NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migrationFn), persistence.WithRetentionPolicy(5, 20*time.Second))
store, err := NewDBStore(prometheus.DefaultRegisterer, utils.Logger(), WithDB(db), WithMigrations(Migrate), WithRetentionPolicy(5, 20*time.Second))
require.NoError(t, err) require.NoError(t, err)
insertTime := time.Now() insertTime := time.Now()