Replace 'redshift' with 'redshift2'

This addresses https://github.com/golang-migrate/migrate/issues/90 . The
exported Redshift object no longer exports an embedde 'Driver' however,
so some more work is needed to make this backwards compatible.
This commit is contained in:
Andrei Mackenzie 2018-11-03 19:01:37 -04:00
parent 4e098f74cd
commit 1d54cf5f18
21 changed files with 308 additions and 367 deletions

View File

@ -1,5 +1,5 @@
SOURCE ?= file go_bindata github aws_s3 google_cloud_storage godoc_vfs
DATABASE ?= postgres mysql redshift redshift2 cassandra spanner cockroachdb clickhouse
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?=
REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)")

View File

@ -1,7 +0,0 @@
// +build redshift2
package main
import (
_ "github.com/golang-migrate/migrate/v4/database/redshift2"
)

View File

@ -1,6 +1,21 @@
Redshift
===
# Redshift
This provides a Redshift driver for migrations. It is used whenever the URL of the database starts with `redshift://`.
`redshift://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5439) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
Redshift is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior.

View File

@ -1,46 +1,312 @@
// +build go1.9
package redshift
import (
"net/url"
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strconv"
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/lib/pq"
)
// init registers the driver under the name 'redshift'
func init() {
db := new(Redshift)
db.Driver = new(postgres.Postgres)
database.Register("redshift", db)
db := Redshift{}
database.Register("redshift", &db)
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
)
type Config struct {
MigrationsTable string
DatabaseName string
}
// Redshift is a wrapper around the PostgreSQL driver which implements Redshift-specific behavior.
//
// Currently, the only different behaviour is the lack of locking in Redshift. The (Un)Lock() method(s) have been overridden from the PostgreSQL adapter to simply return nil.
type Redshift struct {
// The wrapped PostgreSQL driver.
database.Driver
isLocked bool
conn *sql.Conn
db *sql.DB
// Open and WithInstance need to garantuee that config is never nil
config *Config
}
// Open implements the database.Driver interface by parsing the URL, switching the scheme from "redshift" to "postgres", and delegating to the underlying PostgreSQL driver.
func (driver *Redshift) Open(dsn string) (database.Driver, error) {
parsed, err := url.Parse(dsn)
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
parsed.Scheme = "postgres"
psql, err := driver.Driver.Open(parsed.String())
px := &Redshift{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Redshift) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
purl.Scheme = "postgres"
db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
return &Redshift{Driver: psql}, nil
migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
}
return px, nil
}
// Lock implements the database.Driver interface by not locking and returning nil.
func (driver *Redshift) Lock() error { return nil }
func (p *Redshift) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
// Unlock implements the database.Driver interface by not unlocking and returning nil.
func (driver *Redshift) Unlock() error { return nil }
// Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html
func (p *Redshift) Lock() error {
if p.isLocked {
return database.ErrLocked
}
p.isLocked = true
return nil
}
func (p *Redshift) Unlock() error {
p.isLocked = false
return nil
}
func (p *Redshift) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Redshift) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `DELETE FROM "` + p.config.MigrationsTable + `"`
if _, err := tx.Exec(query); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Redshift) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Redshift) Drop() error {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer tables.Close()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := p.ensureVersionTable(); err != nil {
return err
}
}
return nil
}
func (p *Redshift) ensureVersionTable() error {
// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
// if not, create the empty migration table
query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}

View File

@ -1,4 +1,4 @@
package redshift2
package redshift
// error codes https://github.com/lib/pq/blob/master/error.go

View File

@ -1,21 +0,0 @@
# Redshift
`redshift://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5439) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
Redshift is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior.

View File

@ -1,312 +0,0 @@
// +build go1.9
package redshift2
import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strconv"
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/lib/pq"
)
func init() {
db := Redshift{}
database.Register("redshift2", &db)
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
)
type Config struct {
MigrationsTable string
DatabaseName string
}
type Redshift struct {
isLocked bool
conn *sql.Conn
db *sql.DB
// Open and WithInstance need to garantuee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
px := &Redshift{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Redshift) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
purl.Scheme = "postgres"
db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
}
return px, nil
}
func (p *Redshift) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
// Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html
func (p *Redshift) Lock() error {
if p.isLocked {
return database.ErrLocked
}
p.isLocked = true
return nil
}
func (p *Redshift) Unlock() error {
p.isLocked = false
return nil
}
func (p *Redshift) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Redshift) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `DELETE FROM "` + p.config.MigrationsTable + `"`
if _, err := tx.Exec(query); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Redshift) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Redshift) Drop() error {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer tables.Close()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := p.ensureVersionTable(); err != nil {
return err
}
}
return nil
}
func (p *Redshift) ensureVersionTable() error {
// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
// if not, create the empty migration table
query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}