mirror of https://github.com/status-im/migrate.git
Merge pull request #50 from buddhamagnet/streamline-driver-code
Streamline driver code
This commit is contained in:
commit
531fdf64e5
|
@ -1 +1,2 @@
|
|||
.DS_Store
|
||||
test.db
|
|
@ -2,7 +2,6 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
neturl "net/url" // alias to allow `url string` func signature in New
|
||||
|
||||
|
@ -14,6 +13,14 @@ import (
|
|||
"github.com/mattes/migrate/file"
|
||||
)
|
||||
|
||||
var driverMap = map[string]Driver{
|
||||
"postgres": &postgres.Driver{},
|
||||
"mysql": &mysql.Driver{},
|
||||
"bash": &bash.Driver{},
|
||||
"cassandra": &cassandra.Driver{},
|
||||
"sqlite3": &sqlite3.Driver{},
|
||||
}
|
||||
|
||||
// Driver is the interface type that needs to implemented by all drivers.
|
||||
type Driver interface {
|
||||
|
||||
|
@ -47,48 +54,14 @@ func New(url string) (Driver, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "postgres":
|
||||
d := &postgres.Driver{}
|
||||
verifyFilenameExtension("postgres", d)
|
||||
if d, found := driverMap[u.Scheme]; found {
|
||||
verifyFilenameExtension(u.Scheme, d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "mysql":
|
||||
d := &mysql.Driver{}
|
||||
verifyFilenameExtension("mysql", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "bash":
|
||||
d := &bash.Driver{}
|
||||
verifyFilenameExtension("bash", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "cassandra":
|
||||
d := &cassandra.Driver{}
|
||||
verifyFilenameExtension("cassanda", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
case "sqlite3":
|
||||
d := &sqlite3.Driver{}
|
||||
verifyFilenameExtension("sqlite3", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
default:
|
||||
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
|
||||
}
|
||||
return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme)
|
||||
}
|
||||
|
||||
// verifyFilenameExtension panics if the drivers filename extension
|
||||
|
|
|
@ -1,11 +1,27 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
import "testing"
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if _, err := New("unknown://url"); err == nil {
|
||||
t.Error("no error although driver unknown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBash(t *testing.T) {
|
||||
driver, err := New("bash://url")
|
||||
if err != nil {
|
||||
t.Error("error although bash driver known")
|
||||
}
|
||||
version, err := driver.Version()
|
||||
if version != 0 {
|
||||
t.Errorf("expected bash driver version to be 0, got %d\n", version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSqlite3(t *testing.T) {
|
||||
_, err := New("sqlite3://test.db")
|
||||
if err != nil {
|
||||
t.Error("error although sqlite3 driver known")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue