diff --git a/migrate.go b/migrate.go index 1a3ba6a..690a671 100644 --- a/migrate.go +++ b/migrate.go @@ -82,13 +82,13 @@ type Migrate struct { func New(sourceUrl, databaseUrl string) (*Migrate, error) { m := newCommon() - sourceName, err := schemeFromUrl(sourceUrl) + sourceName, err := sourceSchemeFromUrl(sourceUrl) if err != nil { return nil, err } m.sourceName = sourceName - databaseName, err := schemeFromUrl(databaseUrl) + databaseName, err := databaseSchemeFromUrl(databaseUrl) if err != nil { return nil, err } diff --git a/util.go b/util.go index 00efa23..96b6746 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,7 @@ package migrate import ( + "errors" "fmt" nurl "net/url" "strings" @@ -43,15 +44,35 @@ func suint(n int) uint { return uint(n) } -var errNoScheme = fmt.Errorf("no scheme") +var errNoScheme = errors.New("no scheme") +var errEmptyURL = errors.New("URL cannot be empty") + +func sourceSchemeFromUrl(url string) (string, error) { + u, err := schemeFromUrl(url) + if err != nil { + return "", fmt.Errorf("source: %v", err) + } + return u, nil +} + +func databaseSchemeFromUrl(url string) (string, error) { + u, err := schemeFromUrl(url) + if err != nil { + return "", fmt.Errorf("database: %v", err) + } + return u, nil +} // schemeFromUrl returns the scheme from a URL string func schemeFromUrl(url string) (string, error) { + if url == "" { + return "", errEmptyURL + } + u, err := nurl.Parse(url) if err != nil { return "", err } - if len(u.Scheme) == 0 { return "", errNoScheme } diff --git a/util_test.go b/util_test.go index 1ad2344..b484341 100644 --- a/util_test.go +++ b/util_test.go @@ -1,6 +1,7 @@ package migrate import ( + "errors" nurl "net/url" "testing" ) @@ -30,3 +31,85 @@ func TestFilterCustomQuery(t *testing.T) { t.Fatalf("didn't expect x-custom") } } + +func TestSourceSchemeFromUrlSuccess(t *testing.T) { + urlStr := "protocol://path" + expected := "protocol" + + u, err := sourceSchemeFromUrl(urlStr) + if err != nil { + t.Fatalf("expected no error, but received %q", err) + } + if u != expected { + t.Fatalf("expected %q, but received %q", expected, u) + } +} + +func TestSourceSchemeFromUrlFailure(t *testing.T) { + cases := []struct { + name string + urlStr string + expectErr error + }{ + { + name: "Empty", + urlStr: "", + expectErr: errors.New("source: URL cannot be empty"), + }, + { + name: "NoScheme", + urlStr: "hello", + expectErr: errors.New("source: no scheme"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := sourceSchemeFromUrl(tc.urlStr) + if err.Error() != tc.expectErr.Error() { + t.Fatalf("expected %q, but received %q", tc.expectErr, err) + } + }) + } +} + +func TestDatabaseSchemeFromUrlSuccess(t *testing.T) { + urlStr := "protocol://path" + expected := "protocol" + + u, err := databaseSchemeFromUrl(urlStr) + if err != nil { + t.Fatalf("expected no error, but received %q", err) + } + if u != expected { + t.Fatalf("expected %q, but received %q", expected, u) + } +} + +func TestDatabaseSchemeFromUrlFailure(t *testing.T) { + cases := []struct { + name string + urlStr string + expectErr error + }{ + { + name: "Empty", + urlStr: "", + expectErr: errors.New("database: URL cannot be empty"), + }, + { + name: "NoScheme", + urlStr: "hello", + expectErr: errors.New("database: no scheme"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := databaseSchemeFromUrl(tc.urlStr) + if err.Error() != tc.expectErr.Error() { + t.Fatalf("expected %q, but received %q", tc.expectErr, err) + } + }) + } +}