Properly filter out custom query params in MySQL DB driver

Addresses: https://github.com/golang-migrate/migrate/issues/272
This commit is contained in:
Dale Hui 2019-08-22 00:00:32 -07:00
parent a354c6d446
commit 0064ee83cf
2 changed files with 83 additions and 3 deletions

View File

@ -97,6 +97,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil
}
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
if c == nil {
return nil, ErrNilConfig
}
customQueryParams := map[string]string{}
for k, v := range c.Params {
if strings.HasPrefix(k, "x-") {
customQueryParams[k] = v
delete(c.Params, k)
}
}
return customQueryParams, nil
}
func urlToMySQLConfig(url string) (*mysql.Config, error) {
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
@ -174,6 +191,13 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
if err != nil {
return nil, err
}
fmt.Printf("config: %+v\n", config)
customParams, err := extractCustomQueryParams(config)
if err != nil {
return nil, err
}
fmt.Printf("config: %+v\n", config)
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
@ -182,7 +206,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
mx, err := WithInstance(db, &Config{
DatabaseName: config.DBName,
MigrationsTable: config.Params["x-migrations-table"],
MigrationsTable: customParams["x-migrations-table"],
})
if err != nil {
return nil, err

View File

@ -6,17 +6,17 @@ import (
sqldriver "database/sql/driver"
"fmt"
"log"
"github.com/golang-migrate/migrate/v4"
"testing"
)
import (
"github.com/dhui/dktest"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
import (
"github.com/golang-migrate/migrate/v4"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
@ -175,6 +175,62 @@ func TestLockWorks(t *testing.T) {
})
}
func TestExtractCustomQueryParams(t *testing.T) {
testcases := []struct {
name string
config *mysql.Config
expectedParams map[string]string
expectedCustomParams map[string]string
expectedErr error
}{
{name: "nil config", expectedErr: ErrNilConfig},
{
name: "no params",
config: mysql.NewConfig(),
expectedCustomParams: map[string]string{},
},
{
name: "no custom params",
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{},
},
{
name: "one param, one custom param",
config: &mysql.Config{
Params: map[string]string{"hello": "world", "x-foo": "bar"},
},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{"x-foo": "bar"},
},
{
name: "multiple params, multiple custom params",
config: &mysql.Config{
Params: map[string]string{
"hello": "world",
"x-foo": "bar",
"dead": "beef",
"x-cat": "hat",
},
},
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
customParams, err := extractCustomQueryParams(tc.config)
if tc.config != nil {
assert.Equal(t, tc.expectedParams, tc.config.Params,
"Expected config params have custom params properly removed")
}
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
assert.Equal(t, tc.expectedCustomParams, customParams,
"Expected custom params to be properly extracted")
})
}
}
func TestURLToMySQLConfig(t *testing.T) {
testcases := []struct {
name string