mirror of
https://github.com/status-im/migrate.git
synced 2025-02-24 00:38:07 +00:00
This addresses https://github.com/golang-migrate/migrate/issues/90 . The exported Redshift object no longer exports an embedde 'Driver' however, so some more work is needed to make this backwards compatible.
313 lines
7.4 KiB
Go
313 lines
7.4 KiB
Go
// +build go1.9
|
|
|
|
package redshift
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
nurl "net/url"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
func init() {
|
|
db := Redshift{}
|
|
database.Register("redshift", &db)
|
|
}
|
|
|
|
var DefaultMigrationsTable = "schema_migrations"
|
|
|
|
var (
|
|
ErrNilConfig = fmt.Errorf("no config")
|
|
ErrNoDatabaseName = fmt.Errorf("no database name")
|
|
ErrNoSchema = fmt.Errorf("no schema")
|
|
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
|
)
|
|
|
|
type Config struct {
|
|
MigrationsTable string
|
|
DatabaseName string
|
|
}
|
|
|
|
type Redshift struct {
|
|
isLocked bool
|
|
conn *sql.Conn
|
|
db *sql.DB
|
|
|
|
// Open and WithInstance need to garantuee that config is never nil
|
|
config *Config
|
|
}
|
|
|
|
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|
if config == nil {
|
|
return nil, ErrNilConfig
|
|
}
|
|
|
|
if err := instance.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := `SELECT CURRENT_DATABASE()`
|
|
var databaseName string
|
|
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
|
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
if len(databaseName) == 0 {
|
|
return nil, ErrNoDatabaseName
|
|
}
|
|
|
|
config.DatabaseName = databaseName
|
|
|
|
if len(config.MigrationsTable) == 0 {
|
|
config.MigrationsTable = DefaultMigrationsTable
|
|
}
|
|
|
|
conn, err := instance.Conn(context.Background())
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
px := &Redshift{
|
|
conn: conn,
|
|
db: instance,
|
|
config: config,
|
|
}
|
|
|
|
if err := px.ensureVersionTable(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return px, nil
|
|
}
|
|
|
|
func (p *Redshift) Open(url string) (database.Driver, error) {
|
|
purl, err := nurl.Parse(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
purl.Scheme = "postgres"
|
|
|
|
db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrationsTable := purl.Query().Get("x-migrations-table")
|
|
if len(migrationsTable) == 0 {
|
|
migrationsTable = DefaultMigrationsTable
|
|
}
|
|
|
|
px, err := WithInstance(db, &Config{
|
|
DatabaseName: purl.Path,
|
|
MigrationsTable: migrationsTable,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return px, nil
|
|
}
|
|
|
|
func (p *Redshift) Close() error {
|
|
connErr := p.conn.Close()
|
|
dbErr := p.db.Close()
|
|
if connErr != nil || dbErr != nil {
|
|
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html
|
|
func (p *Redshift) Lock() error {
|
|
if p.isLocked {
|
|
return database.ErrLocked
|
|
}
|
|
p.isLocked = true
|
|
return nil
|
|
}
|
|
|
|
func (p *Redshift) Unlock() error {
|
|
p.isLocked = false
|
|
return nil
|
|
}
|
|
|
|
func (p *Redshift) Run(migration io.Reader) error {
|
|
migr, err := ioutil.ReadAll(migration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// run migration
|
|
query := string(migr[:])
|
|
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
|
if pgErr, ok := err.(*pq.Error); ok {
|
|
var line uint
|
|
var col uint
|
|
var lineColOK bool
|
|
if pgErr.Position != "" {
|
|
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
|
|
line, col, lineColOK = computeLineFromPos(query, int(pos))
|
|
}
|
|
}
|
|
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
|
if lineColOK {
|
|
message = fmt.Sprintf("%s (column %d)", message, col)
|
|
}
|
|
if pgErr.Detail != "" {
|
|
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
|
}
|
|
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
|
|
}
|
|
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
|
// replace crlf with lf
|
|
s = strings.Replace(s, "\r\n", "\n", -1)
|
|
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
|
runes := []rune(s)
|
|
if pos > len(runes) {
|
|
return 0, 0, false
|
|
}
|
|
sel := runes[:pos]
|
|
line = uint(runesCount(sel, newLine) + 1)
|
|
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
|
return line, col, true
|
|
}
|
|
|
|
const newLine = '\n'
|
|
|
|
func runesCount(input []rune, target rune) int {
|
|
var count int
|
|
for _, r := range input {
|
|
if r == target {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func runesLastIndex(input []rune, target rune) int {
|
|
for i := len(input) - 1; i >= 0; i-- {
|
|
if input[i] == target {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (p *Redshift) SetVersion(version int, dirty bool) error {
|
|
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
|
}
|
|
|
|
query := `DELETE FROM "` + p.config.MigrationsTable + `"`
|
|
if _, err := tx.Exec(query); err != nil {
|
|
tx.Rollback()
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
if version >= 0 {
|
|
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
|
|
if _, err := tx.Exec(query, version, dirty); err != nil {
|
|
tx.Rollback()
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Redshift) Version() (version int, dirty bool, err error) {
|
|
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
|
|
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
|
switch {
|
|
case err == sql.ErrNoRows:
|
|
return database.NilVersion, false, nil
|
|
|
|
case err != nil:
|
|
if e, ok := err.(*pq.Error); ok {
|
|
if e.Code.Name() == "undefined_table" {
|
|
return database.NilVersion, false, nil
|
|
}
|
|
}
|
|
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
|
|
default:
|
|
return version, dirty, nil
|
|
}
|
|
}
|
|
|
|
func (p *Redshift) Drop() error {
|
|
// select all tables in current schema
|
|
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
|
tables, err := p.conn.QueryContext(context.Background(), query)
|
|
if err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
defer tables.Close()
|
|
|
|
// delete one table after another
|
|
tableNames := make([]string, 0)
|
|
for tables.Next() {
|
|
var tableName string
|
|
if err := tables.Scan(&tableName); err != nil {
|
|
return err
|
|
}
|
|
if len(tableName) > 0 {
|
|
tableNames = append(tableNames, tableName)
|
|
}
|
|
}
|
|
|
|
if len(tableNames) > 0 {
|
|
// delete one by one ...
|
|
for _, t := range tableNames {
|
|
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
|
|
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
}
|
|
if err := p.ensureVersionTable(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Redshift) ensureVersionTable() error {
|
|
// check if migration table exists
|
|
var count int
|
|
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
|
|
if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
if count == 1 {
|
|
return nil
|
|
}
|
|
|
|
// if not, create the empty migration table
|
|
query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
|
|
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
return nil
|
|
}
|