mirror of https://github.com/status-im/migrate.git
251 lines
5.4 KiB
Go
251 lines
5.4 KiB
Go
// +build go1.9
|
|
|
|
package firebird
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database"
|
|
"github.com/hashicorp/go-multierror"
|
|
_ "github.com/nakagami/firebirdsql"
|
|
"io"
|
|
"io/ioutil"
|
|
nurl "net/url"
|
|
)
|
|
|
|
func init() {
|
|
db := Firebird{}
|
|
database.Register("firebird", &db)
|
|
database.Register("firebirdsql", &db)
|
|
}
|
|
|
|
var DefaultMigrationsTable = "schema_migrations"
|
|
|
|
var (
|
|
ErrNilConfig = fmt.Errorf("no config")
|
|
)
|
|
|
|
type Config struct {
|
|
DatabaseName string
|
|
MigrationsTable string
|
|
}
|
|
|
|
type Firebird struct {
|
|
// Locking and unlocking need to use the same connection
|
|
conn *sql.Conn
|
|
db *sql.DB
|
|
isLocked bool
|
|
|
|
// Open and WithInstance need to guarantee that config is never nil
|
|
config *Config
|
|
}
|
|
|
|
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|
if config == nil {
|
|
return nil, ErrNilConfig
|
|
}
|
|
|
|
if err := instance.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(config.MigrationsTable) == 0 {
|
|
config.MigrationsTable = DefaultMigrationsTable
|
|
}
|
|
|
|
conn, err := instance.Conn(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fb := &Firebird{
|
|
conn: conn,
|
|
db: instance,
|
|
config: config,
|
|
}
|
|
|
|
if err := fb.ensureVersionTable(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return fb, nil
|
|
}
|
|
|
|
func (f *Firebird) Open(dsn string) (database.Driver, error) {
|
|
purl, err := nurl.Parse(dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
px, err := WithInstance(db, &Config{
|
|
MigrationsTable: purl.Query().Get("x-migrations-table"),
|
|
DatabaseName: purl.Path,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return px, nil
|
|
}
|
|
|
|
func (f *Firebird) Close() error {
|
|
connErr := f.conn.Close()
|
|
dbErr := f.db.Close()
|
|
if connErr != nil || dbErr != nil {
|
|
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *Firebird) Lock() error {
|
|
if f.isLocked {
|
|
return database.ErrLocked
|
|
}
|
|
f.isLocked = true
|
|
return nil
|
|
}
|
|
|
|
func (f *Firebird) Unlock() error {
|
|
f.isLocked = false
|
|
return nil
|
|
}
|
|
|
|
func (f *Firebird) Run(migration io.Reader) error {
|
|
migr, err := ioutil.ReadAll(migration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// run migration
|
|
query := string(migr[:])
|
|
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
|
|
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *Firebird) SetVersion(version int, dirty bool) error {
|
|
if version < 0 {
|
|
return nil
|
|
}
|
|
|
|
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
|
DELETE FROM "%v";
|
|
INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
|
|
END;`,
|
|
f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
|
|
|
|
if _, err := f.conn.ExecContext(context.Background(), query, version, btoi(dirty)); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *Firebird) Version() (version int, dirty bool, err error) {
|
|
var d int
|
|
query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
|
|
err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
|
|
switch {
|
|
case err == sql.ErrNoRows:
|
|
return database.NilVersion, false, nil
|
|
case err != nil:
|
|
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
|
|
default:
|
|
return version, itob(d), nil
|
|
}
|
|
}
|
|
|
|
func (f *Firebird) Drop() (err error) {
|
|
// select all tables
|
|
query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
|
|
tables, err := f.conn.QueryContext(context.Background(), query)
|
|
if err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
defer func() {
|
|
if errClose := tables.Close(); errClose != nil {
|
|
err = multierror.Append(err, errClose)
|
|
}
|
|
}()
|
|
|
|
// delete one table after another
|
|
tableNames := make([]string, 0)
|
|
for tables.Next() {
|
|
var tableName string
|
|
if err := tables.Scan(&tableName); err != nil {
|
|
return err
|
|
}
|
|
if len(tableName) > 0 {
|
|
tableNames = append(tableNames, tableName)
|
|
}
|
|
}
|
|
|
|
// delete one by one ...
|
|
for _, t := range tableNames {
|
|
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
|
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
|
|
execute statement 'drop table "%v"';
|
|
END;`,
|
|
t, t)
|
|
|
|
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
|
func (f *Firebird) ensureVersionTable() (err error) {
|
|
if err = f.Lock(); err != nil {
|
|
return err
|
|
}
|
|
|
|
defer func() {
|
|
if e := f.Unlock(); e != nil {
|
|
if err == nil {
|
|
err = e
|
|
} else {
|
|
err = multierror.Append(err, e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
|
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
|
|
execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
|
|
END;`,
|
|
f.config.MigrationsTable, f.config.MigrationsTable)
|
|
|
|
if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// btoi converts bool to int
|
|
func btoi(v bool) int {
|
|
if v {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// itob converts int to bool
|
|
func itob(v int) bool {
|
|
return v != 0
|
|
}
|