diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 428fcb8..85afbfa 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -13,8 +13,13 @@ import ( nurl "net/url" "strconv" "strings" +) +import ( "github.com/go-sql-driver/mysql" +) + +import ( "github.com/golang-migrate/migrate" "github.com/golang-migrate/migrate/database" ) @@ -89,6 +94,25 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } +// urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. +// Manually sets username and password to avoid net/url from url-encoding the reserved URL characters +func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) { + origUserInfo := u.User + u.User = nil + + c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://")) + if err != nil { + return nil, err + } + if origUserInfo != nil { + c.User = origUserInfo.Username() + if p, ok := origUserInfo.Password(); ok { + c.Passwd = p + } + } + return c, nil +} + func (m *Mysql) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -99,8 +123,11 @@ func (m *Mysql) Open(url string) (database.Driver, error) { q.Set("multiStatements", "true") purl.RawQuery = q.Encode() - db, err := sql.Open("mysql", strings.Replace( - migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1)) + c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) + if err != nil { + return nil, err + } + db, err := sql.Open("mysql", c.FormatDSN()) if err != nil { return nil, err } diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index ae2c956..56925fe 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -4,11 +4,15 @@ import ( "database/sql" sqldriver "database/sql/driver" "fmt" - // "io/ioutil" - // "log" + "net/url" "testing" +) +import ( "github.com/go-sql-driver/mysql" +) + +import ( dt "github.com/golang-migrate/migrate/database/testing" mt "github.com/golang-migrate/migrate/testing" ) @@ -97,3 +101,55 @@ func TestLockWorks(t *testing.T) { } }) } + +func TestURLToMySQLConfig(t *testing.T) { + testcases := []struct { + name string + urlStr string + expectedDSN string // empty string signifies that an error is expected + }{ + {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "only user - with encoded :", + urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "only user - with encoded @", + urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 + // {name: "user/password - user with encoded :", + // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "user/password - user with encoded @", + urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "user/password - password with encoded :", + urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + {name: "user/password - password with encoded @", + urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", + expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + u, err := url.Parse(tc.urlStr) + if err != nil { + t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) + } + if config, err := urlToMySQLConfig(*u); err == nil { + dsn := config.FormatDSN() + if dsn != tc.expectedDSN { + t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) + } + } else { + if tc.expectedDSN != "" { + t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr) + } + } + }) + } +}