diff --git a/database/driver.go b/database/driver.go index 901e5dd..2c673ca 100644 --- a/database/driver.go +++ b/database/driver.go @@ -7,8 +7,9 @@ package database import ( "fmt" "io" - nurl "net/url" "sync" + + iurl "github.com/golang-migrate/migrate/v4/internal/url" ) var ( @@ -81,21 +82,16 @@ type Driver interface { // Open returns a new driver instance. func Open(url string) (Driver, error) { - u, err := nurl.Parse(url) + scheme, err := iurl.SchemeFromURL(url) if err != nil { - return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+ - "See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err) - } - - if u.Scheme == "" { - return nil, fmt.Errorf("database driver: invalid URL scheme") + return nil, err } driversMu.RLock() - d, ok := drivers[u.Scheme] + d, ok := drivers[scheme] driversMu.RUnlock() if !ok { - return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme) + return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme) } return d.Open(url) diff --git a/database/driver_test.go b/database/driver_test.go index c0a2930..7880f32 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -1,8 +1,115 @@ package database +import ( + "io" + "testing" +) + func ExampleDriver() { // see database/stub for an example // database/stub/stub.go has the driver implementation // database/stub/stub_test.go runs database/testing/test.go:Test } + +// Using database/stub here is not possible as it +// results in an import cycle. +type mockDriver struct { + url string +} + +func (m *mockDriver) Open(url string) (Driver, error) { + return &mockDriver{ + url: url, + }, nil +} + +func (m *mockDriver) Close() error { + return nil +} + +func (m *mockDriver) Lock() error { + return nil +} + +func (m *mockDriver) Unlock() error { + return nil +} + +func (m *mockDriver) Run(migration io.Reader) error { + return nil +} + +func (m *mockDriver) SetVersion(version int, dirty bool) error { + return nil +} + +func (m *mockDriver) Version() (version int, dirty bool, err error) { + return 0, false, nil +} + +func (m *mockDriver) Drop() error { + return nil +} + +func TestRegisterTwice(t *testing.T) { + Register("mock", &mockDriver{}) + + var err interface{} + func() { + defer func() { + err = recover() + }() + Register("mock", &mockDriver{}) + }() + + if err == nil { + t.Fatal("expected a panic when calling Register twice") + } +} + +func TestOpen(t *testing.T) { + // Make sure the driver is registered. + // But if the previous test already registered it just ignore the panic. + // If we don't do this it will be impossible to run this test standalone. + func() { + defer func() { + _ = recover() + }() + Register("mock", &mockDriver{}) + }() + + cases := []struct { + url string + err bool + }{ + { + "mock://user:pass@tcp(host:1337)/db", + false, + }, + { + "unknown://bla", + true, + }, + } + + for _, c := range cases { + t.Run(c.url, func(t *testing.T) { + d, err := Open(c.url) + + if err == nil { + if c.err { + t.Fatal("expected an error for an unknown driver") + } else { + if md, ok := d.(*mockDriver); !ok { + t.Fatalf("expected *mockDriver got %T", d) + } else if md.url != c.url { + t.Fatalf("expected %q got %q", c.url, md.url) + } + } + } else if !c.err { + t.Fatalf("did not expect %q", err) + } + }) + } +} diff --git a/internal/url/url.go b/internal/url/url.go new file mode 100644 index 0000000..e793fa8 --- /dev/null +++ b/internal/url/url.go @@ -0,0 +1,25 @@ +package url + +import ( + "errors" + "strings" +) + +var errNoScheme = errors.New("no scheme") +var errEmptyURL = errors.New("URL cannot be empty") + +// schemeFromURL returns the scheme from a URL string +func SchemeFromURL(url string) (string, error) { + if url == "" { + return "", errEmptyURL + } + + i := strings.Index(url, ":") + + // No : or : is the first character. + if i < 1 { + return "", errNoScheme + } + + return url[0:i], nil +} diff --git a/internal/url/url_test.go b/internal/url/url_test.go new file mode 100644 index 0000000..de338e7 --- /dev/null +++ b/internal/url/url_test.go @@ -0,0 +1,48 @@ +package url + +import ( + "testing" +) + +func TestSchemeFromUrl(t *testing.T) { + cases := []struct { + name string + urlStr string + expected string + expectErr error + }{ + { + name: "Simple", + urlStr: "protocol://path", + expected: "protocol", + }, + { + // See issue #264 + name: "MySQLWithPort", + urlStr: "mysql://user:pass@tcp(host:1337)/db", + expected: "mysql", + }, + { + name: "Empty", + urlStr: "", + expectErr: errEmptyURL, + }, + { + name: "NoScheme", + urlStr: "hello", + expectErr: errNoScheme, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s, err := SchemeFromURL(tc.urlStr) + if err != tc.expectErr { + t.Fatalf("expected %q, but received %q", tc.expectErr, err) + } + if s != tc.expected { + t.Fatalf("expected %q, but received %q", tc.expected, s) + } + }) + } +} diff --git a/migrate.go b/migrate.go index 3ede504..f692d6f 100644 --- a/migrate.go +++ b/migrate.go @@ -13,6 +13,7 @@ import ( "time" "github.com/golang-migrate/migrate/v4/database" + iurl "github.com/golang-migrate/migrate/v4/internal/url" "github.com/golang-migrate/migrate/v4/source" ) @@ -85,13 +86,13 @@ type Migrate struct { func New(sourceURL, databaseURL string) (*Migrate, error) { m := newCommon() - sourceName, err := sourceSchemeFromURL(sourceURL) + sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { return nil, err } m.sourceName = sourceName - databaseName, err := databaseSchemeFromURL(databaseURL) + databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { return nil, err } @@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { m := newCommon() - sourceName, err := schemeFromURL(sourceURL) + sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { return nil, err } @@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { m := newCommon() - databaseName, err := schemeFromURL(databaseURL) + databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { return nil, err } diff --git a/util.go b/util.go index ecf3773..26131a3 100644 --- a/util.go +++ b/util.go @@ -1,7 +1,6 @@ package migrate import ( - "errors" "fmt" nurl "net/url" "strings" @@ -49,41 +48,6 @@ func suint(n int) uint { return uint(n) } -var errNoScheme = errors.New("no scheme") -var errEmptyURL = errors.New("URL cannot be empty") - -func sourceSchemeFromURL(url string) (string, error) { - u, err := schemeFromURL(url) - if err != nil { - return "", fmt.Errorf("source: %v", err) - } - return u, nil -} - -func databaseSchemeFromURL(url string) (string, error) { - u, err := schemeFromURL(url) - if err != nil { - return "", fmt.Errorf("database: %v", err) - } - return u, nil -} - -// schemeFromURL returns the scheme from a URL string -func schemeFromURL(url string) (string, error) { - if url == "" { - return "", errEmptyURL - } - - i := strings.Index(url, ":") - - // No : or : is the first character. - if i < 1 { - return "", errNoScheme - } - - return url[0:i], nil -} - // FilterCustomQuery filters all query values starting with `x-` func FilterCustomQuery(u *nurl.URL) *nurl.URL { ux := *u diff --git a/util_test.go b/util_test.go index ef395e8..1ad2344 100644 --- a/util_test.go +++ b/util_test.go @@ -1,7 +1,6 @@ package migrate import ( - "errors" nurl "net/url" "testing" ) @@ -31,104 +30,3 @@ func TestFilterCustomQuery(t *testing.T) { t.Fatalf("didn't expect x-custom") } } - -func TestSourceSchemeFromUrlSuccess(t *testing.T) { - urlStr := "protocol://path" - expected := "protocol" - - u, err := sourceSchemeFromURL(urlStr) - if err != nil { - t.Fatalf("expected no error, but received %q", err) - } - if u != expected { - t.Fatalf("expected %q, but received %q", expected, u) - } -} - -func TestSourceSchemeFromUrlFailure(t *testing.T) { - cases := []struct { - name string - urlStr string - expectErr error - }{ - { - name: "Empty", - urlStr: "", - expectErr: errors.New("source: URL cannot be empty"), - }, - { - name: "NoScheme", - urlStr: "hello", - expectErr: errors.New("source: no scheme"), - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - _, err := sourceSchemeFromURL(tc.urlStr) - if err.Error() != tc.expectErr.Error() { - t.Fatalf("expected %q, but received %q", tc.expectErr, err) - } - }) - } -} - -func TestDatabaseSchemeFromUrlSuccess(t *testing.T) { - cases := []struct { - name string - urlStr string - expected string - }{ - { - name: "Simple", - urlStr: "protocol://path", - expected: "protocol", - }, - { - // See issue #264 - name: "MySQLWithPort", - urlStr: "mysql://user:pass@tcp(host:1337)/db", - expected: "mysql", - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - u, err := databaseSchemeFromURL(tc.urlStr) - if err != nil { - t.Fatalf("expected no error, but received %q", err) - } - if u != tc.expected { - t.Fatalf("expected %q, but received %q", tc.expected, u) - } - }) - } -} - -func TestDatabaseSchemeFromUrlFailure(t *testing.T) { - cases := []struct { - name string - urlStr string - expectErr error - }{ - { - name: "Empty", - urlStr: "", - expectErr: errors.New("database: URL cannot be empty"), - }, - { - name: "NoScheme", - urlStr: "hello", - expectErr: errors.New("database: no scheme"), - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - _, err := databaseSchemeFromURL(tc.urlStr) - if err.Error() != tc.expectErr.Error() { - t.Fatalf("expected %q, but received %q", tc.expectErr, err) - } - }) - } -}