From 98e5f88b9fa47837524890c56b5bb06323aedf59 Mon Sep 17 00:00:00 2001 From: nathan-c Date: Sun, 19 May 2019 16:08:15 +0100 Subject: [PATCH] mssql: fix error parsing and add tests --- database/mssql/mssql.go | 2 +- database/mssql/mssql_test.go | 94 +++++++++++++++++++++++++++--------- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/database/mssql/mssql.go b/database/mssql/mssql.go index 64f8eb1..fd12630 100644 --- a/database/mssql/mssql.go +++ b/database/mssql/mssql.go @@ -206,7 +206,7 @@ func (ss *MSSQL) Run(migration io.Reader) error { // run migration query := string(migr[:]) if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { - if msErr, ok := err.(*mssql.Error); ok { + if msErr, ok := err.(mssql.Error); ok { message := fmt.Sprintf("migration failed: %s", msErr.Message) if msErr.ProcName != "" { message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) diff --git a/database/mssql/mssql_test.go b/database/mssql/mssql_test.go index 817fa58..013a234 100755 --- a/database/mssql/mssql_test.go +++ b/database/mssql/mssql_test.go @@ -6,6 +6,7 @@ import ( sqldriver "database/sql/driver" "fmt" "log" + "strings" "testing" "github.com/dhui/dktest" @@ -32,12 +33,16 @@ var ( } ) +func msConnectionString(host, port string) string { + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port) +} + func isReady(ctx context.Context, c dktest.ContainerInfo) bool { ip, port, err := c.Port(defaultPort) if err != nil { return false } - uri := fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, ip, port) + uri := msConnectionString(ip, port) db, err := sql.Open("sqlserver", uri) if err != nil { return false @@ -61,15 +66,13 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool { } func Test(t *testing.T) { - // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) } - addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port) + addr := msConnectionString(ip, port) p := &MSSQL{} d, err := p.Open(addr) if err != nil { @@ -78,33 +81,22 @@ func Test(t *testing.T) { defer func() { if err := d.Close(); err != nil { - log.Println("close error:", err) + t.Error(err) } }() dt.Test(t, d, []byte("SELECT 1")) - - // check ensureVersionTable - if err := d.(*MSSQL).ensureVersionTable(); err != nil { - t.Fatal(err) - } - // check again - if err := d.(*MSSQL).ensureVersionTable(); err != nil { - t.Fatal(err) - } }) } func TestMigrate(t *testing.T) { - // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.Port(defaultPort) if err != nil { t.Fatal(err) } - addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port) + addr := msConnectionString(ip, port) p := &MSSQL{} d, err := p.Open(addr) if err != nil { @@ -113,24 +105,78 @@ func TestMigrate(t *testing.T) { defer func() { if err := d.Close(); err != nil { - log.Println("close error:", err) + t.Error(err) } }() - m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d) if err != nil { - t.Fatalf("%v", err) + t.Fatal(err) } dt.TestMigrate(t, m, []byte("SELECT 1")) + }) +} - // check ensureVersionTable - if err := d.(*MSSQL).ensureVersionTable(); err != nil { +func TestMultiStatement(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { t.Fatal(err) } - // check again - if err := d.(*MSSQL).ensureVersionTable(); err != nil { + + addr := msConnectionString(ip, port) + ms := &MSSQL{} + d, err := ms.Open(addr) + if err != nil { t.Fatal(err) } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure second table exists + var exists int + if err := d.(*MSSQL).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil { + t.Fatal(err) + } + if exists != 1 { + t.Fatalf("expected table bar to exist") + } + }) +} + +func TestErrorParsing(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := msConnectionString(ip, port) + p := &MSSQL{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` + + ` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` + + `'TABLEE' used in a CREATE, DROP, or ALTER statement.)` + if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil { + t.Fatal("expected err but got nil") + } else if err.Error() != wantErr { + t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error()) + } }) }