diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index d28eb8b..5fdb756 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -2,13 +2,13 @@ package mysql import ( "database/sql" - // sqldriver "database/sql/driver" + sqldriver "database/sql/driver" "fmt" // "io/ioutil" // "log" "testing" - // "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" dt "github.com/golang-migrate/migrate/database/testing" mt "github.com/golang-migrate/migrate/testing" ) @@ -27,6 +27,12 @@ func isReady(i mt.Instance) bool { } defer db.Close() if err = db.Ping(); err != nil { + switch err { + case sqldriver.ErrBadConn, mysql.ErrInvalidConn: + return false + default: + fmt.Println(err) + } return false } @@ -44,6 +50,7 @@ func Test(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d.Close() dt.Test(t, d, []byte("SELECT 1")) // check ensureVersionTable diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 46413b9..0357d27 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -5,13 +5,14 @@ package postgres import ( "bytes" "database/sql" + sqldriver "database/sql/driver" "fmt" "io" "testing" - "github.com/lib/pq" dt "github.com/golang-migrate/migrate/database/testing" mt "github.com/golang-migrate/migrate/testing" + // "github.com/lib/pq" ) var versions = []mt.Version{ @@ -28,14 +29,14 @@ func isReady(i mt.Instance) bool { return false } defer db.Close() - err = db.Ping() - if err == io.EOF { - return false - - } else if e, ok := err.(*pq.Error); ok { - if e.Code.Name() == "cannot_connect_now" { + if err = db.Ping(); err != nil { + switch err { + case sqldriver.ErrBadConn, io.EOF: return false + default: + fmt.Println(err) } + return false } return true @@ -50,6 +51,7 @@ func Test(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d.Close() dt.Test(t, d, []byte("SELECT 1")) }) } @@ -63,6 +65,7 @@ func TestMultiStatement(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d.Close() if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"))); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -83,10 +86,11 @@ func TestFilterCustomQuery(t *testing.T) { func(t *testing.T, i mt.Instance) { p := &Postgres{} addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&x-custom=foobar", i.Host(), i.Port()) - _, err := p.Open(addr) + d, err := p.Open(addr) if err != nil { t.Fatalf("%v", err) } + defer d.Close() }) } @@ -99,6 +103,7 @@ func TestWithSchema(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d.Close() // create foobar schema if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foobar AUTHORIZATION postgres"))); err != nil { @@ -113,6 +118,7 @@ func TestWithSchema(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d2.Close() version, _, err := d2.Version() if err != nil {