Merge pull request #50 from buddhamagnet/streamline-driver-code

Streamline driver code
This commit is contained in:
Matthias Kadenbach 2015-09-24 10:33:18 -07:00
commit 531fdf64e5
3 changed files with 32 additions and 42 deletions

3
.gitignore vendored
View File

@ -1 +1,2 @@
.DS_Store
.DS_Store
test.db

View File

@ -2,7 +2,6 @@
package driver
import (
"errors"
"fmt"
neturl "net/url" // alias to allow `url string` func signature in New
@ -14,6 +13,14 @@ import (
"github.com/mattes/migrate/file"
)
var driverMap = map[string]Driver{
"postgres": &postgres.Driver{},
"mysql": &mysql.Driver{},
"bash": &bash.Driver{},
"cassandra": &cassandra.Driver{},
"sqlite3": &sqlite3.Driver{},
}
// Driver is the interface type that needs to implemented by all drivers.
type Driver interface {
@ -47,48 +54,14 @@ func New(url string) (Driver, error) {
return nil, err
}
switch u.Scheme {
case "postgres":
d := &postgres.Driver{}
verifyFilenameExtension("postgres", d)
if d, found := driverMap[u.Scheme]; found {
verifyFilenameExtension(u.Scheme, d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "mysql":
d := &mysql.Driver{}
verifyFilenameExtension("mysql", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "bash":
d := &bash.Driver{}
verifyFilenameExtension("bash", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "cassandra":
d := &cassandra.Driver{}
verifyFilenameExtension("cassanda", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "sqlite3":
d := &sqlite3.Driver{}
verifyFilenameExtension("sqlite3", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
}
return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme)
}
// verifyFilenameExtension panics if the drivers filename extension

View File

@ -1,11 +1,27 @@
package driver
import (
"testing"
)
import "testing"
func TestNew(t *testing.T) {
if _, err := New("unknown://url"); err == nil {
t.Error("no error although driver unknown")
}
}
func TestNewBash(t *testing.T) {
driver, err := New("bash://url")
if err != nil {
t.Error("error although bash driver known")
}
version, err := driver.Version()
if version != 0 {
t.Errorf("expected bash driver version to be 0, got %d\n", version)
}
}
func TestNewSqlite3(t *testing.T) {
_, err := New("sqlite3://test.db")
if err != nil {
t.Error("error although sqlite3 driver known")
}
}