refactor: remove unused function and simplify code related to creating db and migrations

This commit is contained in:
Richard Ramos 2023-08-09 13:23:44 -04:00 committed by richΛrd
parent e56f54252f
commit e0e4a2fa87
9 changed files with 29 additions and 76 deletions

View File

@ -12,10 +12,10 @@ import (
func MemoryDB(t *testing.T) *persistence.DBStore {
var db *sql.DB
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration))
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err)
return dbStore

View File

@ -13,29 +13,6 @@ import (
"go.uber.org/zap"
)
// WithDB is a DBOption that lets you use a postgresql DBStore and run migrations
func WithDB(dburl string, migrate bool, shouldVacuum bool) persistence.DBOption {
return func(d *persistence.DBStore) error {
driverOption := persistence.WithDriver("pgx", dburl)
err := driverOption(d)
if err != nil {
return err
}
if !migrate {
return nil
}
migrationOpt := persistence.WithMigrations(Migrate)
err = migrationOpt(d)
if err != nil {
return err
}
return nil
}
}
func executeVacuum(db *sql.DB, logger *zap.Logger) error {
logger.Info("starting PostgreSQL database vacuuming")
_, err := db.Exec("VACUUM FULL")
@ -47,20 +24,20 @@ func executeVacuum(db *sql.DB, logger *zap.Logger) error {
}
// NewDB connects to postgres DB in the specified path
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*sql.DB) error, error) {
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, error) {
db, err := sql.Open("pgx", dburl)
if err != nil {
return nil, nil, err
return nil, err
}
if shouldVacuum {
err := executeVacuum(db, logger)
if err != nil {
return nil, nil, err
return nil, err
}
}
return db, Migrate, nil
return db, nil
}
func migrationDriver(db *sql.DB) (database.Driver, error) {
@ -69,8 +46,8 @@ func migrationDriver(db *sql.DB) (database.Driver, error) {
})
}
// Migrate is the function used for DB migration with postgres driver
func Migrate(db *sql.DB) error {
// Migrations is the function used for DB migration with postgres driver
func Migrations(db *sql.DB) error {
migrationDriver, err := migrationDriver(db)
if err != nil {
return err

View File

@ -30,32 +30,6 @@ func addSqliteURLDefaults(dburl string) string {
return dburl
}
// WithDB is a DBOption that lets you use a sqlite3 DBStore and run migrations
func WithDB(dburl string, migrate bool) persistence.DBOption {
return func(d *persistence.DBStore) error {
driverOption := persistence.WithDriver("sqlite3", addSqliteURLDefaults(dburl), persistence.ConnectionPoolOptions{
// Disable concurrent access as not supported by the driver
MaxOpenConnections: 1,
})
err := driverOption(d)
if err != nil {
return err
}
if !migrate {
return nil
}
migrationOpt := persistence.WithMigrations(Migrate)
err = migrationOpt(d)
if err != nil {
return err
}
return nil
}
}
func executeVacuum(db *sql.DB, logger *zap.Logger) error {
logger.Info("starting sqlite database vacuuming")
_, err := db.Exec("VACUUM")
@ -67,10 +41,10 @@ func executeVacuum(db *sql.DB, logger *zap.Logger) error {
}
// NewDB creates a sqlite3 DB in the specified path
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*sql.DB) error, error) {
func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, error) {
db, err := sql.Open("sqlite3", addSqliteURLDefaults(dburl))
if err != nil {
return nil, nil, err
return nil, err
}
// Disable concurrent access as not supported by the driver
@ -79,11 +53,11 @@ func NewDB(dburl string, shouldVacuum bool, logger *zap.Logger) (*sql.DB, func(*
if shouldVacuum {
err := executeVacuum(db, logger)
if err != nil {
return nil, nil, err
return nil, err
}
}
return db, Migrate, nil
return db, nil
}
func migrationDriver(db *sql.DB) (database.Driver, error) {
@ -92,8 +66,8 @@ func migrationDriver(db *sql.DB) (database.Driver, error) {
})
}
// Migrate is the function used for DB migration with sqlite driver
func Migrate(db *sql.DB) error {
// Migrations is the function used for DB migration with sqlite driver
func Migrations(db *sql.DB) error {
migrationDriver, err := migrationDriver(db)
if err != nil {
return err

View File

@ -121,9 +121,11 @@ func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
}
}
type MigrationFn func(db *sql.DB) error
// WithMigrations is a DBOption used to determine if migrations should
// be executed, and what driver to use
func WithMigrations(migrationFn func(db *sql.DB) error) DBOption {
func WithMigrations(migrationFn MigrationFn) DBOption {
return func(d *DBStore) error {
d.enableMigrations = true
d.migrationFn = migrationFn

View File

@ -49,9 +49,11 @@ func ExtractDBAndMigration(databaseURL string, dbSettings DBSettings, logger *za
dbParams := dbURLParts[1]
switch dbEngine {
case "sqlite3":
db, migrationFn, err = sqlite.NewDB(dbParams, dbSettings.Vacuum, logger)
db, err = sqlite.NewDB(dbParams, dbSettings.Vacuum, logger)
migrationFn = sqlite.Migrations
case "postgresql":
db, migrationFn, err = postgres.NewDB(dbURL, dbSettings.Vacuum, logger)
db, err = postgres.NewDB(dbURL, dbSettings.Vacuum, logger)
migrationFn = postgres.Migrations
default:
err = errors.New("unsupported database engine")
}

View File

@ -69,9 +69,9 @@ func TestConnectionStatusChanges(t *testing.T) {
err = node2.Start(ctx)
require.NoError(t, err)
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration))
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err)
// Node3: Relay + Store

View File

@ -230,9 +230,9 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
subs.Unsubscribe()
// NODE2: Filter Client/Store
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration))
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err)
hostAddr2, err := net.ResolveTCPAddr("tcp", "0.0.0.0:0")

View File

@ -12,10 +12,10 @@ import (
func MemoryDB(t *testing.T) *persistence.DBStore {
var db *sql.DB
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err)
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(migration))
dbStore, err := persistence.NewDBStore(utils.Logger(), persistence.WithDB(db), persistence.WithMigrations(sqlite.Migrations))
require.NoError(t, err)
return dbStore

View File

@ -3,7 +3,6 @@ package rendezvous
import (
"context"
"crypto/rand"
"database/sql"
"fmt"
"sync"
"testing"
@ -46,11 +45,10 @@ func TestRendezvous(t *testing.T) {
host1, err := tests.MakeHost(ctx, port1, rand.Reader)
require.NoError(t, err)
var db *sql.DB
db, migration, err := sqlite.NewDB(":memory:", false, utils.Logger())
db, err := sqlite.NewDB(":memory:", false, utils.Logger())
require.NoError(t, err)
err = migration(db)
err = sqlite.Migrations(db)
require.NoError(t, err)
rdb := NewDB(ctx, db, utils.Logger())