Merge branch 'master' into fix-mongodb-dep

This commit is contained in:
zikaeroh 2019-08-17 08:13:16 -07:00
commit d80e0e2f7f
5 changed files with 70 additions and 66 deletions

View File

@ -21,7 +21,6 @@ import (
) )
import ( import (
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/database"
) )
@ -98,43 +97,35 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil return mx, nil
} }
// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. func urlToMySQLConfig(url string) (*mysql.Config, error) {
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
origUserInfo := u.User
u.User = nil
c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
if err != nil {
return nil, err
}
if origUserInfo != nil {
c.User = origUserInfo.Username()
if p, ok := origUserInfo.Password(); ok {
c.Passwd = p
}
}
return c, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
q := purl.Query() config.MultiStatements = true
q.Set("multiStatements", "true")
purl.RawQuery = q.Encode()
migrationsTable := purl.Query().Get("x-migrations-table") // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
// net/url.Parse() would automatically unescape it for us.
// See: https://play.golang.org/p/q9j1io-YICQ
user, err := nurl.QueryUnescape(config.User)
if err != nil {
return nil, err
}
config.User = user
password, err := nurl.QueryUnescape(config.Passwd)
if err != nil {
return nil, err
}
config.Passwd = password
// use custom TLS? // use custom TLS?
ctls := purl.Query().Get("tls") ctls := config.TLSConfig
if len(ctls) > 0 { if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool() rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -144,7 +135,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
} }
clientCert := make([]tls.Certificate, 0, 1) clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" { if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
if ccert == "" || ckey == "" { if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig return nil, ErrTLSCertKeyConfig
} }
@ -156,8 +147,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
} }
insecureSkipVerify := false insecureSkipVerify := false
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -175,18 +166,23 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
} }
} }
c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) return config, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
config, err := urlToMySQLConfig(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db, err := sql.Open("mysql", c.FormatDSN())
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil { if err != nil {
return nil, err return nil, err
} }
mx, err := WithInstance(db, &Config{ mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path, DatabaseName: config.DBName,
MigrationsTable: migrationsTable, MigrationsTable: config.Params["x-migrations-table"],
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -8,7 +8,6 @@ import (
"log" "log"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
"net/url"
"testing" "testing"
) )
@ -210,19 +209,13 @@ func TestURLToMySQLConfig(t *testing.T) {
} }
for _, tc := range testcases { for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.urlStr) config, err := urlToMySQLConfig(tc.urlStr)
if err != nil { if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
} }
if config, err := urlToMySQLConfig(*u); err == nil { dsn := config.FormatDSN()
dsn := config.FormatDSN() if dsn != tc.expectedDSN {
if dsn != tc.expectedDSN { t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
} else {
if tc.expectedDSN != "" {
t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
}
} }
}) })
} }

View File

@ -139,8 +139,7 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
}{ }{
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!", {char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
{char: "#", parses: true, expectedUsername: "", expectedPassword: "", {char: "#", parses: false},
encodedURL: schemeAndUsernameAndSep + basePassword + "#" + urlSuffixAndSep},
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$", {char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
{char: "%", parses: false}, {char: "%", parses: false},
@ -158,16 +157,14 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,", {char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
{char: "/", parses: true, expectedUsername: "", expectedPassword: "", {char: "/", parses: false},
encodedURL: schemeAndUsernameAndSep + basePassword + "/" + urlSuffixAndSep},
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:", {char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;", {char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=", {char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
{char: "?", parses: true, expectedUsername: "", expectedPassword: "", {char: "?", parses: false},
encodedURL: schemeAndUsernameAndSep + basePassword + "?" + urlSuffixAndSep},
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@", {char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep}, encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
{char: "[", parses: false}, {char: "[", parses: false},

11
util.go
View File

@ -74,15 +74,14 @@ func schemeFromURL(url string) (string, error) {
return "", errEmptyURL return "", errEmptyURL
} }
u, err := nurl.Parse(url) i := strings.Index(url, ":")
if err != nil {
return "", err // No : or : is the first character.
} if i < 1 {
if len(u.Scheme) == 0 {
return "", errNoScheme return "", errNoScheme
} }
return u.Scheme, nil return url[0:i], nil
} }
// FilterCustomQuery filters all query values starting with `x-` // FilterCustomQuery filters all query values starting with `x-`

View File

@ -74,15 +74,34 @@ func TestSourceSchemeFromUrlFailure(t *testing.T) {
} }
func TestDatabaseSchemeFromUrlSuccess(t *testing.T) { func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
urlStr := "protocol://path" cases := []struct {
expected := "protocol" name string
urlStr string
u, err := databaseSchemeFromURL(urlStr) expected string
if err != nil { }{
t.Fatalf("expected no error, but received %q", err) {
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
} }
if u != expected {
t.Fatalf("expected %q, but received %q", expected, u) for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u, err := databaseSchemeFromURL(tc.urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
}
if u != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, u)
}
})
} }
} }