mirror of https://github.com/status-im/migrate.git
Properly filter out custom query params in MySQL DB driver
Addresses: https://github.com/golang-migrate/migrate/issues/272
This commit is contained in:
parent
a354c6d446
commit
0064ee83cf
|
@ -97,6 +97,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|||
return mx, nil
|
||||
}
|
||||
|
||||
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
|
||||
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
|
||||
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
|
||||
if c == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
customQueryParams := map[string]string{}
|
||||
|
||||
for k, v := range c.Params {
|
||||
if strings.HasPrefix(k, "x-") {
|
||||
customQueryParams[k] = v
|
||||
delete(c.Params, k)
|
||||
}
|
||||
}
|
||||
return customQueryParams, nil
|
||||
}
|
||||
|
||||
func urlToMySQLConfig(url string) (*mysql.Config, error) {
|
||||
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
|
||||
if err != nil {
|
||||
|
@ -174,6 +191,13 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("config: %+v\n", config)
|
||||
|
||||
customParams, err := extractCustomQueryParams(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("config: %+v\n", config)
|
||||
|
||||
db, err := sql.Open("mysql", config.FormatDSN())
|
||||
if err != nil {
|
||||
|
@ -182,7 +206,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
|||
|
||||
mx, err := WithInstance(db, &Config{
|
||||
DatabaseName: config.DBName,
|
||||
MigrationsTable: config.Params["x-migrations-table"],
|
||||
MigrationsTable: customParams["x-migrations-table"],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -6,17 +6,17 @@ import (
|
|||
sqldriver "database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"testing"
|
||||
)
|
||||
|
||||
import (
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
import (
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
@ -175,6 +175,62 @@ func TestLockWorks(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestExtractCustomQueryParams(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
config *mysql.Config
|
||||
expectedParams map[string]string
|
||||
expectedCustomParams map[string]string
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "nil config", expectedErr: ErrNilConfig},
|
||||
{
|
||||
name: "no params",
|
||||
config: mysql.NewConfig(),
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "no custom params",
|
||||
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "one param, one custom param",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{"hello": "world", "x-foo": "bar"},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "multiple params, multiple custom params",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{
|
||||
"hello": "world",
|
||||
"x-foo": "bar",
|
||||
"dead": "beef",
|
||||
"x-cat": "hat",
|
||||
},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
customParams, err := extractCustomQueryParams(tc.config)
|
||||
if tc.config != nil {
|
||||
assert.Equal(t, tc.expectedParams, tc.config.Params,
|
||||
"Expected config params have custom params properly removed")
|
||||
}
|
||||
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
|
||||
assert.Equal(t, tc.expectedCustomParams, customParams,
|
||||
"Expected custom params to be properly extracted")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLToMySQLConfig(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
|
|
Loading…
Reference in New Issue