diff --git a/driver/driver.go b/driver/driver.go index 9c00074..7192f7f 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -2,7 +2,6 @@ package driver import ( - "errors" "fmt" neturl "net/url" // alias to allow `url string` func signature in New @@ -14,6 +13,14 @@ import ( "github.com/mattes/migrate/file" ) +var driverMap = map[string]Driver{ + "postgres": &postgres.Driver{}, + "mysql": &mysql.Driver{}, + "bash": &bash.Driver{}, + "cassandra": &cassandra.Driver{}, + "sqlite3": &sqlite3.Driver{}, +} + // Driver is the interface type that needs to implemented by all drivers. type Driver interface { @@ -47,48 +54,14 @@ func New(url string) (Driver, error) { return nil, err } - switch u.Scheme { - case "postgres": - d := &postgres.Driver{} - verifyFilenameExtension("postgres", d) + if d, found := driverMap[u.Scheme]; found { + verifyFilenameExtension(u.Scheme, d) if err := d.Initialize(url); err != nil { return nil, err } return d, nil - - case "mysql": - d := &mysql.Driver{} - verifyFilenameExtension("mysql", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - - case "bash": - d := &bash.Driver{} - verifyFilenameExtension("bash", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - - case "cassandra": - d := &cassandra.Driver{} - verifyFilenameExtension("cassanda", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - case "sqlite3": - d := &sqlite3.Driver{} - verifyFilenameExtension("sqlite3", d) - if err := d.Initialize(url); err != nil { - return nil, err - } - return d, nil - default: - return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme)) } + return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme) } // verifyFilenameExtension panics if the drivers filename extension diff --git a/driver/driver_test.go b/driver/driver_test.go index d0b11d6..b9c191b 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -1,8 +1,6 @@ package driver -import ( - "testing" -) +import "testing" func TestNew(t *testing.T) { if _, err := New("unknown://url"); err == nil {