mirror of https://github.com/status-im/migrate.git
Merge branch 'master' into fix-mongodb-dep
This commit is contained in:
commit
d80e0e2f7f
|
@ -21,7 +21,6 @@ import (
|
|||
)
|
||||
|
||||
import (
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
)
|
||||
|
||||
|
@ -98,43 +97,35 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|||
return mx, nil
|
||||
}
|
||||
|
||||
// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
|
||||
// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
|
||||
func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
|
||||
origUserInfo := u.User
|
||||
u.User = nil
|
||||
|
||||
c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if origUserInfo != nil {
|
||||
c.User = origUserInfo.Username()
|
||||
if p, ok := origUserInfo.Password(); ok {
|
||||
c.Passwd = p
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
func urlToMySQLConfig(url string) (*mysql.Config, error) {
|
||||
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := purl.Query()
|
||||
q.Set("multiStatements", "true")
|
||||
purl.RawQuery = q.Encode()
|
||||
config.MultiStatements = true
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
|
||||
// net/url.Parse() would automatically unescape it for us.
|
||||
// See: https://play.golang.org/p/q9j1io-YICQ
|
||||
user, err := nurl.QueryUnescape(config.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.User = user
|
||||
|
||||
password, err := nurl.QueryUnescape(config.Passwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Passwd = password
|
||||
|
||||
// use custom TLS?
|
||||
ctls := purl.Query().Get("tls")
|
||||
ctls := config.TLSConfig
|
||||
if len(ctls) > 0 {
|
||||
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
|
||||
pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -144,7 +135,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
|||
}
|
||||
|
||||
clientCert := make([]tls.Certificate, 0, 1)
|
||||
if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
|
||||
if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
|
||||
if ccert == "" || ckey == "" {
|
||||
return nil, ErrTLSCertKeyConfig
|
||||
}
|
||||
|
@ -156,8 +147,8 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
|||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
|
||||
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
|
||||
if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
|
||||
x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -175,18 +166,23 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
|||
}
|
||||
}
|
||||
|
||||
c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Open(url string) (database.Driver, error) {
|
||||
config, err := urlToMySQLConfig(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err := sql.Open("mysql", c.FormatDSN())
|
||||
|
||||
db, err := sql.Open("mysql", config.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
DatabaseName: config.DBName,
|
||||
MigrationsTable: config.Params["x-migrations-table"],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -210,19 +209,13 @@ func TestURLToMySQLConfig(t *testing.T) {
|
|||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tc.urlStr)
|
||||
config, err := urlToMySQLConfig(tc.urlStr)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
|
||||
}
|
||||
if config, err := urlToMySQLConfig(*u); err == nil {
|
||||
dsn := config.FormatDSN()
|
||||
if dsn != tc.expectedDSN {
|
||||
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
|
||||
}
|
||||
} else {
|
||||
if tc.expectedDSN != "" {
|
||||
t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
|
||||
}
|
||||
dsn := config.FormatDSN()
|
||||
if dsn != tc.expectedDSN {
|
||||
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -139,8 +139,7 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
|
|||
}{
|
||||
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
|
||||
{char: "#", parses: true, expectedUsername: "", expectedPassword: "",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "#" + urlSuffixAndSep},
|
||||
{char: "#", parses: false},
|
||||
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
|
||||
{char: "%", parses: false},
|
||||
|
@ -158,16 +157,14 @@ func TestPasswordUnencodedReservedURLChars(t *testing.T) {
|
|||
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
|
||||
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
|
||||
{char: "/", parses: true, expectedUsername: "", expectedPassword: "",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "/" + urlSuffixAndSep},
|
||||
{char: "/", parses: false},
|
||||
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
|
||||
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
|
||||
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
|
||||
{char: "?", parses: true, expectedUsername: "", expectedPassword: "",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "?" + urlSuffixAndSep},
|
||||
{char: "?", parses: false},
|
||||
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
|
||||
{char: "[", parses: false},
|
||||
|
|
11
util.go
11
util.go
|
@ -74,15 +74,14 @@ func schemeFromURL(url string) (string, error) {
|
|||
return "", errEmptyURL
|
||||
}
|
||||
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(u.Scheme) == 0 {
|
||||
i := strings.Index(url, ":")
|
||||
|
||||
// No : or : is the first character.
|
||||
if i < 1 {
|
||||
return "", errNoScheme
|
||||
}
|
||||
|
||||
return u.Scheme, nil
|
||||
return url[0:i], nil
|
||||
}
|
||||
|
||||
// FilterCustomQuery filters all query values starting with `x-`
|
||||
|
|
35
util_test.go
35
util_test.go
|
@ -74,15 +74,34 @@ func TestSourceSchemeFromUrlFailure(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatabaseSchemeFromUrlSuccess(t *testing.T) {
|
||||
urlStr := "protocol://path"
|
||||
expected := "protocol"
|
||||
|
||||
u, err := databaseSchemeFromURL(urlStr)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but received %q", err)
|
||||
cases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple",
|
||||
urlStr: "protocol://path",
|
||||
expected: "protocol",
|
||||
},
|
||||
{
|
||||
// See issue #264
|
||||
name: "MySQLWithPort",
|
||||
urlStr: "mysql://user:pass@tcp(host:1337)/db",
|
||||
expected: "mysql",
|
||||
},
|
||||
}
|
||||
if u != expected {
|
||||
t.Fatalf("expected %q, but received %q", expected, u)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
u, err := databaseSchemeFromURL(tc.urlStr)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but received %q", err)
|
||||
}
|
||||
if u != tc.expected {
|
||||
t.Fatalf("expected %q, but received %q", tc.expected, u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue