drop custom query params

This commit is contained in:
Matthias Kadenbach 2017-02-17 16:59:47 -08:00
parent f45821581c
commit d574676702
No known key found for this signature in database
GPG Key ID: DC1F4DC6D31A7031
4 changed files with 48 additions and 10 deletions

View File

@ -9,6 +9,7 @@ import (
nurl "net/url"
"github.com/lib/pq"
"github.com/mattes/migrate"
"github.com/mattes/migrate/database"
)
@ -30,6 +31,14 @@ type Config struct {
DatabaseName string
}
type Postgres struct {
db *sql.DB
isLocked bool
// Open and WithInstance need to garantuee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
@ -63,21 +72,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return px, nil
}
type Postgres struct {
db *sql.DB
isLocked bool
// Open and WithInstance need to garantuee that config is never nil
config *Config
}
func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
db, err := sql.Open("postgres", url)
db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
@ -155,7 +156,7 @@ func (p *Postgres) Run(version int, migration io.Reader) error {
if dirty, err := p.isDirty(); err != nil {
return err
} else if dirty {
return ErrDatabaseDirty
return ErrDatabaseDirty // TODO: add more verbose error
}
if migration == nil {

View File

@ -78,6 +78,18 @@ func TestMultiStatement(t *testing.T) {
})
}
func TestFilterCustomQuery(t *testing.T) {
mt.ParallelTest(t, versions, isReady,
func(t *testing.T, i mt.Instance) {
p := &Postgres{}
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&x-custom=foobar", i.Host(), i.Port())
_, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
})
}
func TestWithSchema(t *testing.T) {
mt.ParallelTest(t, versions, isReady,
func(t *testing.T, i mt.Instance) {

13
util.go
View File

@ -90,3 +90,16 @@ func schemeFromUrl(url string) (string, error) {
return u.Scheme, nil
}
// FilterCustomQuery filters all query values starting with `x-`
func FilterCustomQuery(u *nurl.URL) *nurl.URL {
ux := *u
vx := make(nurl.Values)
for k, v := range ux.Query() {
if len(k) <= 1 || (len(k) > 1 && k[0:2] != "x-") {
vx[k] = v
}
}
ux.RawQuery = vx.Encode()
return &ux
}

View File

@ -1,6 +1,7 @@
package migrate
import (
nurl "net/url"
"testing"
)
@ -18,3 +19,14 @@ func TestSuint(t *testing.T) {
t.Fatalf("expected 0, got %v", u)
}
}
func TestFilterCustomQuery(t *testing.T) {
n, err := nurl.Parse("foo://host?a=b&x-custom=foo&c=d")
if err != nil {
t.Fatal(err)
}
nx := FilterCustomQuery(n).Query()
if nx.Get("x-custom") != "" {
t.Fatalf("didn't expect x-custom")
}
}