allow postgres migrations_table config

This commit is contained in:
Matthias Kadenbach 2017-02-16 11:06:11 -08:00
parent 6394299937
commit 2031939bfc
No known key found for this signature in database
GPG Key ID: DC1F4DC6D31A7031
1 changed files with 20 additions and 9 deletions

View File

@ -16,7 +16,7 @@ func init() {
database.Register("postgres", &Postgres{})
}
var MigrationsTable = "schema_migrations"
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
@ -26,7 +26,8 @@ var (
)
type Config struct {
DatabaseName string
MigrationsTable string
DatabaseName string
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
@ -46,6 +47,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.DatabaseName = databaseName
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
px := &Postgres{
db: instance,
config: config,
@ -77,8 +82,14 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
@ -179,14 +190,14 @@ func (p *Postgres) saveVersion(version int, dirty bool) error {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE "` + MigrationsTable + `"`
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
if _, err := p.db.Exec(query); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query = `INSERT INTO "` + MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
if _, err := p.db.Exec(query, version, dirty); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
@ -201,7 +212,7 @@ func (p *Postgres) saveVersion(version int, dirty bool) error {
}
func (p *Postgres) isDirty() (bool, error) {
query := `SELECT dirty FROM "` + MigrationsTable + `" LIMIT 1`
query := `SELECT dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
var dirty bool
err := p.db.QueryRow(query).Scan(&dirty)
switch {
@ -222,7 +233,7 @@ func (p *Postgres) isDirty() (bool, error) {
}
func (p *Postgres) Version() (int, error) {
query := `SELECT version FROM "` + MigrationsTable + `" LIMIT 1`
query := `SELECT version FROM "` + p.config.MigrationsTable + `" LIMIT 1`
var version uint64
err := p.db.QueryRow(query).Scan(&version)
switch {
@ -283,7 +294,7 @@ func (p *Postgres) ensureVersionTable() error {
// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, MigrationsTable).Scan(&count); err != nil {
if err := p.db.QueryRow(query, p.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
@ -291,7 +302,7 @@ func (p *Postgres) ensureVersionTable() error {
}
// if not, create the empty migration table
query = `CREATE TABLE "` + MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
if _, err := p.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}