Merge branch 'master' into fix-mongodb-dep

This commit is contained in:
zikaeroh 2019-08-17 08:13:16 -07:00
commit d80e0e2f7f
5 changed files with 70 additions and 66 deletions

View File

@ -21,7 +21,6 @@ import (
)
import (
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
)
@ -98,43 +97,35 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil
}
// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
origUserInfo := u.User
u.User = nil
c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
if err != nil {
return nil, err
}
if origUserInfo != nil {
c.User = origUserInfo.Username()
if p, ok := origUserInfo.Password(); ok {
c.Passwd = p
}
}
return c, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
func urlToMySQLConfig(url string) (*mysql.Config, error) {
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
return nil, err
}
q := purl.Query()
q.Set("multiStatements", "true")
purl.RawQuery = q.Encode()
config.MultiStatements = true
migrationsTable := purl.Query().Get("x-migrations-table")
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
// net/url.Parse() would automatically unescape it for us.
// See: https://play.golang.org/p/q9j1io-YICQ
user, err := nurl.QueryUnescape(config.User)
if err != nil {
return nil, err
}
config.User = user
password, err := nurl.QueryUnescape(config.Passwd)
if err != nil {
return nil, err
}
config.Passwd = password
// use custom TLS?
ctls := purl.Query().Get("tls")
ctls := config.TLSConfig
if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
if err != nil {
return nil, err
}
@ -144,7 +135,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}
clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig
}
@ -156,8 +147,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}
insecureSkipVerify := false
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
if err != nil {
return nil, err
}
@ -175,18 +166,23 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
}
}
c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
return config, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
config, err := urlToMySQLConfig(url)
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", c.FormatDSN())
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
DatabaseName: config.DBName,
MigrationsTable: config.Params["x-migrations-table"],
})
if err != nil {
return nil, err

View File

@ -8,7 +8,6 @@ import (
"log"
"github.com/golang-migrate/migrate/v4"
"net/url"
"testing"
)
@ -210,19 +209,13 @@ func TestURLToMySQLConfig(t *testing.T) {
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.urlStr)
config, err := urlToMySQLConfig(tc.urlStr)
if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
}
if config, err := urlToMySQLConfig(*u); err == nil {
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
} else {
if tc.expectedDSN != "" {
t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
}
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
})
}

View File

@ -139,8 +139,7 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
}{
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
{char: "#", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "#" + urlSuffixAndSep},
{char: "#", parses: false},
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
{char: "%", parses: false},
@ -158,16 +157,14 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
{char: "/", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "/" + urlSuffixAndSep},
{char: "/", parses: false},
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
{char: "?", parses: true, expectedUsername: "", expectedPassword: "",
encodedURL: schemeAndUsernameAndSep + basePassword + "?" + urlSuffixAndSep},
{char: "?", parses: false},
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
{char: "[", parses: false},

11
util.go
View File

@ -74,15 +74,14 @@ func schemeFromURL(url string) (string, error) {
return "", errEmptyURL
}
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if len(u.Scheme) == 0 {
i := strings.Index(url, ":")
// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}
return u.Scheme, nil
return url[0:i], nil
}
// FilterCustomQuery filters all query values starting with `x-`

View File

@ -74,15 +74,34 @@ func TestSourceSchemeFromUrlFailure(t *testing.T) {
}
func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
urlStr := "protocol://path"
expected := "protocol"
u, err := databaseSchemeFromURL(urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
cases := []struct {
name string
urlStr string
expected string
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
}
if u != expected {
t.Fatalf("expected %q, but received %q", expected, u)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u, err := databaseSchemeFromURL(tc.urlStr)
if err != nil {
t.Fatalf("expected no error, but received %q", err)
}
if u != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, u)
}
})
}
}