rename mssql driver to sqlserver

This commit is contained in:
Thomas Lokshall 2019-05-24 15:16:12 +02:00
parent e08ae0e996
commit 293bfec844
21 changed files with 30 additions and 31 deletions

View File

@ -1,5 +1,5 @@
SOURCE ?= file go_bindata github aws_s3 google_cloud_storage godoc_vfs gitlab SOURCE ?= file go_bindata github aws_s3 google_cloud_storage godoc_vfs gitlab
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb mssql DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb sqlserver
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-) VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?= TEST_FLAGS ?=
REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)") REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)")

View File

@ -37,7 +37,7 @@ Database drivers run migrations. [Add a new database?](database/driver.go)
* [CockroachDB](database/cockroachdb) * [CockroachDB](database/cockroachdb)
* [ClickHouse](database/clickhouse) * [ClickHouse](database/clickhouse)
* [Firebird](database/firebird) ([todo #49](https://github.com/golang-migrate/migrate/issues/49)) * [Firebird](database/firebird) ([todo #49](https://github.com/golang-migrate/migrate/issues/49))
* [MS SQL Server](database/mssql) * [MS SQL Server](database/sqlserver)
### Database URLs ### Database URLs

View File

@ -1,4 +1,4 @@
package mssql package sqlserver
import ( import (
"context" "context"
@ -15,8 +15,7 @@ import (
) )
func init() { func init() {
db := MSSQL{} database.Register("sqlserver", &SQLServer{})
database.Register("sqlserver", &db)
} }
// DefaultMigrationsTable is the name of the migrations table in the database // DefaultMigrationsTable is the name of the migrations table in the database
@ -44,7 +43,7 @@ type Config struct {
} }
// MSSQL connection // MSSQL connection
type MSSQL struct { type SQLServer struct {
// Locking and unlocking need to use the same connection // Locking and unlocking need to use the same connection
conn *sql.Conn conn *sql.Conn
db *sql.DB db *sql.DB
@ -100,7 +99,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return nil, err return nil, err
} }
ss := &MSSQL{ ss := &SQLServer{
conn: conn, conn: conn,
db: instance, db: instance,
config: config, config: config,
@ -114,7 +113,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
} }
// Open a connection to the database // Open a connection to the database
func (ss *MSSQL) Open(url string) (database.Driver, error) { func (ss *SQLServer) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url) purl, err := nurl.Parse(url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -140,7 +139,7 @@ func (ss *MSSQL) Open(url string) (database.Driver, error) {
} }
// Close the database connection // Close the database connection
func (ss *MSSQL) Close() error { func (ss *SQLServer) Close() error {
connErr := ss.conn.Close() connErr := ss.conn.Close()
dbErr := ss.db.Close() dbErr := ss.db.Close()
if connErr != nil || dbErr != nil { if connErr != nil || dbErr != nil {
@ -150,7 +149,7 @@ func (ss *MSSQL) Close() error {
} }
// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
func (ss *MSSQL) Lock() error { func (ss *SQLServer) Lock() error {
if ss.isLocked { if ss.isLocked {
return database.ErrLocked return database.ErrLocked
} }
@ -177,7 +176,7 @@ func (ss *MSSQL) Lock() error {
} }
// Unlock froms the migration lock from the database // Unlock froms the migration lock from the database
func (ss *MSSQL) Unlock() error { func (ss *SQLServer) Unlock() error {
if !ss.isLocked { if !ss.isLocked {
return nil return nil
} }
@ -198,7 +197,7 @@ func (ss *MSSQL) Unlock() error {
} }
// Run the migrations for the database // Run the migrations for the database
func (ss *MSSQL) Run(migration io.Reader) error { func (ss *SQLServer) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration) migr, err := ioutil.ReadAll(migration)
if err != nil { if err != nil {
return err return err
@ -221,7 +220,7 @@ func (ss *MSSQL) Run(migration io.Reader) error {
} }
// SetVersion for the current database // SetVersion for the current database
func (ss *MSSQL) SetVersion(version int, dirty bool) error { func (ss *SQLServer) SetVersion(version int, dirty bool) error {
tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil { if err != nil {
@ -258,7 +257,7 @@ func (ss *MSSQL) SetVersion(version int, dirty bool) error {
} }
// Version of the current database state // Version of the current database state
func (ss *MSSQL) Version() (version int, dirty bool, err error) { func (ss *SQLServer) Version() (version int, dirty bool, err error) {
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch { switch {
@ -275,7 +274,7 @@ func (ss *MSSQL) Version() (version int, dirty bool, err error) {
} }
// Drop all tables from the database. // Drop all tables from the database.
func (ss *MSSQL) Drop() error { func (ss *SQLServer) Drop() error {
// drop all referential integrity constraints // drop all referential integrity constraints
query := ` query := `
@ -309,7 +308,7 @@ func (ss *MSSQL) Drop() error {
return nil return nil
} }
func (ss *MSSQL) ensureVersionTable() (err error) { func (ss *SQLServer) ensureVersionTable() (err error) {
if err = ss.Lock(); err != nil { if err = ss.Lock(); err != nil {
return err return err
} }

View File

@ -1,4 +1,4 @@
package mssql package sqlserver
import ( import (
"context" "context"
@ -74,7 +74,7 @@ func Test(t *testing.T) {
} }
addr := msConnectionString(ip, port) addr := msConnectionString(ip, port)
p := &MSSQL{} p := &SQLServer{}
d, err := p.Open(addr) d, err := p.Open(addr)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
@ -98,7 +98,7 @@ func TestMigrate(t *testing.T) {
} }
addr := msConnectionString(ip, port) addr := msConnectionString(ip, port)
p := &MSSQL{} p := &SQLServer{}
d, err := p.Open(addr) d, err := p.Open(addr)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
@ -126,7 +126,7 @@ func TestMultiStatement(t *testing.T) {
} }
addr := msConnectionString(ip, port) addr := msConnectionString(ip, port)
ms := &MSSQL{} ms := &SQLServer{}
d, err := ms.Open(addr) d, err := ms.Open(addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -142,7 +142,7 @@ func TestMultiStatement(t *testing.T) {
// make sure second table exists // make sure second table exists
var exists int var exists int
if err := d.(*MSSQL).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil { if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if exists != 1 { if exists != 1 {
@ -159,7 +159,7 @@ func TestErrorParsing(t *testing.T) {
} }
addr := msConnectionString(ip, port) addr := msConnectionString(ip, port)
p := &MSSQL{} p := &SQLServer{}
d, err := p.Open(addr) d, err := p.Open(addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -189,14 +189,14 @@ func TestLockWorks(t *testing.T) {
} }
addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port) addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
p := &MSSQL{} p := &SQLServer{}
d, err := p.Open(addr) d, err := p.Open(addr)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }
dt.Test(t, d, []byte("SELECT 1")) dt.Test(t, d, []byte("SELECT 1"))
ms := d.(*MSSQL) ms := d.(*SQLServer)
err = ms.Lock() err = ms.Lock()
if err != nil { if err != nil {

View File

@ -1,7 +0,0 @@
// +build mssql
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/mssql"
)

View File

@ -0,0 +1,7 @@
// +build sqlserver
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/sqlserver"
)