migrate/driver/postgres/postgres.go

122 lines
2.6 KiB
Go
Raw Normal View History

2014-08-13 02:38:29 +02:00
// Package postgres implements the Driver interface.
2014-08-11 03:42:57 +02:00
package postgres
import (
"database/sql"
"errors"
"fmt"
"github.com/lib/pq"
2014-08-11 03:42:57 +02:00
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
"strconv"
2014-08-11 03:42:57 +02:00
)
type Driver struct {
db *sql.DB
}
const tableName = "schema_migrations"
func (driver *Driver) Initialize(url string) error {
db, err := sql.Open("postgres", url)
if err != nil {
return err
}
if err := db.Ping(); err != nil {
return err
}
driver.db = db
if err := driver.ensureVersionTableExists(); err != nil {
return err
}
return nil
}
2014-08-25 17:44:45 +02:00
func (driver *Driver) Close() error {
if err := driver.db.Close(); err != nil {
return err
}
return nil
}
2014-08-11 03:42:57 +02:00
func (driver *Driver) ensureVersionTableExists() error {
2014-08-25 16:49:25 +02:00
if _, err := driver.db.Exec("CREATE TABLE IF NOT EXISTS " + tableName + " (version int not null primary key);"); err != nil {
2014-08-11 03:42:57 +02:00
return err
}
return nil
}
func (driver *Driver) FilenameExtension() string {
return "sql"
}
func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
2014-08-12 00:58:30 +02:00
defer close(pipe)
pipe <- f
2014-08-11 03:42:57 +02:00
tx, err := driver.db.Begin()
if err != nil {
pipe <- err
return
}
2014-08-11 03:42:57 +02:00
if f.Direction == direction.Up {
2014-08-25 16:49:25 +02:00
if _, err := tx.Exec("INSERT INTO "+tableName+" (version) VALUES ($1)", f.Version); err != nil {
2014-08-12 22:20:17 +02:00
pipe <- err
if err := tx.Rollback(); err != nil {
pipe <- err
}
2014-08-12 22:20:17 +02:00
return
}
} else if f.Direction == direction.Down {
2014-08-25 16:49:25 +02:00
if _, err := tx.Exec("DELETE FROM "+tableName+" WHERE version=$1", f.Version); err != nil {
pipe <- err
2014-08-11 03:42:57 +02:00
if err := tx.Rollback(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
}
2014-08-12 00:58:30 +02:00
return
2014-08-11 03:42:57 +02:00
}
}
if err := f.ReadContent(); err != nil {
pipe <- err
return
}
if _, err := tx.Exec(string(f.Content)); err != nil {
pqErr := err.(*pq.Error)
offset, err := strconv.Atoi(pqErr.Position)
if err == nil && offset >= 0 {
lineNo, columnNo := file.LineColumnFromOffset(f.Content, offset-1)
errorPart := file.LinesBeforeAndAfter(f.Content, lineNo, 5, 5, true)
pipe <- errors.New(fmt.Sprintf("%s %v: %s in line %v, column %v:\n\n%s", pqErr.Severity, pqErr.Code, pqErr.Message, lineNo, columnNo, string(errorPart)))
} else {
pipe <- errors.New(fmt.Sprintf("%s %v: %s", pqErr.Severity, pqErr.Code, pqErr.Message))
}
2014-08-11 03:42:57 +02:00
if err := tx.Rollback(); err != nil {
2014-08-12 00:58:30 +02:00
pipe <- err
2014-08-11 03:42:57 +02:00
}
return
}
if err := tx.Commit(); err != nil {
pipe <- err
return
2014-08-11 03:42:57 +02:00
}
}
func (driver *Driver) Version() (uint64, error) {
var version uint64
2014-08-25 16:49:25 +02:00
err := driver.db.QueryRow("SELECT version FROM " + tableName + " ORDER BY version DESC").Scan(&version)
2014-08-11 03:42:57 +02:00
switch {
case err == sql.ErrNoRows:
return 0, nil
case err != nil:
return 0, err
default:
return version, nil
}
}