mirror of https://github.com/status-im/migrate.git
streamline driver code
This commit is contained in:
parent
e857fcc785
commit
02c5fc24d3
|
@ -2,7 +2,6 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
neturl "net/url" // alias to allow `url string` func signature in New
|
||||
|
||||
|
@ -14,6 +13,14 @@ import (
|
|||
"github.com/mattes/migrate/file"
|
||||
)
|
||||
|
||||
var driverMap = map[string]Driver{
|
||||
"postgres": &postgres.Driver{},
|
||||
"mysql": &mysql.Driver{},
|
||||
"bash": &bash.Driver{},
|
||||
"cassandra": &cassandra.Driver{},
|
||||
"sqlite3": &sqlite3.Driver{},
|
||||
}
|
||||
|
||||
// Driver is the interface type that needs to implemented by all drivers.
|
||||
type Driver interface {
|
||||
|
||||
|
@ -47,48 +54,14 @@ func New(url string) (Driver, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "postgres":
|
||||
d := &postgres.Driver{}
|
||||
verifyFilenameExtension("postgres", d)
|
||||
if d, found := driverMap[u.Scheme]; found {
|
||||
verifyFilenameExtension(u.Scheme, d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "mysql":
|
||||
d := &mysql.Driver{}
|
||||
verifyFilenameExtension("mysql", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "bash":
|
||||
d := &bash.Driver{}
|
||||
verifyFilenameExtension("bash", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
|
||||
case "cassandra":
|
||||
d := &cassandra.Driver{}
|
||||
verifyFilenameExtension("cassanda", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
case "sqlite3":
|
||||
d := &sqlite3.Driver{}
|
||||
verifyFilenameExtension("sqlite3", d)
|
||||
if err := d.Initialize(url); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
default:
|
||||
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
|
||||
}
|
||||
return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme)
|
||||
}
|
||||
|
||||
// verifyFilenameExtension panics if the drivers filename extension
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
import "testing"
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if _, err := New("unknown://url"); err == nil {
|
||||
|
|
Loading…
Reference in New Issue