refactor: remove unused function and simplify code related to creating db and migrations

This commit is contained in:
Richard Ramos 2023-08-09 13:23:44 -04:00 committed by richΛrd
parent e56f54252f
commit e0e4a2fa87
9 changed files with 29 additions and 76 deletions

View File

@ -12,10 +12,10 @@ import (
func MemoryDB(t *testing.T) *persistence.DBStore { func MemoryDB(t *testing.T) *persistence.DBStore {
var db *sql.DB var db *sql.DB
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger()) db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err) require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration)) dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err) require.NoError(t, err)
return dbStore return dbStore

View File

@ -13,29 +13,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// WithDB is a DBOption that lets you use a postgresql DBStore and run migrations
func WithDB(dburl string, migrate bool, shouldVacuum bool) persistence.DBOption {
return func(d *persistence.DBStore) error {
driverOption := persistence.WithDriver("pgx", dburl)
err := driverOption(d)
if err != nil {
return err
}
if !migrate {
return nil
}
migrationOpt := persistence.WithMigrations(Migrate)
err = migrationOpt(d)
if err != nil {
return err
}
return nil
}
}
func executeVacuum(db *sql.DB, logger *zap.Logger) error { func executeVacuum(db *sql.DB, logger *zap.Logger) error {
logger.Info("starting PostgreSQL database vacuuming") logger.Info("starting PostgreSQL database vacuuming")
_, err := db.Exec("VACUUM FULL") _, err := db.Exec("VACUUM FULL")
@ -47,20 +24,20 @@ func executeVacuum(db *sql.DB, logger *zap.Logger) error {
} }
// NewDB connects to postgres DB in the specified path // NewDB connects to postgres DB in the specified path
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*sql.DB) error, error) { func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, error) {
db, err := sql.Open("pgx", dburl) db, err := sql.Open("pgx", dburl)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if shouldVacuum { if shouldVacuum {
err := executeVacuum(db, logger) err := executeVacuum(db, logger)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
} }
return db, Migrate, nil return db, nil
} }
func migrationDriver(db *sql.DB) (database.Driver, error) { func migrationDriver(db *sql.DB) (database.Driver, error) {
@ -69,8 +46,8 @@ func migrationDriver(db *sql.DB) (database.Driver, error) {
}) })
} }
// Migrate is the function used for DB migration with postgres driver // Migrations is the function used for DB migration with postgres driver
func Migrate(db *sql.DB) error { func Migrations(db *sql.DB) error {
migrationDriver, err := migrationDriver(db) migrationDriver, err := migrationDriver(db)
if err != nil { if err != nil {
return err return err

View File

@ -30,32 +30,6 @@ func addSqliteURLDefaults(dburl string) string {
return dburl return dburl
} }
// WithDB is a DBOption that lets you use a sqlite3 DBStore and run migrations
func WithDB(dburl string, migrate bool) persistence.DBOption {
return func(d *persistence.DBStore) error {
driverOption := persistence.WithDriver("sqlite3", addSqliteURLDefaults(dburl), persistence.ConnectionPoolOptions{
// Disable concurrent access as not supported by the driver
MaxOpenConnections: 1,
})
err := driverOption(d)
if err != nil {
return err
}
if !migrate {
return nil
}
migrationOpt := persistence.WithMigrations(Migrate)
err = migrationOpt(d)
if err != nil {
return err
}
return nil
}
}
func executeVacuum(db *sql.DB, logger *zap.Logger) error { func executeVacuum(db *sql.DB, logger *zap.Logger) error {
logger.Info("starting sqlite database vacuuming") logger.Info("starting sqlite database vacuuming")
_, err := db.Exec("VACUUM") _, err := db.Exec("VACUUM")
@ -67,10 +41,10 @@ func executeVacuum(db *sql.DB, logger *zap.Logger) error {
} }
// NewDB creates a sqlite3 DB in the specified path // NewDB creates a sqlite3 DB in the specified path
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*sql.DB) error, error) { func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, error) {
db, err := sql.Open("sqlite3", addSqliteURLDefaults(dburl)) db, err := sql.Open("sqlite3", addSqliteURLDefaults(dburl))
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
// Disable concurrent access as not supported by the driver // Disable concurrent access as not supported by the driver
@ -79,11 +53,11 @@ func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*
if shouldVacuum { if shouldVacuum {
err := executeVacuum(db, logger) err := executeVacuum(db, logger)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
} }
return db, Migrate, nil return db, nil
} }
func migrationDriver(db *sql.DB) (database.Driver, error) { func migrationDriver(db *sql.DB) (database.Driver, error) {
@ -92,8 +66,8 @@ func migrationDriver(db *sql.DB) (database.Driver, error) {
}) })
} }
// Migrate is the function used for DB migration with sqlite driver // Migrations is the function used for DB migration with sqlite driver
func Migrate(db *sql.DB) error { func Migrations(db *sql.DB) error {
migrationDriver, err := migrationDriver(db) migrationDriver, err := migrationDriver(db)
if err != nil { if err != nil {
return err return err

View File

@ -121,9 +121,11 @@ func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
} }
} }
type MigrationFn func(db *sql.DB) error
// WithMigrations is a DBOption used to determine if migrations should // WithMigrations is a DBOption used to determine if migrations should
// be executed, and what driver to use // be executed, and what driver to use
func WithMigrations(migrationFn func(db *sql.DB) error) DBOption { func WithMigrations(migrationFn MigrationFn) DBOption {
return func(d *DBStore) error { return func(d *DBStore) error {
d.enableMigrations = true d.enableMigrations = true
d.migrationFn = migrationFn d.migrationFn = migrationFn

View File

@ -49,9 +49,11 @@ func ExtractDBAndMigration(databaseURL string, dbSettings DBSettings, logger *za
dbParams := dbURLParts[1] dbParams := dbURLParts[1]
switch dbEngine { switch dbEngine {
case "sqlite3": case "sqlite3":
db, migrationFn, err = sqlite.NewDB(dbParams, dbSettings.Vacuum, logger) db, err = sqlite.NewDB(dbParams, dbSettings.Vacuum, logger)
migrationFn = sqlite.Migrations
case "postgresql": case "postgresql":
db, migrationFn, err = postgres.NewDB(dbURL, dbSettings.Vacuum, logger) db, err = postgres.NewDB(dbURL, dbSettings.Vacuum, logger)
migrationFn = postgres.Migrations
default: default:
err = errors.New("unsupported database engine") err = errors.New("unsupported database engine")
} }

View File

@ -69,9 +69,9 @@ func TestConnectionStatusChanges(t *testing.T) {
err = node2.Start(ctx) err = node2.Start(ctx)
require.NoError(t, err) require.NoError(t, err)
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger()) db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err) require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration)) dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err) require.NoError(t, err)
// Node3: Relay + Store // Node3: Relay + Store

View File

@ -230,9 +230,9 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
subs.Unsubscribe() subs.Unsubscribe()
// NODE2: Filter Client/Store // NODE2: Filter Client/Store
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger()) db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err) require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration)) dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err) require.NoError(t, err)
hostAddr2, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0") hostAddr2, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0")

View File

@ -12,10 +12,10 @@ import (
func MemoryDB(t *testing.T) *persistence.DBStore { func MemoryDB(t *testing.T) *persistence.DBStore {
var db *sql.DB var db *sql.DB
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger()) db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err) require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration)) dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err) require.NoError(t, err)
return dbStore return dbStore

View File

@ -3,7 +3,6 @@ package rendezvous
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"database/sql"
"fmt" "fmt"
"sync" "sync"
"testing" "testing"
@ -46,11 +45,10 @@ func TestRendezvous(t *testing.T) {
host1, err := tests.MakeHost(ctx, port1, rand.Reader) host1, err := tests.MakeHost(ctx, port1, rand.Reader)
require.NoError(t, err) require.NoError(t, err)
var db *sql.DB db, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err) require.NoError(t, err)
err = migration(db) err = sqlite.Migrations(db)
require.NoError(t, err) require.NoError(t, err)
rdb := NewDB(ctx, db, utils.Logger()) rdb := NewDB(ctx, db, utils.Logger())