mirror of https://github.com/status-im/migrate.git
Let database.Open() use schemeFromURL as well (#271)
* Let database.Open() use schemeFromURL as well Otherwise it will fail on MySQL DSNs. Moved schemeFromURL into the database package. Also removed databaseSchemeFromURL and sourceSchemeFromURL as they were just calling schemeFromURL. Fixes https://github.com/golang-migrate/migrate/pull/265#issuecomment-522301237 * Moved url functions into internal/url Also merged the test cases. * Add some database tests to improve coverage * Fix suggestions
This commit is contained in:
parent
d5960ade4a
commit
e5b4be7771
|
@ -7,8 +7,9 @@ package database
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"sync"
|
||||
|
||||
iurl "github.com/golang-migrate/migrate/v4/internal/url"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -81,21 +82,16 @@ type Driver interface {
|
|||
|
||||
// Open returns a new driver instance.
|
||||
func Open(url string) (Driver, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
scheme, err := iurl.SchemeFromURL(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+
|
||||
"See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err)
|
||||
}
|
||||
|
||||
if u.Scheme == "" {
|
||||
return nil, fmt.Errorf("database driver: invalid URL scheme")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driversMu.RLock()
|
||||
d, ok := drivers[u.Scheme]
|
||||
d, ok := drivers[scheme]
|
||||
driversMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme)
|
||||
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
|
||||
}
|
||||
|
||||
return d.Open(url)
|
||||
|
|
|
@ -1,8 +1,115 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleDriver() {
|
||||
// see database/stub for an example
|
||||
|
||||
// database/stub/stub.go has the driver implementation
|
||||
// database/stub/stub_test.go runs database/testing/test.go:Test
|
||||
}
|
||||
|
||||
// Using database/stub here is not possible as it
|
||||
// results in an import cycle.
|
||||
type mockDriver struct {
|
||||
url string
|
||||
}
|
||||
|
||||
func (m *mockDriver) Open(url string) (Driver, error) {
|
||||
return &mockDriver{
|
||||
url: url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Lock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Unlock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Run(migration io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) SetVersion(version int, dirty bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Version() (version int, dirty bool, err error) {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Drop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRegisterTwice(t *testing.T) {
|
||||
Register("mock", &mockDriver{})
|
||||
|
||||
var err interface{}
|
||||
func() {
|
||||
defer func() {
|
||||
err = recover()
|
||||
}()
|
||||
Register("mock", &mockDriver{})
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected a panic when calling Register twice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
// Make sure the driver is registered.
|
||||
// But if the previous test already registered it just ignore the panic.
|
||||
// If we don't do this it will be impossible to run this test standalone.
|
||||
func() {
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
Register("mock", &mockDriver{})
|
||||
}()
|
||||
|
||||
cases := []struct {
|
||||
url string
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
"mock://user:pass@tcp(host:1337)/db",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unknown://bla",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.url, func(t *testing.T) {
|
||||
d, err := Open(c.url)
|
||||
|
||||
if err == nil {
|
||||
if c.err {
|
||||
t.Fatal("expected an error for an unknown driver")
|
||||
} else {
|
||||
if md, ok := d.(*mockDriver); !ok {
|
||||
t.Fatalf("expected *mockDriver got %T", d)
|
||||
} else if md.url != c.url {
|
||||
t.Fatalf("expected %q got %q", c.url, md.url)
|
||||
}
|
||||
}
|
||||
} else if !c.err {
|
||||
t.Fatalf("did not expect %q", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package url
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errNoScheme = errors.New("no scheme")
|
||||
var errEmptyURL = errors.New("URL cannot be empty")
|
||||
|
||||
// schemeFromURL returns the scheme from a URL string
|
||||
func SchemeFromURL(url string) (string, error) {
|
||||
if url == "" {
|
||||
return "", errEmptyURL
|
||||
}
|
||||
|
||||
i := strings.Index(url, ":")
|
||||
|
||||
// No : or : is the first character.
|
||||
if i < 1 {
|
||||
return "", errNoScheme
|
||||
}
|
||||
|
||||
return url[0:i], nil
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package url
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSchemeFromUrl(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expected string
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "Simple",
|
||||
urlStr: "protocol://path",
|
||||
expected: "protocol",
|
||||
},
|
||||
{
|
||||
// See issue #264
|
||||
name: "MySQLWithPort",
|
||||
urlStr: "mysql://user:pass@tcp(host:1337)/db",
|
||||
expected: "mysql",
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
urlStr: "",
|
||||
expectErr: errEmptyURL,
|
||||
},
|
||||
{
|
||||
name: "NoScheme",
|
||||
urlStr: "hello",
|
||||
expectErr: errNoScheme,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s, err := SchemeFromURL(tc.urlStr)
|
||||
if err != tc.expectErr {
|
||||
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
|
||||
}
|
||||
if s != tc.expected {
|
||||
t.Fatalf("expected %q, but received %q", tc.expected, s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
iurl "github.com/golang-migrate/migrate/v4/internal/url"
|
||||
"github.com/golang-migrate/migrate/v4/source"
|
||||
)
|
||||
|
||||
|
@ -85,13 +86,13 @@ type Migrate struct {
|
|||
func New(sourceURL, databaseURL string) (*Migrate, error) {
|
||||
m := newCommon()
|
||||
|
||||
sourceName, err := sourceSchemeFromURL(sourceURL)
|
||||
sourceName, err := iurl.SchemeFromURL(sourceURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.sourceName = sourceName
|
||||
|
||||
databaseName, err := databaseSchemeFromURL(databaseURL)
|
||||
databaseName, err := iurl.SchemeFromURL(databaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
|
|||
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
|
||||
m := newCommon()
|
||||
|
||||
sourceName, err := schemeFromURL(sourceURL)
|
||||
sourceName, err := iurl.SchemeFromURL(sourceURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
|
|||
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
|
||||
m := newCommon()
|
||||
|
||||
databaseName, err := schemeFromURL(databaseURL)
|
||||
databaseName, err := iurl.SchemeFromURL(databaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
36
util.go
36
util.go
|
@ -1,7 +1,6 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
nurl "net/url"
|
||||
"strings"
|
||||
|
@ -49,41 +48,6 @@ func suint(n int) uint {
|
|||
return uint(n)
|
||||
}
|
||||
|
||||
var errNoScheme = errors.New("no scheme")
|
||||
var errEmptyURL = errors.New("URL cannot be empty")
|
||||
|
||||
func sourceSchemeFromURL(url string) (string, error) {
|
||||
u, err := schemeFromURL(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("source: %v", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func databaseSchemeFromURL(url string) (string, error) {
|
||||
u, err := schemeFromURL(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("database: %v", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// schemeFromURL returns the scheme from a URL string
|
||||
func schemeFromURL(url string) (string, error) {
|
||||
if url == "" {
|
||||
return "", errEmptyURL
|
||||
}
|
||||
|
||||
i := strings.Index(url, ":")
|
||||
|
||||
// No : or : is the first character.
|
||||
if i < 1 {
|
||||
return "", errNoScheme
|
||||
}
|
||||
|
||||
return url[0:i], nil
|
||||
}
|
||||
|
||||
// FilterCustomQuery filters all query values starting with `x-`
|
||||
func FilterCustomQuery(u *nurl.URL) *nurl.URL {
|
||||
ux := *u
|
||||
|
|
102
util_test.go
102
util_test.go
|
@ -1,7 +1,6 @@
|
|||
package migrate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nurl "net/url"
|
||||
"testing"
|
||||
)
|
||||
|
@ -31,104 +30,3 @@ func TestFilterCustomQuery(t *testing.T) {
|
|||
t.Fatalf("didn't expect x-custom")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceSchemeFromUrlSuccess(t *testing.T) {
|
||||
urlStr := "protocol://path"
|
||||
expected := "protocol"
|
||||
|
||||
u, err := sourceSchemeFromURL(urlStr)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but received %q", err)
|
||||
}
|
||||
if u != expected {
|
||||
t.Fatalf("expected %q, but received %q", expected, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceSchemeFromUrlFailure(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
urlStr: "",
|
||||
expectErr: errors.New("source: URL cannot be empty"),
|
||||
},
|
||||
{
|
||||
name: "NoScheme",
|
||||
urlStr: "hello",
|
||||
expectErr: errors.New("source: no scheme"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := sourceSchemeFromURL(tc.urlStr)
|
||||
if err.Error() != tc.expectErr.Error() {
|
||||
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple",
|
||||
urlStr: "protocol://path",
|
||||
expected: "protocol",
|
||||
},
|
||||
{
|
||||
// See issue #264
|
||||
name: "MySQLWithPort",
|
||||
urlStr: "mysql://user:pass@tcp(host:1337)/db",
|
||||
expected: "mysql",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
u, err := databaseSchemeFromURL(tc.urlStr)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but received %q", err)
|
||||
}
|
||||
if u != tc.expected {
|
||||
t.Fatalf("expected %q, but received %q", tc.expected, u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseSchemeFromUrlFailure(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
urlStr: "",
|
||||
expectErr: errors.New("database: URL cannot be empty"),
|
||||
},
|
||||
{
|
||||
name: "NoScheme",
|
||||
urlStr: "hello",
|
||||
expectErr: errors.New("database: no scheme"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := databaseSchemeFromURL(tc.urlStr)
|
||||
if err.Error() != tc.expectErr.Error() {
|
||||
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue