diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 96cb600..5cc892a 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -21,7 +21,6 @@ import ( ) import ( - "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" ) @@ -98,43 +97,35 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } -// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. -// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters -func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) { - origUserInfo := u.User - u.User = nil - - c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://")) - if err != nil { - return nil, err - } - if origUserInfo != nil { - c.User = origUserInfo.Username() - if p, ok := origUserInfo.Password(); ok { - c.Passwd = p - } - } - return c, nil -} - -func (m *Mysql) Open(url string) (database.Driver, error) { - purl, err := nurl.Parse(url) +func urlToMySQLConfig(url string) (*mysql.Config, error) { + config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) if err != nil { return nil, err } - q := purl.Query() - q.Set("multiStatements", "true") - purl.RawQuery = q.Encode() + config.MultiStatements = true - migrationsTable := purl.Query().Get("x-migrations-table") + // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. + // net/url.Parse() would automatically unescape it for us. + // See: https://play.golang.org/p/q9j1io-YICQ + user, err := nurl.QueryUnescape(config.User) + if err != nil { + return nil, err + } + config.User = user + + password, err := nurl.QueryUnescape(config.Passwd) + if err != nil { + return nil, err + } + config.Passwd = password // use custom TLS? - ctls := purl.Query().Get("tls") + ctls := config.TLSConfig if len(ctls) > 0 { if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { rootCertPool := x509.NewCertPool() - pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) + pem, err := ioutil.ReadFile(config.Params["x-tls-ca"]) if err != nil { return nil, err } @@ -144,7 +135,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) { } clientCert := make([]tls.Certificate, 0, 1) - if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" { + if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" { if ccert == "" || ckey == "" { return nil, ErrTLSCertKeyConfig } @@ -156,8 +147,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) { } insecureSkipVerify := false - if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { - x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) + if len(config.Params["x-tls-insecure-skip-verify"]) > 0 { + x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"]) if err != nil { return nil, err } @@ -175,18 +166,23 @@ func (m *Mysql) Open(url string) (database.Driver, error) { } } - c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) + return config, nil +} + +func (m *Mysql) Open(url string) (database.Driver, error) { + config, err := urlToMySQLConfig(url) if err != nil { return nil, err } - db, err := sql.Open("mysql", c.FormatDSN()) + + db, err := sql.Open("mysql", config.FormatDSN()) if err != nil { return nil, err } mx, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, - MigrationsTable: migrationsTable, + DatabaseName: config.DBName, + MigrationsTable: config.Params["x-migrations-table"], }) if err != nil { return nil, err diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index e5f73ff..5d6e82e 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -8,7 +8,6 @@ import ( "log" "github.com/golang-migrate/migrate/v4" - "net/url" "testing" ) @@ -210,19 +209,13 @@ func TestURLToMySQLConfig(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - u, err := url.Parse(tc.urlStr) + config, err := urlToMySQLConfig(tc.urlStr) if err != nil { t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) } - if config, err := urlToMySQLConfig(*u); err == nil { - dsn := config.FormatDSN() - if dsn != tc.expectedDSN { - t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) - } - } else { - if tc.expectedDSN != "" { - t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr) - } + dsn := config.FormatDSN() + if dsn != tc.expectedDSN { + t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) } }) } diff --git a/database/parse_test.go b/database/parse_test.go index 6558e25..3709a67 100644 --- a/database/parse_test.go +++ b/database/parse_test.go @@ -139,8 +139,7 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) { }{ {char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!", encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep}, - {char: "#", parses: true, expectedUsername: "", expectedPassword: "", - encodedURL: schemeAndUsernameAndSep + basePassword + "#" + urlSuffixAndSep}, + {char: "#", parses: false}, {char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$", encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep}, {char: "%", parses: false}, @@ -158,16 +157,14 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) { encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep}, {char: ",", parses: true, expectedUsername: username, expectedPassword: "password,", encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep}, - {char: "/", parses: true, expectedUsername: "", expectedPassword: "", - encodedURL: schemeAndUsernameAndSep + basePassword + "/" + urlSuffixAndSep}, + {char: "/", parses: false}, {char: ":", parses: true, expectedUsername: username, expectedPassword: "password:", encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep}, {char: ";", parses: true, expectedUsername: username, expectedPassword: "password;", encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep}, {char: "=", parses: true, expectedUsername: username, expectedPassword: "password=", encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep}, - {char: "?", parses: true, expectedUsername: "", expectedPassword: "", - encodedURL: schemeAndUsernameAndSep + basePassword + "?" + urlSuffixAndSep}, + {char: "?", parses: false}, {char: "@", parses: true, expectedUsername: username, expectedPassword: "password@", encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep}, {char: "[", parses: false}, diff --git a/util.go b/util.go index 1cef03c..ecf3773 100644 --- a/util.go +++ b/util.go @@ -74,15 +74,14 @@ func schemeFromURL(url string) (string, error) { return "", errEmptyURL } - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - if len(u.Scheme) == 0 { + i := strings.Index(url, ":") + + // No : or : is the first character. + if i < 1 { return "", errNoScheme } - return u.Scheme, nil + return url[0:i], nil } // FilterCustomQuery filters all query values starting with `x-` diff --git a/util_test.go b/util_test.go index 6543b28..ef395e8 100644 --- a/util_test.go +++ b/util_test.go @@ -74,15 +74,34 @@ func TestSourceSchemeFromUrlFailure(t *testing.T) { } func TestDatabaseSchemeFromUrlSuccess(t *testing.T) { - urlStr := "protocol://path" - expected := "protocol" - - u, err := databaseSchemeFromURL(urlStr) - if err != nil { - t.Fatalf("expected no error, but received %q", err) + cases := []struct { + name string + urlStr string + expected string + }{ + { + name: "Simple", + urlStr: "protocol://path", + expected: "protocol", + }, + { + // See issue #264 + name: "MySQLWithPort", + urlStr: "mysql://user:pass@tcp(host:1337)/db", + expected: "mysql", + }, } - if u != expected { - t.Fatalf("expected %q, but received %q", expected, u) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + u, err := databaseSchemeFromURL(tc.urlStr) + if err != nil { + t.Fatalf("expected no error, but received %q", err) + } + if u != tc.expected { + t.Fatalf("expected %q, but received %q", tc.expected, u) + } + }) } }