migrate/database/sqlserver/sqlserver.go

340 lines
9.0 KiB
Go

package sqlserver
import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
mssql "github.com/denisenkom/go-mssqldb" // mssql support
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
)
func init() {
database.Register("sqlserver", &SQLServer{})
}
// DefaultMigrationsTable is the name of the migrations table in the database
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
)
var lockErrorMap = map[mssql.ReturnStatus]string{
-1: "The lock request timed out.",
-2: "The lock request was canceled.",
-3: "The lock request was chosen as a deadlock victim.",
-999: "Parameter validation or other call error.",
}
// Config for database
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
}
// SQL Server connection
type SQLServer struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked bool
// Open and WithInstance need to garantuee that config is never nil
config *Config
}
// WithInstance returns a database instance from an already created database connection.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
query := `SELECT DB_NAME()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
query = `SELECT SCHEMA_NAME()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(schemaName) == 0 {
return nil, ErrNoSchema
}
config.SchemaName = schemaName
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
ss := &SQLServer{
conn: conn,
db: instance,
config: config,
}
if err := ss.ensureVersionTable(); err != nil {
return nil, err
}
return ss, nil
}
// Open a connection to the database
func (ss *SQLServer) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
}
return px, nil
}
// Close the database connection
func (ss *SQLServer) Close() error {
connErr := ss.conn.Close()
dbErr := ss.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
func (ss *SQLServer) Lock() error {
if ss.isLocked {
return database.ErrLocked
}
aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
if err != nil {
return err
}
// This will either obtain the lock immediately and return true,
// or return false if the lock cannot be acquired immediately.
// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
var status mssql.ReturnStatus
if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
ss.isLocked = true
return nil
} else if err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
} else {
return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
}
}
// Unlock froms the migration lock from the database
func (ss *SQLServer) Unlock() error {
if !ss.isLocked {
return nil
}
aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
if err != nil {
return err
}
// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
ss.isLocked = false
return nil
}
// Run the migrations for the database
func (ss *SQLServer) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
// SetVersion for the current database
func (ss *SQLServer) SetVersion(version int, dirty bool) error {
tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
var dirtyBit int
if dirty {
dirtyBit = 1
}
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
if _, err := tx.Exec(query, version, dirtyBit); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
// Version of the current database state
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
// FIXME: convert to MSSQL error
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
// Drop all tables from the database.
func (ss *SQLServer) Drop() error {
// drop all referential integrity constraints
query := `
DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
SET @Cursor = CURSOR FAST_FORWARD FOR
SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
WHILE (@@FETCH_STATUS = 0)
BEGIN
Exec sp_executesql @Sql
FETCH NEXT FROM @Cursor INTO @Sql
END
CLOSE @Cursor DEALLOCATE @Cursor`
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// drop the tables
query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (ss *SQLServer) ensureVersionTable() (err error) {
if err = ss.Lock(); err != nil {
return err
}
defer func() {
if e := ss.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
query := `IF NOT EXISTS
(SELECT *
FROM sysobjects
WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
AND OBJECTPROPERTY(id, N'IsUserTable') = 1
)
CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}