migrate/driver/postgres/postgres.go

105 lines
2.0 KiB
Go
Raw Normal View History

2014-08-11 03:42:57 +02:00
package postgres
import (
"database/sql"
2014-08-12 00:58:30 +02:00
"fmt"
2014-08-11 03:42:57 +02:00
_ "github.com/lib/pq"
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
)
type Driver struct {
db *sql.DB
}
const tableName = "schema_migrations"
func (driver *Driver) Initialize(url string) error {
db, err := sql.Open("postgres", url)
if err != nil {
return err
}
if err := db.Ping(); err != nil {
return err
}
driver.db = db
if err := driver.ensureVersionTableExists(); err != nil {
return err
}
return nil
}
func (driver *Driver) ensureVersionTableExists() error {
if _, err := driver.db.Exec(`CREATE TABLE IF NOT EXISTS ` + tableName + ` (
version int not null primary key
);`); err != nil {
return err
}
return nil
}
func (driver *Driver) FilenameExtension() string {
return "sql"
}
2014-08-12 00:58:30 +02:00
func (driver *Driver) Migrate(files file.Files, pipe chan interface{}) {
defer close(pipe)
2014-08-11 03:42:57 +02:00
for _, f := range files {
tx, err := driver.db.Begin()
if err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
return
2014-08-11 03:42:57 +02:00
}
if f.Direction == direction.Up {
if _, err := tx.Exec(`INSERT INTO `+tableName+` (version) VALUES ($1)`, f.Version); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
if err := tx.Rollback(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
}
2014-08-12 00:58:30 +02:00
return
2014-08-11 03:42:57 +02:00
}
} else if f.Direction == direction.Down {
if _, err := tx.Exec(`DELETE FROM `+tableName+` WHERE version=$1`, f.Version); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
if err := tx.Rollback(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
}
2014-08-12 00:58:30 +02:00
return
2014-08-11 03:42:57 +02:00
}
}
f.Read()
if _, err := tx.Exec(string(f.Content)); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
if err := tx.Rollback(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
}
2014-08-12 00:58:30 +02:00
return
2014-08-11 03:42:57 +02:00
}
2014-08-12 00:58:30 +02:00
pipe <- fmt.Sprintf("Applied %s", f.FileName)
2014-08-11 03:42:57 +02:00
if err := tx.Commit(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
return
2014-08-11 03:42:57 +02:00
}
}
2014-08-12 00:58:30 +02:00
return
2014-08-11 03:42:57 +02:00
}
func (driver *Driver) Version() (uint64, error) {
var version uint64
err := driver.db.QueryRow(`SELECT version FROM ` + tableName + ` ORDER BY version DESC`).Scan(&version)
switch {
case err == sql.ErrNoRows:
return 0, nil
case err != nil:
return 0, err
default:
return version, nil
}
}