add mysql driver, add ENV to docker containers

This commit is contained in:
Matthias Kadenbach 2017-02-28 15:10:56 -08:00
parent 760bc3eb2d
commit be1ba9204a
No known key found for this signature in database
GPG Key ID: DC1F4DC6D31A7031
11 changed files with 423 additions and 53 deletions

View File

@ -1,5 +1,5 @@
SOURCE ?= file go-bindata github
DATABASE ?= postgres
DATABASE ?= postgres mysql
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?=
REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)")

7
cli/build_mysql.go Normal file
View File

@ -0,0 +1,7 @@
// +build mysql
package main
import (
_ "github.com/mattes/migrate/database/mysql"
)

272
database/mysql/mysql.go Normal file
View File

@ -0,0 +1,272 @@
package mysql
import (
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/mattes/migrate"
"github.com/mattes/migrate/database"
)
func init() {
database.Register("mysql", &Mysql{})
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
)
type Config struct {
MigrationsTable string
DatabaseName string
}
type Mysql struct {
db *sql.DB
isLocked bool
config *Config
}
// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName.String) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName.String
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
mx := &Mysql{
db: instance,
config: config,
}
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}
return mx, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
purl.Query().Set("multiStatements", "true")
db, err := sql.Open("mysql", strings.Replace(
migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1))
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
mx, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
}
return mx, nil
}
func (m *Mysql) Close() error {
return m.db.Close()
}
func (m *Mysql) Lock() error {
if m.isLocked {
return database.ErrLocked
}
aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName)
if err != nil {
return err
}
query := "SELECT GET_LOCK(?, 1)"
var success bool
if err := m.db.QueryRow(query, aid).Scan(&success); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
if success {
m.isLocked = true
return nil
}
return database.ErrLocked
}
func (m *Mysql) Unlock() error {
if !m.isLocked {
return nil
}
aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName)
if err != nil {
return err
}
query := `SELECT RELEASE_LOCK(?)`
if _, err := m.db.Exec(query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
m.isLocked = false
return nil
}
func (m *Mysql) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
query := string(migr[:])
if _, err := m.db.Exec(query); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (m *Mysql) SetVersion(version int, dirty bool) error {
tx, err := m.db.Begin()
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := "TRUNCATE `" + m.config.MigrationsTable + "`"
if _, err := m.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
if _, err := m.db.Exec(query, version, dirty); err != nil {
tx.Rollback()
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (m *Mysql) Version() (version int, dirty bool, err error) {
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
err = m.db.QueryRow(query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*mysql.MySQLError); ok {
if e.Number == 0 {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (m *Mysql) Drop() error {
// select all tables
query := `SHOW TABLES LIKE '%'`
tables, err := m.db.Query(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer tables.Close()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
if _, err := m.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := m.ensureVersionTable(); err != nil {
return err
}
}
return nil
}
func (m *Mysql) ensureVersionTable() error {
// check if migration table exists
var count int
query := `SHOW TABLES WHERE ?`
if err := m.db.QueryRow(query, m.config.MigrationsTable).Scan(&count); err != nil {
if err != sql.ErrNoRows {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if count == 1 {
return nil
}
// if not, create the empty migration table
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
if _, err := m.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}

View File

@ -0,0 +1,51 @@
package mysql
import (
"database/sql"
sqldriver "database/sql/driver"
"fmt"
// "io/ioutil"
// "log"
"testing"
// "github.com/go-sql-driver/mysql"
dt "github.com/mattes/migrate/database/testing"
mt "github.com/mattes/migrate/testing"
)
var versions = []mt.Version{
{"mysql:8", []string{"MYSQL_ROOT_PASSWORD=root", "MYSQL_DATABASE=public"}},
{"mysql:5.7", []string{"MYSQL_ROOT_PASSWORD=root", "MYSQL_DATABASE=public"}},
{"mysql:5.6", []string{"MYSQL_ROOT_PASSWORD=root", "MYSQL_DATABASE=public"}},
{"mysql:5.5", []string{"MYSQL_ROOT_PASSWORD=root", "MYSQL_DATABASE=public"}},
}
func isReady(i mt.Instance) bool {
db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", i.Host(), i.Port()))
if err != nil {
return false
}
defer db.Close()
err = db.Ping()
if err == sqldriver.ErrBadConn {
return false
}
return true
}
func Test(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
mt.ParallelTest(t, versions, isReady,
func(t *testing.T, i mt.Instance) {
p := &Mysql{}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", i.Host(), i.Port())
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
dt.Test(t, d, []byte("SELECT 1"))
})
}

View File

@ -3,7 +3,6 @@ package postgres
import (
"database/sql"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
nurl "net/url"
@ -44,6 +43,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
@ -96,10 +99,6 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return px, nil
}
@ -113,7 +112,7 @@ func (p *Postgres) Lock() error {
return database.ErrLocked
}
aid, err := p.generateAdvisoryLockId()
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
@ -139,7 +138,7 @@ func (p *Postgres) Unlock() error {
return nil
}
aid, err := p.generateAdvisoryLockId()
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
@ -270,12 +269,3 @@ func (p *Postgres) ensureVersionTable() error {
}
return nil
}
const advisoryLockIdSalt uint = 1486364155
// inspired by rails migrations, see https://goo.gl/8o9bCT
func (p *Postgres) generateAdvisoryLockId() (string, error) {
sum := crc32.ChecksumIEEE([]byte(p.config.DatabaseName))
sum = sum * uint32(advisoryLockIdSalt)
return fmt.Sprintf("%v", sum), nil
}

View File

@ -14,12 +14,12 @@ import (
mt "github.com/mattes/migrate/testing"
)
var versions = []string{
"postgres:9.6",
"postgres:9.5",
"postgres:9.4",
"postgres:9.3",
"postgres:9.2",
var versions = []mt.Version{
{Image: "postgres:9.6"},
{Image: "postgres:9.5"},
{Image: "postgres:9.4"},
{Image: "postgres:9.3"},
{Image: "postgres:9.2"},
}
func isReady(i mt.Instance) bool {
@ -148,16 +148,3 @@ func TestWithSchema(t *testing.T) {
func TestWithInstance(t *testing.T) {
}
func TestGenerateAdvisoryLockId(t *testing.T) {
p := &Postgres{}
p.config = &Config{DatabaseName: "database_name"}
id, err := p.generateAdvisoryLockId()
if err != nil {
t.Errorf("expected err to be nil, got %v", err)
}
if len(id) == 0 {
t.Errorf("expected generated id not to be empty")
}
t.Logf("generated id: %v", id)
}

15
database/util.go Normal file
View File

@ -0,0 +1,15 @@
package database
import (
"fmt"
"hash/crc32"
)
const advisoryLockIdSalt uint = 1486364155
// inspired by rails migrations, see https://goo.gl/8o9bCT
func GenerateAdvisoryLockId(databaseName string) (string, error) {
sum := crc32.ChecksumIEEE([]byte(databaseName))
sum = sum * uint32(advisoryLockIdSalt)
return fmt.Sprintf("%v", sum), nil
}

12
database/util_test.go Normal file
View File

@ -0,0 +1,12 @@
package database
func TestGenerateAdvisoryLockId(t *testing.T) {
id, err := p.generateAdvisoryLockId("database_name")
if err != nil {
t.Errorf("expected err to be nil, got %v", err)
}
if len(id) == 0 {
t.Errorf("expected generated id not to be empty")
}
t.Logf("generated id: %v", id)
}

View File

@ -6,6 +6,7 @@ import (
"context" // TODO: is issue with go < 1.7?
"encoding/json"
"fmt"
"io"
"math/rand"
"strconv"
"strings"
@ -18,7 +19,7 @@ import (
dockerclient "github.com/docker/docker/client"
)
func NewDockerContainer(t testing.TB, image string) (*DockerContainer, error) {
func NewDockerContainer(t testing.TB, image string, env []string) (*DockerContainer, error) {
c, err := dockerclient.NewEnvClient()
if err != nil {
return nil, err
@ -28,6 +29,7 @@ func NewDockerContainer(t testing.TB, image string) (*DockerContainer, error) {
t: t,
client: c,
ImageName: image,
ENV: env,
}
if err := contr.PullImage(); err != nil {
@ -46,6 +48,7 @@ type DockerContainer struct {
t testing.TB
client *dockerclient.Client
ImageName string
ENV []string
ContainerId string
ContainerName string
ContainerJSON dockertypes.ContainerJSON
@ -83,9 +86,9 @@ func (d *DockerContainer) Start() error {
&dockercontainer.Config{
Image: d.ImageName,
Labels: map[string]string{"migrate_test": "true"},
Env: d.ENV,
},
&dockercontainer.HostConfig{
AutoRemove: true,
PublishAllPorts: true,
},
&dockernetwork.NetworkingConfig{},
@ -146,6 +149,17 @@ func (d *DockerContainer) Inspect() error {
return nil
}
func (d *DockerContainer) Logs() (io.ReadCloser, error) {
if len(d.ContainerId) == 0 {
return nil, fmt.Errorf("missing containerId")
}
return d.client.ContainerLogs(context.Background(), d.ContainerId, dockertypes.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
})
}
func (d *DockerContainer) firstPortMapping() (containerPort uint, hostIP string, hostPort uint, err error) {
if !d.containerInspected {
if err := d.Inspect(); err != nil {

View File

@ -1,6 +1,7 @@
package testing
import (
"io/ioutil"
"os"
"strconv"
"testing"
@ -11,22 +12,33 @@ type IsReadyFunc func(Instance) bool
type TestFunc func(*testing.T, Instance)
func ParallelTest(t *testing.T, versions []string, readyFn IsReadyFunc, testFn TestFunc) {
type Version struct {
Image string
ENV []string
}
func ParallelTest(t *testing.T, versions []Version, readyFn IsReadyFunc, testFn TestFunc) {
delay, err := strconv.Atoi(os.Getenv("MIGRATE_TEST_CONTAINER_BOOT_DELAY"))
if err != nil {
delay = 0
}
for i, version := range versions {
version := version // capture range variable, see https://goo.gl/60w3p2
// Only test against first found version in short mode
// Only test against one version in short mode
// TODO: order is random, maybe always pick first version instead?
if i > 0 && testing.Short() {
t.Logf("Skipping %v in short mode", version)
} else {
t.Run(version, func(t *testing.T) {
t.Run(version.Image, func(t *testing.T) {
t.Parallel()
// creata new container
container, err := NewDockerContainer(t, version)
// create new container
container, err := NewDockerContainer(t, version.Image, version.ENV)
if err != nil {
t.Fatal(err)
t.Fatalf("%v\n%s", err, containerLogs(t, container))
}
// make sure to remove container once done
@ -34,7 +46,7 @@ func ParallelTest(t *testing.T, versions []string, readyFn IsReadyFunc, testFn T
// wait until database is ready
tick := time.Tick(1000 * time.Millisecond)
timeout := time.After(30 * time.Second)
timeout := time.After(time.Duration(delay+60) * time.Second)
outer:
for {
select {
@ -44,16 +56,11 @@ func ParallelTest(t *testing.T, versions []string, readyFn IsReadyFunc, testFn T
}
case <-timeout:
t.Fatalf("Docker: Container not ready, timeout for %v.", version)
t.Fatalf("Docker: Container not ready, timeout for %v.\n%s", version, containerLogs(t, container))
}
}
delay, err := strconv.Atoi(os.Getenv("MIGRATE_TEST_CONTAINER_BOOT_DELAY"))
if err == nil {
time.Sleep(time.Duration(int64(delay)) * time.Second)
} else {
time.Sleep(2 * time.Second)
}
time.Sleep(time.Duration(int64(delay)) * time.Second)
// we can now run the tests
testFn(t, container)
@ -62,6 +69,21 @@ func ParallelTest(t *testing.T, versions []string, readyFn IsReadyFunc, testFn T
}
}
func containerLogs(t *testing.T, c *DockerContainer) []byte {
r, err := c.Logs()
if err != nil {
t.Error("%v", err)
return nil
}
defer r.Close()
b, err := ioutil.ReadAll(r)
if err != nil {
t.Error("%v", err)
return nil
}
return b
}
type Instance interface {
Host() string
Port() uint

View File

@ -12,7 +12,7 @@ func ExampleParallelTest(t *testing.T) {
}
// t is *testing.T coming from parent Test(t *testing.T)
ParallelTest(t, []string{"docker_image:9.6"}, isReady,
ParallelTest(t, []Version{{Image: "docker_image:9.6"}}, isReady,
func(t *testing.T, i Instance) {
// Run your test/s ...
t.Fatal("...")