2019-02-15 16:08:38 -06:00
package mssql
import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
2019-05-19 13:37:40 +01:00
mssql "github.com/denisenkom/go-mssqldb" // mssql support
2019-02-15 16:08:38 -06:00
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
2019-05-19 13:37:40 +01:00
"github.com/hashicorp/go-multierror"
2019-02-15 16:08:38 -06:00
)
func init ( ) {
db := MSSQL { }
database . Register ( "mssql" , & db )
database . Register ( "sqlserver" , & db )
}
// DefaultMigrationsTable is the name of the migrations table in the database
2019-05-19 15:09:56 +01:00
var DefaultMigrationsTable = "schema_migrations"
2019-02-15 16:08:38 -06:00
var (
ErrNilConfig = fmt . Errorf ( "no config" )
ErrNoDatabaseName = fmt . Errorf ( "no database name" )
ErrNoSchema = fmt . Errorf ( "no schema" )
ErrDatabaseDirty = fmt . Errorf ( "database is dirty" )
)
2019-05-19 13:37:40 +01:00
var lockErrorMap = map [ int ] string {
- 1 : "The lock request timed out." ,
- 2 : "The lock request was canceled." ,
- 3 : "The lock request was chosen as a deadlock victim." ,
- 999 : "Parameter validation or other call error." ,
}
2019-02-15 16:08:38 -06:00
// Config for database
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
}
// MSSQL connection
type MSSQL struct {
// Locking and unlocking need to use the same connection
conn * sql . Conn
db * sql . DB
isLocked bool
// Open and WithInstance need to garantuee that config is never nil
config * Config
}
2019-05-19 13:37:40 +01:00
// WithInstance returns a database instance from an already created database connection
2019-02-15 16:08:38 -06:00
func WithInstance ( instance * sql . DB , config * Config ) ( database . Driver , error ) {
if config == nil {
return nil , ErrNilConfig
}
if err := instance . Ping ( ) ; err != nil {
return nil , err
}
query := ` SELECT DB_NAME() `
var databaseName string
if err := instance . QueryRow ( query ) . Scan ( & databaseName ) ; err != nil {
return nil , & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
if len ( databaseName ) == 0 {
return nil , ErrNoDatabaseName
}
config . DatabaseName = databaseName
query = ` SELECT SCHEMA_NAME() `
var schemaName string
if err := instance . QueryRow ( query ) . Scan ( & schemaName ) ; err != nil {
return nil , & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
if len ( schemaName ) == 0 {
return nil , ErrNoSchema
}
config . SchemaName = schemaName
if len ( config . MigrationsTable ) == 0 {
config . MigrationsTable = DefaultMigrationsTable
}
conn , err := instance . Conn ( context . Background ( ) )
if err != nil {
return nil , err
}
ss := & MSSQL {
conn : conn ,
db : instance ,
config : config ,
}
if err := ss . ensureVersionTable ( ) ; err != nil {
return nil , err
}
return ss , nil
}
// Open a connection to the database
func ( ss * MSSQL ) Open ( url string ) ( database . Driver , error ) {
purl , err := nurl . Parse ( url )
if err != nil {
return nil , err
}
db , err := sql . Open ( "mssql" , migrate . FilterCustomQuery ( purl ) . String ( ) )
if err != nil {
return nil , err
}
migrationsTable := purl . Query ( ) . Get ( "x-migrations-table" )
px , err := WithInstance ( db , & Config {
DatabaseName : purl . Path ,
MigrationsTable : migrationsTable ,
} )
2019-05-19 13:37:40 +01:00
2019-02-15 16:08:38 -06:00
if err != nil {
return nil , err
}
return px , nil
}
// Close the database connection
func ( ss * MSSQL ) Close ( ) error {
connErr := ss . conn . Close ( )
dbErr := ss . db . Close ( )
if connErr != nil || dbErr != nil {
return fmt . Errorf ( "conn: %v, db: %v" , connErr , dbErr )
}
return nil
}
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
func ( ss * MSSQL ) Lock ( ) error {
if ss . isLocked {
return database . ErrLocked
}
aid , err := database . GenerateAdvisoryLockId ( ss . config . DatabaseName , ss . config . SchemaName )
if err != nil {
return err
}
// This will either obtain the lock immediately and return true,
// or return false if the lock cannot be acquired immediately.
// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
2019-05-19 13:37:40 +01:00
query := ` EXEC sp_getapplock @Resource = ?, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0 `
var status mssql . ReturnStatus
if _ , err = ss . conn . ExecContext ( context . Background ( ) , query , aid , & status ) ; err == nil && status > - 1 {
ss . isLocked = true
return nil
} else if err != nil {
2019-02-15 16:08:38 -06:00
return & database . Error { OrigErr : err , Err : "try lock failed" , Query : [ ] byte ( query ) }
2019-05-19 13:37:40 +01:00
} else {
return & database . Error { Err : fmt . Sprintf ( "try lock failed with error %v" , lockErrorMap [ int ( status ) ] ) , Query : [ ] byte ( query ) }
2019-02-15 16:08:38 -06:00
}
}
// Unlock froms the migration lock from the database
func ( ss * MSSQL ) Unlock ( ) error {
if ! ss . isLocked {
return nil
}
aid , err := database . GenerateAdvisoryLockId ( ss . config . DatabaseName , ss . config . SchemaName )
if err != nil {
return err
}
// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
2019-05-19 13:37:40 +01:00
query := ` EXEC sp_releaseapplock @Resource = ?, @LockOwner = 'Session' `
2019-02-15 16:08:38 -06:00
if _ , err := ss . conn . ExecContext ( context . Background ( ) , query , aid ) ; err != nil {
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
ss . isLocked = false
return nil
}
// Run the migrations for the database
func ( ss * MSSQL ) Run ( migration io . Reader ) error {
migr , err := ioutil . ReadAll ( migration )
if err != nil {
return err
}
// run migration
query := string ( migr [ : ] )
if _ , err := ss . conn . ExecContext ( context . Background ( ) , query ) ; err != nil {
2019-05-19 16:08:15 +01:00
if msErr , ok := err . ( mssql . Error ) ; ok {
2019-05-19 14:08:10 +01:00
message := fmt . Sprintf ( "migration failed: %s" , msErr . Message )
if msErr . ProcName != "" {
message = fmt . Sprintf ( "%s (proc name %s)" , msErr . Message , msErr . ProcName )
}
return database . Error { OrigErr : err , Err : message , Query : migr , Line : uint ( msErr . LineNo ) }
}
2019-02-15 16:08:38 -06:00
return database . Error { OrigErr : err , Err : "migration failed" , Query : migr }
}
return nil
}
// SetVersion for the current database
func ( ss * MSSQL ) SetVersion ( version int , dirty bool ) error {
2019-05-19 13:37:40 +01:00
tx , err := ss . conn . BeginTx ( context . Background ( ) , & sql . TxOptions { } )
if err != nil {
return & database . Error { OrigErr : err , Err : "transaction start failed" }
}
2019-02-15 16:08:38 -06:00
query := ` TRUNCATE TABLE " ` + ss . config . MigrationsTable + ` " `
if _ , err := tx . Exec ( query ) ; err != nil {
2019-05-19 13:37:40 +01:00
if errRollback := tx . Rollback ( ) ; errRollback != nil {
err = multierror . Append ( err , errRollback )
}
2019-02-15 16:08:38 -06:00
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
if version >= 0 {
var dirtyBit int
if dirty {
dirtyBit = 1
}
query = ` INSERT INTO " ` + ss . config . MigrationsTable + ` " (version, dirty) VALUES ($1, $2) `
if _ , err := tx . Exec ( query , version , dirtyBit ) ; err != nil {
2019-05-19 13:37:40 +01:00
if errRollback := tx . Rollback ( ) ; errRollback != nil {
err = multierror . Append ( err , errRollback )
}
2019-02-15 16:08:38 -06:00
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
}
2019-05-19 13:37:40 +01:00
if err := tx . Commit ( ) ; err != nil {
return & database . Error { OrigErr : err , Err : "transaction commit failed" }
}
2019-02-15 16:08:38 -06:00
return nil
}
// Version of the current database state
func ( ss * MSSQL ) Version ( ) ( version int , dirty bool , err error ) {
query := ` SELECT TOP 1 version, dirty FROM " ` + ss . config . MigrationsTable + ` " `
err = ss . conn . QueryRowContext ( context . Background ( ) , query ) . Scan ( & version , & dirty )
switch {
case err == sql . ErrNoRows :
return database . NilVersion , false , nil
case err != nil :
// FIXME: convert to MSSQL error
return 0 , false , & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
default :
return version , dirty , nil
}
}
// Drop all tables from the database.
func ( ss * MSSQL ) Drop ( ) error {
// drop all referential integrity constraints
query := `
DECLARE @ Sql NVARCHAR ( 500 ) DECLARE @ Cursor CURSOR
SET @ Cursor = CURSOR FAST_FORWARD FOR
SELECT DISTINCT sql = ' ALTER TABLE [ ' + tc2 . TABLE_NAME + ' ] DROP [ ' + rc1 . CONSTRAINT_NAME + ']'
FROM INFORMATION_SCHEMA . REFERENTIAL_CONSTRAINTS rc1
LEFT JOIN INFORMATION_SCHEMA . TABLE_CONSTRAINTS tc2 ON tc2 . CONSTRAINT_NAME = rc1 . CONSTRAINT_NAME
OPEN @ Cursor FETCH NEXT FROM @ Cursor INTO @ Sql
WHILE ( @ @ FETCH_STATUS = 0 )
BEGIN
Exec sp_executesql @ Sql
FETCH NEXT FROM @ Cursor INTO @ Sql
END
CLOSE @ Cursor DEALLOCATE @ Cursor `
if _ , err := ss . conn . ExecContext ( context . Background ( ) , query ) ; err != nil {
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
// drop the tables
query = ` EXEC sp_MSforeachtable 'DROP TABLE ?' `
if _ , err := ss . conn . ExecContext ( context . Background ( ) , query ) ; err != nil {
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
return nil
}
func ( ss * MSSQL ) ensureVersionTable ( ) ( err error ) {
2019-05-19 13:37:40 +01:00
if err = ss . Lock ( ) ; err != nil {
return err
}
defer func ( ) {
if e := ss . Unlock ( ) ; e != nil {
if err == nil {
err = e
} else {
err = multierror . Append ( err , e )
}
}
} ( )
2019-02-15 16:08:38 -06:00
query := ` IF NOT EXISTS
( SELECT *
FROM sysobjects
WHERE id = object_id ( N ' [ dbo ] . [ ` + ss.config.MigrationsTable + ` ] ' )
AND OBJECTPROPERTY ( id , N ' IsUserTable ' ) = 1
)
CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL , dirty BIT NOT NULL ) ; `
if _ , err = ss . conn . ExecContext ( context . Background ( ) , query ) ; err != nil {
return & database . Error { OrigErr : err , Query : [ ] byte ( query ) }
}
return nil
}