diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 0bbbb07..73c68f3 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -16,7 +16,7 @@ func init() { database.Register("postgres", &Postgres{}) } -var MigrationsTable = "schema_migrations" +var DefaultMigrationsTable = "schema_migrations" var ( ErrNilConfig = fmt.Errorf("no config") @@ -26,7 +26,8 @@ var ( ) type Config struct { - DatabaseName string + MigrationsTable string + DatabaseName string } func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { @@ -46,6 +47,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.DatabaseName = databaseName + if len(config.MigrationsTable) == 0 { + config.MigrationsTable = DefaultMigrationsTable + } + px := &Postgres{ db: instance, config: config, @@ -77,8 +82,14 @@ func (p *Postgres) Open(url string) (database.Driver, error) { return nil, err } + migrationsTable := purl.Query().Get("x-migrations-table") + if len(migrationsTable) == 0 { + migrationsTable = DefaultMigrationsTable + } + px, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, + DatabaseName: purl.Path, + MigrationsTable: migrationsTable, }) if err != nil { return nil, err @@ -179,14 +190,14 @@ func (p *Postgres) saveVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE "` + MigrationsTable + `"` + query := `TRUNCATE "` + p.config.MigrationsTable + `"` if _, err := p.db.Exec(query); err != nil { tx.Rollback() return &database.Error{OrigErr: err, Query: []byte(query)} } if version >= 0 { - query = `INSERT INTO "` + MigrationsTable + `" (version, dirty) VALUES ($1, $2)` + query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)` if _, err := p.db.Exec(query, version, dirty); err != nil { tx.Rollback() return &database.Error{OrigErr: err, Query: []byte(query)} @@ -201,7 +212,7 @@ func (p *Postgres) saveVersion(version int, dirty bool) error { } func (p *Postgres) isDirty() (bool, error) { - query := `SELECT dirty FROM "` + MigrationsTable + `" LIMIT 1` + query := `SELECT dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` var dirty bool err := p.db.QueryRow(query).Scan(&dirty) switch { @@ -222,7 +233,7 @@ func (p *Postgres) isDirty() (bool, error) { } func (p *Postgres) Version() (int, error) { - query := `SELECT version FROM "` + MigrationsTable + `" LIMIT 1` + query := `SELECT version FROM "` + p.config.MigrationsTable + `" LIMIT 1` var version uint64 err := p.db.QueryRow(query).Scan(&version) switch { @@ -283,7 +294,7 @@ func (p *Postgres) ensureVersionTable() error { // check if migration table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - if err := p.db.QueryRow(query, MigrationsTable).Scan(&count); err != nil { + if err := p.db.QueryRow(query, p.config.MigrationsTable).Scan(&count); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { @@ -291,7 +302,7 @@ func (p *Postgres) ensureVersionTable() error { } // if not, create the empty migration table - query = `CREATE TABLE "` + MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` + query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` if _, err := p.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} }