Don't load in all drivers by default #40

Requires activating drivers with a _ style import, e.g.
import "_ github.com/mattes/migrate/driver/postgres"
This commit is contained in:
Dave Jeffrey 2015-06-11 11:11:28 +01:00
parent 9bb037339f
commit 0741616d2e
9 changed files with 63 additions and 41 deletions

View File

@ -78,6 +78,9 @@ See GoDoc here: http://godoc.org/github.com/mattes/migrate/migrate
```go ```go
import "github.com/mattes/migrate/migrate" import "github.com/mattes/migrate/migrate"
// Import any required drivers so that they are registered and available
import _ "github.com/mattes/migrate/drivers/mysql"
// use synchronous versions of migration functions ... // use synchronous versions of migration functions ...
allErrors, ok := migrate.UpSync("driver://url", "./path") allErrors, ok := migrate.UpSync("driver://url", "./path")
if !ok { if !ok {

View File

@ -2,6 +2,7 @@
package bash package bash
import ( import (
"github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
_ "github.com/mattes/migrate/migrate/direction" _ "github.com/mattes/migrate/migrate/direction"
) )
@ -30,3 +31,7 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
func (driver *Driver) Version() (uint64, error) { func (driver *Driver) Version() (uint64, error) {
return uint64(0), nil return uint64(0), nil
} }
func init() {
registry.RegisterDriver("bash", Driver{})
}

View File

@ -4,6 +4,7 @@ package cassandra
import ( import (
"fmt" "fmt"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction" "github.com/mattes/migrate/migrate/direction"
"net/url" "net/url"
@ -153,3 +154,7 @@ func (driver *Driver) Version() (uint64, error) {
err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version) err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version)
return uint64(version) - 1, err return uint64(version) - 1, err
} }
func init() {
registry.RegisterDriver("cassandra", Driver{})
}

View File

@ -5,12 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
neturl "net/url" // alias to allow `url string` func signature in New neturl "net/url" // alias to allow `url string` func signature in New
"reflect"
"github.com/mattes/migrate/driver/bash" "github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/driver/cassandra"
"github.com/mattes/migrate/driver/mysql"
"github.com/mattes/migrate/driver/postgres"
"github.com/mattes/migrate/driver/sqlite3"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
) )
@ -47,51 +44,26 @@ func New(url string) (Driver, error) {
return nil, err return nil, err
} }
switch u.Scheme { driver := registry.GetDriver(u.Scheme)
case "postgres": if driver != nil {
d := &postgres.Driver{} blankDriver := reflect.New(reflect.TypeOf(driver)).Interface()
verifyFilenameExtension("postgres", d) d, ok := blankDriver.(Driver)
if !ok {
err := errors.New(fmt.Sprintf("Driver '%s' does not implement the Driver interface"))
return nil, err
}
verifyFilenameExtension(u.Scheme, d)
if err := d.Initialize(url); err != nil { if err := d.Initialize(url); err != nil {
return nil, err return nil, err
} }
return d, nil
case "mysql":
d := &mysql.Driver{}
verifyFilenameExtension("mysql", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil return d, nil
} else {
case "bash":
d := &bash.Driver{}
verifyFilenameExtension("bash", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "cassandra":
d := &cassandra.Driver{}
verifyFilenameExtension("cassanda", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "sqlite3":
d := &sqlite3.Driver{}
verifyFilenameExtension("sqlite3", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme)) return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
} }
} }
// verifyFilenameExtension panics if the drivers filename extension // verifyFilenameExtension panics if the driver's filename extension
// is not correct or empty. // is not correct or empty.
func verifyFilenameExtension(driverName string, d Driver) { func verifyFilenameExtension(driverName string, d Driver) {
f := d.FilenameExtension() f := d.FilenameExtension()

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction" "github.com/mattes/migrate/migrate/direction"
"regexp" "regexp"
@ -177,3 +178,7 @@ func (driver *Driver) Version() (uint64, error) {
return version, nil return version, nil
} }
} }
func init() {
registry.RegisterDriver("mysql", Driver{})
}

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction" "github.com/mattes/migrate/migrate/direction"
"strconv" "strconv"
@ -119,3 +120,7 @@ func (driver *Driver) Version() (uint64, error) {
return version, nil return version, nil
} }
} }
func init() {
registry.RegisterDriver("postgres", Driver{})
}

View File

@ -0,0 +1,20 @@
// Package registry maintains a map of imported and available drivers
package registry
var driverRegistry map[string]interface{}
// Registers a driver so it can be created from its name. Drivers should
// call this from an init() function so that they registers themselvse on
// import
func RegisterDriver(name string, driver interface{}) {
driverRegistry[name] = driver
}
// Retrieves a registered driver by name
func GetDriver(name string) interface{} {
return driverRegistry[name]
}
func init() {
driverRegistry = make(map[string]interface{})
}

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"github.com/mattes/migrate/driver/registry"
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction" "github.com/mattes/migrate/migrate/direction"
"github.com/mattn/go-sqlite3" "github.com/mattn/go-sqlite3"
@ -123,3 +124,7 @@ func (driver *Driver) Version() (uint64, error) {
return version, nil return version, nil
} }
} }
func init() {
registry.RegisterDriver("sqlite3", Driver{})
}

View File

@ -3,6 +3,8 @@ package migrate
import ( import (
"io/ioutil" "io/ioutil"
"testing" "testing"
// Ensure imports for each driver we wish to test
_ "github.com/mattes/migrate/driver/postgres"
) )
// Add Driver URLs here to test basic Up, Down, .. functions. // Add Driver URLs here to test basic Up, Down, .. functions.