mssql: fix error parsing and add tests

This commit is contained in:
nathan-c 2019-05-19 16:08:15 +01:00
parent 5ac583ba7b
commit 98e5f88b9f
2 changed files with 71 additions and 25 deletions

View File

@ -206,7 +206,7 @@ func (ss *MSSQL) Run(migration io.Reader) error {
// run migration
query := string(migr[:])
if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
if msErr, ok := err.(*mssql.Error); ok {
if msErr, ok := err.(mssql.Error); ok {
message := fmt.Sprintf("migration failed: %s", msErr.Message)
if msErr.ProcName != "" {
message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)

View File

@ -6,6 +6,7 @@ import (
sqldriver "database/sql/driver"
"fmt"
"log"
"strings"
"testing"
"github.com/dhui/dktest"
@ -32,12 +33,16 @@ var (
}
)
func msConnectionString(host, port string) string {
return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(defaultPort)
if err != nil {
return false
}
uri := fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, ip, port)
uri := msConnectionString(ip, port)
db, err := sql.Open("sqlserver", uri)
if err != nil {
return false
@ -61,15 +66,13 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
}
func Test(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
addr := msConnectionString(ip, port)
p := &MSSQL{}
d, err := p.Open(addr)
if err != nil {
@ -78,33 +81,22 @@ func Test(t *testing.T) {
defer func() {
if err := d.Close(); err != nil {
log.Println("close error:", err)
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
// check ensureVersionTable
if err := d.(*MSSQL).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*MSSQL).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestMigrate(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
addr := msConnectionString(ip, port)
p := &MSSQL{}
d, err := p.Open(addr)
if err != nil {
@ -113,24 +105,78 @@ func TestMigrate(t *testing.T) {
defer func() {
if err := d.Close(); err != nil {
log.Println("close error:", err)
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d)
if err != nil {
t.Fatalf("%v", err)
t.Fatal(err)
}
dt.TestMigrate(t, m, []byte("SELECT 1"))
})
}
// check ensureVersionTable
if err := d.(*MSSQL).ensureVersionTable(); err != nil {
func TestMultiStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// check again
if err := d.(*MSSQL).ensureVersionTable(); err != nil {
addr := msConnectionString(ip, port)
ms := &MSSQL{}
d, err := ms.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists int
if err := d.(*MSSQL).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
t.Fatal(err)
}
if exists != 1 {
t.Fatalf("expected table bar to exist")
}
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := msConnectionString(ip, port)
p := &MSSQL{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` +
` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` +
`'TABLEE' used in a CREATE, DROP, or ALTER statement.)`
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
})
}