Let database.Open() use schemeFromURL as well (#271)

* Let database.Open() use schemeFromURL as well

Otherwise it will fail on MySQL DSNs.

Moved schemeFromURL into the database package. Also removed databaseSchemeFromURL
and sourceSchemeFromURL as they were just calling schemeFromURL.

Fixes https://github.com/golang-migrate/migrate/pull/265#issuecomment-522301237

* Moved url functions into internal/url

Also merged the test cases.

* Add some database tests to improve coverage

* Fix suggestions
This commit is contained in:
Erik Dubbelboer 2019-08-20 18:59:15 +02:00 committed by Dale Hui
parent d5960ade4a
commit e5b4be7771
7 changed files with 191 additions and 152 deletions

View File

@ -7,8 +7,9 @@ package database
import (
"fmt"
"io"
nurl "net/url"
"sync"
iurl "github.com/golang-migrate/migrate/v4/internal/url"
)
var (
@ -81,21 +82,16 @@ type Driver interface {
// Open returns a new driver instance.
func Open(url string) (Driver, error) {
u, err := nurl.Parse(url)
scheme, err := iurl.SchemeFromURL(url)
if err != nil {
return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+
"See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err)
}
if u.Scheme == "" {
return nil, fmt.Errorf("database driver: invalid URL scheme")
return nil, err
}
driversMu.RLock()
d, ok := drivers[u.Scheme]
d, ok := drivers[scheme]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme)
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
}
return d.Open(url)

View File

@ -1,8 +1,115 @@
package database
import (
"io"
"testing"
)
func ExampleDriver() {
// see database/stub for an example
// database/stub/stub.go has the driver implementation
// database/stub/stub_test.go runs database/testing/test.go:Test
}
// Using database/stub here is not possible as it
// results in an import cycle.
type mockDriver struct {
url string
}
func (m *mockDriver) Open(url string) (Driver, error) {
return &mockDriver{
url: url,
}, nil
}
func (m *mockDriver) Close() error {
return nil
}
func (m *mockDriver) Lock() error {
return nil
}
func (m *mockDriver) Unlock() error {
return nil
}
func (m *mockDriver) Run(migration io.Reader) error {
return nil
}
func (m *mockDriver) SetVersion(version int, dirty bool) error {
return nil
}
func (m *mockDriver) Version() (version int, dirty bool, err error) {
return 0, false, nil
}
func (m *mockDriver) Drop() error {
return nil
}
func TestRegisterTwice(t *testing.T) {
Register("mock", &mockDriver{})
var err interface{}
func() {
defer func() {
err = recover()
}()
Register("mock", &mockDriver{})
}()
if err == nil {
t.Fatal("expected a panic when calling Register twice")
}
}
func TestOpen(t *testing.T) {
// Make sure the driver is registered.
// But if the previous test already registered it just ignore the panic.
// If we don't do this it will be impossible to run this test standalone.
func() {
defer func() {
_ = recover()
}()
Register("mock", &mockDriver{})
}()
cases := []struct {
url string
err bool
}{
{
"mock://user:pass@tcp(host:1337)/db",
false,
},
{
"unknown://bla",
true,
},
}
for _, c := range cases {
t.Run(c.url, func(t *testing.T) {
d, err := Open(c.url)
if err == nil {
if c.err {
t.Fatal("expected an error for an unknown driver")
} else {
if md, ok := d.(*mockDriver); !ok {
t.Fatalf("expected *mockDriver got %T", d)
} else if md.url != c.url {
t.Fatalf("expected %q got %q", c.url, md.url)
}
}
} else if !c.err {
t.Fatalf("did not expect %q", err)
}
})
}
}

25
internal/url/url.go Normal file
View File

@ -0,0 +1,25 @@
package url
import (
"errors"
"strings"
)
var errNoScheme = errors.New("no scheme")
var errEmptyURL = errors.New("URL cannot be empty")
// schemeFromURL returns the scheme from a URL string
func SchemeFromURL(url string) (string, error) {
if url == "" {
return "", errEmptyURL
}
i := strings.Index(url, ":")
// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}
return url[0:i], nil
}

48
internal/url/url_test.go Normal file
View File

@ -0,0 +1,48 @@
package url
import (
"testing"
)
func TestSchemeFromUrl(t *testing.T) {
cases := []struct {
name string
urlStr string
expected string
expectErr error
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
{
name: "Empty",
urlStr: "",
expectErr: errEmptyURL,
},
{
name: "NoScheme",
urlStr: "hello",
expectErr: errNoScheme,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s, err := SchemeFromURL(tc.urlStr)
if err != tc.expectErr {
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
}
if s != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, s)
}
})
}
}

View File

@ -13,6 +13,7 @@ import (
"time"
"github.com/golang-migrate/migrate/v4/database"
iurl "github.com/golang-migrate/migrate/v4/internal/url"
"github.com/golang-migrate/migrate/v4/source"
)
@ -85,13 +86,13 @@ type Migrate struct {
func New(sourceURL, databaseURL string) (*Migrate, error) {
m := newCommon()
sourceName, err := sourceSchemeFromURL(sourceURL)
sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
return nil, err
}
m.sourceName = sourceName
databaseName, err := databaseSchemeFromURL(databaseURL)
databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
return nil, err
}
@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()
sourceName, err := schemeFromURL(sourceURL)
sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
return nil, err
}
@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
m := newCommon()
databaseName, err := schemeFromURL(databaseURL)
databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
return nil, err
}

36
util.go
View File

@ -1,7 +1,6 @@
package migrate
import (
"errors"
"fmt"
nurl "net/url"
"strings"
@ -49,41 +48,6 @@ func suint(n int) uint {
return uint(n)
}
var errNoScheme = errors.New("no scheme")
var errEmptyURL = errors.New("URL cannot be empty")
func sourceSchemeFromURL(url string) (string, error) {
u, err := schemeFromURL(url)
if err != nil {
return "", fmt.Errorf("source: %v", err)
}
return u, nil
}
func databaseSchemeFromURL(url string) (string, error) {
u, err := schemeFromURL(url)
if err != nil {
return "", fmt.Errorf("database: %v", err)
}
return u, nil
}
// schemeFromURL returns the scheme from a URL string
func schemeFromURL(url string) (string, error) {
if url == "" {
return "", errEmptyURL
}
i := strings.Index(url, ":")
// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}
return url[0:i], nil
}
// FilterCustomQuery filters all query values starting with `x-`
func FilterCustomQuery(u *nurl.URL) *nurl.URL {
ux := *u

View File

@ -1,7 +1,6 @@
package migrate
import (
"errors"
nurl "net/url"
"testing"
)
@ -31,104 +30,3 @@ func TestFilterCustomQuery(t *testing.T) {
t.Fatalf("didn't expect x-custom")
}
}
func TestSourceSchemeFromUrlSuccess(t *testing.T) {
urlStr := "protocol://path"
expected := "protocol"
u, err := sourceSchemeFromURL(urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
}
if u != expected {
t.Fatalf("expected %q, but received %q", expected, u)
}
}
func TestSourceSchemeFromUrlFailure(t *testing.T) {
cases := []struct {
name string
urlStr string
expectErr error
}{
{
name: "Empty",
urlStr: "",
expectErr: errors.New("source: URL cannot be empty"),
},
{
name: "NoScheme",
urlStr: "hello",
expectErr: errors.New("source: no scheme"),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := sourceSchemeFromURL(tc.urlStr)
if err.Error() != tc.expectErr.Error() {
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
}
})
}
}
func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
cases := []struct {
name string
urlStr string
expected string
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u, err := databaseSchemeFromURL(tc.urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
}
if u != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, u)
}
})
}
}
func TestDatabaseSchemeFromUrlFailure(t *testing.T) {
cases := []struct {
name string
urlStr string
expectErr error
}{
{
name: "Empty",
urlStr: "",
expectErr: errors.New("database: URL cannot be empty"),
},
{
name: "NoScheme",
urlStr: "hello",
expectErr: errors.New("database: no scheme"),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := databaseSchemeFromURL(tc.urlStr)
if err.Error() != tc.expectErr.Error() {
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
}
})
}
}