streamline driver code

This commit is contained in:
buddhamagnet 2015-09-15 22:56:29 +01:00
parent e857fcc785
commit 02c5fc24d3
2 changed files with 12 additions and 41 deletions

View File

@ -2,7 +2,6 @@
package driver package driver
import ( import (
"errors"
"fmt" "fmt"
neturl "net/url" // alias to allow `url string` func signature in New neturl "net/url" // alias to allow `url string` func signature in New
@ -14,6 +13,14 @@ import (
"github.com/mattes/migrate/file" "github.com/mattes/migrate/file"
) )
var driverMap = map[string]Driver{
"postgres": &postgres.Driver{},
"mysql": &mysql.Driver{},
"bash": &bash.Driver{},
"cassandra": &cassandra.Driver{},
"sqlite3": &sqlite3.Driver{},
}
// Driver is the interface type that needs to implemented by all drivers. // Driver is the interface type that needs to implemented by all drivers.
type Driver interface { type Driver interface {
@ -47,48 +54,14 @@ func New(url string) (Driver, error) {
return nil, err return nil, err
} }
switch u.Scheme { if d, found := driverMap[u.Scheme]; found {
case "postgres": verifyFilenameExtension(u.Scheme, d)
d := &postgres.Driver{}
verifyFilenameExtension("postgres", d)
if err := d.Initialize(url); err != nil { if err := d.Initialize(url); err != nil {
return nil, err return nil, err
} }
return d, nil return d, nil
case "mysql":
d := &mysql.Driver{}
verifyFilenameExtension("mysql", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "bash":
d := &bash.Driver{}
verifyFilenameExtension("bash", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "cassandra":
d := &cassandra.Driver{}
verifyFilenameExtension("cassanda", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "sqlite3":
d := &sqlite3.Driver{}
verifyFilenameExtension("sqlite3", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
} }
return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme)
} }
// verifyFilenameExtension panics if the drivers filename extension // verifyFilenameExtension panics if the drivers filename extension

View File

@ -1,8 +1,6 @@
package driver package driver
import ( import "testing"
"testing"
)
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
if _, err := New("unknown://url"); err == nil { if _, err := New("unknown://url"); err == nil {