mirror of https://github.com/status-im/migrate.git
340 lines
9.0 KiB
Go
340 lines
9.0 KiB
Go
package sqlserver
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
nurl "net/url"
|
|
|
|
mssql "github.com/denisenkom/go-mssqldb" // mssql support
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database"
|
|
"github.com/hashicorp/go-multierror"
|
|
)
|
|
|
|
func init() {
|
|
database.Register("sqlserver", &SQLServer{})
|
|
}
|
|
|
|
// DefaultMigrationsTable is the name of the migrations table in the database
|
|
var DefaultMigrationsTable = "schema_migrations"
|
|
|
|
var (
|
|
ErrNilConfig = fmt.Errorf("no config")
|
|
ErrNoDatabaseName = fmt.Errorf("no database name")
|
|
ErrNoSchema = fmt.Errorf("no schema")
|
|
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
|
)
|
|
|
|
var lockErrorMap = map[mssql.ReturnStatus]string{
|
|
-1: "The lock request timed out.",
|
|
-2: "The lock request was canceled.",
|
|
-3: "The lock request was chosen as a deadlock victim.",
|
|
-999: "Parameter validation or other call error.",
|
|
}
|
|
|
|
// Config for database
|
|
type Config struct {
|
|
MigrationsTable string
|
|
DatabaseName string
|
|
SchemaName string
|
|
}
|
|
|
|
// SQL Server connection
|
|
type SQLServer struct {
|
|
// Locking and unlocking need to use the same connection
|
|
conn *sql.Conn
|
|
db *sql.DB
|
|
isLocked bool
|
|
|
|
// Open and WithInstance need to garantuee that config is never nil
|
|
config *Config
|
|
}
|
|
|
|
// WithInstance returns a database instance from an already created database connection.
|
|
//
|
|
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
|
|
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|
if config == nil {
|
|
return nil, ErrNilConfig
|
|
}
|
|
|
|
if err := instance.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := `SELECT DB_NAME()`
|
|
var databaseName string
|
|
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
|
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
if len(databaseName) == 0 {
|
|
return nil, ErrNoDatabaseName
|
|
}
|
|
|
|
config.DatabaseName = databaseName
|
|
|
|
query = `SELECT SCHEMA_NAME()`
|
|
var schemaName string
|
|
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
|
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
if len(schemaName) == 0 {
|
|
return nil, ErrNoSchema
|
|
}
|
|
|
|
config.SchemaName = schemaName
|
|
|
|
if len(config.MigrationsTable) == 0 {
|
|
config.MigrationsTable = DefaultMigrationsTable
|
|
}
|
|
|
|
conn, err := instance.Conn(context.Background())
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ss := &SQLServer{
|
|
conn: conn,
|
|
db: instance,
|
|
config: config,
|
|
}
|
|
|
|
if err := ss.ensureVersionTable(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ss, nil
|
|
}
|
|
|
|
// Open a connection to the database
|
|
func (ss *SQLServer) Open(url string) (database.Driver, error) {
|
|
purl, err := nurl.Parse(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrationsTable := purl.Query().Get("x-migrations-table")
|
|
|
|
px, err := WithInstance(db, &Config{
|
|
DatabaseName: purl.Path,
|
|
MigrationsTable: migrationsTable,
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return px, nil
|
|
}
|
|
|
|
// Close the database connection
|
|
func (ss *SQLServer) Close() error {
|
|
connErr := ss.conn.Close()
|
|
dbErr := ss.db.Close()
|
|
if connErr != nil || dbErr != nil {
|
|
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
|
|
func (ss *SQLServer) Lock() error {
|
|
if ss.isLocked {
|
|
return database.ErrLocked
|
|
}
|
|
|
|
aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// This will either obtain the lock immediately and return true,
|
|
// or return false if the lock cannot be acquired immediately.
|
|
// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
|
|
query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
|
|
|
|
var status mssql.ReturnStatus
|
|
if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
|
|
ss.isLocked = true
|
|
return nil
|
|
} else if err != nil {
|
|
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
|
} else {
|
|
return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
|
|
}
|
|
}
|
|
|
|
// Unlock froms the migration lock from the database
|
|
func (ss *SQLServer) Unlock() error {
|
|
if !ss.isLocked {
|
|
return nil
|
|
}
|
|
|
|
aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
|
|
query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
|
|
if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
ss.isLocked = false
|
|
|
|
return nil
|
|
}
|
|
|
|
// Run the migrations for the database
|
|
func (ss *SQLServer) Run(migration io.Reader) error {
|
|
migr, err := ioutil.ReadAll(migration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// run migration
|
|
query := string(migr[:])
|
|
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
|
|
if msErr, ok := err.(mssql.Error); ok {
|
|
message := fmt.Sprintf("migration failed: %s", msErr.Message)
|
|
if msErr.ProcName != "" {
|
|
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
|
|
}
|
|
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
|
|
}
|
|
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetVersion for the current database
|
|
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
|
|
|
|
tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
|
}
|
|
|
|
query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
|
|
if _, err := tx.Exec(query); err != nil {
|
|
if errRollback := tx.Rollback(); errRollback != nil {
|
|
err = multierror.Append(err, errRollback)
|
|
}
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
if version >= 0 {
|
|
var dirtyBit int
|
|
if dirty {
|
|
dirtyBit = 1
|
|
}
|
|
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
|
|
if _, err := tx.Exec(query, version, dirtyBit); err != nil {
|
|
if errRollback := tx.Rollback(); errRollback != nil {
|
|
err = multierror.Append(err, errRollback)
|
|
}
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Version of the current database state
|
|
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
|
|
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
|
|
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
|
switch {
|
|
case err == sql.ErrNoRows:
|
|
return database.NilVersion, false, nil
|
|
|
|
case err != nil:
|
|
// FIXME: convert to MSSQL error
|
|
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
|
|
default:
|
|
return version, dirty, nil
|
|
}
|
|
}
|
|
|
|
// Drop all tables from the database.
|
|
func (ss *SQLServer) Drop() error {
|
|
|
|
// drop all referential integrity constraints
|
|
query := `
|
|
DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
|
|
|
|
SET @Cursor = CURSOR FAST_FORWARD FOR
|
|
SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
|
|
FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
|
|
LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
|
|
|
|
OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
|
|
|
|
WHILE (@@FETCH_STATUS = 0)
|
|
BEGIN
|
|
Exec sp_executesql @Sql
|
|
FETCH NEXT FROM @Cursor INTO @Sql
|
|
END
|
|
|
|
CLOSE @Cursor DEALLOCATE @Cursor`
|
|
|
|
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
// drop the tables
|
|
query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
|
|
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ss *SQLServer) ensureVersionTable() (err error) {
|
|
if err = ss.Lock(); err != nil {
|
|
return err
|
|
}
|
|
|
|
defer func() {
|
|
if e := ss.Unlock(); e != nil {
|
|
if err == nil {
|
|
err = e
|
|
} else {
|
|
err = multierror.Append(err, e)
|
|
}
|
|
}
|
|
}()
|
|
|
|
query := `IF NOT EXISTS
|
|
(SELECT *
|
|
FROM sysobjects
|
|
WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
|
|
AND OBJECTPROPERTY(id, N'IsUserTable') = 1
|
|
)
|
|
CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
|
|
|
|
if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
|
|
return nil
|
|
}
|