This commit is contained in:
mattes 2014-08-11 03:42:57 +02:00
commit a45e244a71
19 changed files with 1264 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.DS_Store

12
.travis.yml Normal file
View File

@ -0,0 +1,12 @@
language: go
go:
- 1.2
- 1.3
- tip
addons:
postgresql: "9.3"
before_script:
- psql -c 'create database migratetest;' -U postgres

90
README.md Normal file
View File

@ -0,0 +1,90 @@
# migrate
[![Build Status](https://travis-ci.org/mattes/migrate.svg?branch=master)](https://travis-ci.org/mattes/migrate)
migrate can be used as CLI or can be imported into your existing Go code.
## Available Drivers
* [Postgres](https://github.com/mattes/migrate/tree/master/driver/postgres)
* Bash (planned)
Need another driver? Just implement the [Driver interface](http://godoc.org/github.com/mattes/migrate/driver#Driver) and open a PR.
## Usage from Terminal
```bash
# install
go get github.com/mattes/migrate
# create new migration
migrate -url="postgres://user@host:port/database" create
# apply all *up* migrations
migrate -url="postgres://user@host:port/database" up
# apply all *down* migrations
migrate -url="postgres://user@host:port/database" down
# roll back the most recently applied migration, then run it again.
migrate -url="postgres://user@host:port/database" redo
# down and up again
migrate -url="postgres://user@host:port/database" reset
# show current migration version
migrate -url="postgres://user@host:port/database" version
# apply the next n migrations
migrate -url="postgres://user@host:port/database" migrate +1
migrate -url="postgres://user@host:port/database" migrate +2
migrate -url="postgres://user@host:port/database" migrate +n
# apply the *down* migration of the current version
# and the previous n-1 migrations
migrate -url="postgres://user@host:port/database" migrate -1
migrate -url="postgres://user@host:port/database" migrate -2
migrate -url="postgres://user@host:port/database" migrate -n
```
``migrate`` looks for migration files in the following directories:
```
./db/migrations
./migrations
./db
```
You can explicitly set the search path with ``-path``.
## Usage from within Go
See http://godoc.org/github.com/mattes/migrate/migrate
```golang
import "github.com/mattes/migrate/migrate"
// optionally set search path
// migrate.SetSearchPath("./location1", "./location2")
migrate.Up("postgres://user@host:port/database")
// ...
// ...
```
## Migrations format
```
./db/migrations/001_initial.up.sql
./db/migrations/001_initial.down.sql
```
Why two files? This way you could do sth like ``psql -f ./db/migrations/001_initial.up.sql``.
## Credits
* https://bitbucket.org/liamstask/goose

7
driver/bash/README.md Normal file
View File

@ -0,0 +1,7 @@
# Bash Driver
```
-url="bash://"
```
* Runs all SQL commands in transcations

26
driver/bash/bash.go Normal file
View File

@ -0,0 +1,26 @@
package bash
import (
"github.com/mattes/migrate/file"
_ "github.com/mattes/migrate/migrate/direction"
)
type Driver struct {
}
func (driver *Driver) Initialize(url string) error {
return nil
}
func (driver *Driver) FilenameExtension() string {
return "sh"
}
func (driver *Driver) Migrate(files file.Files) error {
return nil
}
func (driver *Driver) Version() (uint64, error) {
return uint64(0), nil
}

9
driver/bash/bash_test.go Normal file
View File

@ -0,0 +1,9 @@
package bash
import (
"testing"
)
func TestFoobar(t *testing.T) {
}

54
driver/driver.go Normal file
View File

@ -0,0 +1,54 @@
package driver
import (
"errors"
"fmt"
"github.com/mattes/migrate/driver/bash"
"github.com/mattes/migrate/driver/postgres"
"github.com/mattes/migrate/file"
neturl "net/url" // alias to allow `url string` func signature in New
)
type Driver interface {
Initialize(url string) error
FilenameExtension() string
Migrate(files file.Files) error
Version() (uint64, error)
}
// InitDriver returns Driver and initializes it
func New(url string) (Driver, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
switch u.Scheme {
case "postgres":
d := &postgres.Driver{}
verifyFilenameExtension("postgres", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
case "bash":
d := &bash.Driver{}
verifyFilenameExtension("bash", d)
if err := d.Initialize(url); err != nil {
return nil, err
}
return d, nil
default:
return nil, errors.New(fmt.Sprintf("Driver '%s' not found.", u.Scheme))
}
}
func verifyFilenameExtension(driverName string, d Driver) {
f := d.FilenameExtension()
if f == "" {
panic(fmt.Sprintf("%s.FilenameExtension() returns empty string.", driverName))
}
if f[0:1] == "." {
panic(fmt.Sprintf("%s.FilenameExtension() returned string must not start with a dot.", driverName))
}
}

11
driver/driver_test.go Normal file
View File

@ -0,0 +1,11 @@
package driver
import (
"testing"
)
func TestNew(t *testing.T) {
if _, err := New("unknown://host/database"); err == nil {
t.Error("no error although driver unknown")
}
}

10
driver/postgres/README.md Normal file
View File

@ -0,0 +1,10 @@
# Postgres Driver
```
-url="postgres://user@host:port/database"
# thinking about adding some custom flag:
-url="postgres://user@host:port/database?schema=name"
```
* Runs all SQL commands in transcations

View File

@ -0,0 +1,96 @@
package postgres
import (
"database/sql"
_ "github.com/lib/pq"
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
)
type Driver struct {
db *sql.DB
}
const tableName = "schema_migrations"
func (driver *Driver) Initialize(url string) error {
db, err := sql.Open("postgres", url)
if err != nil {
return err
}
if err := db.Ping(); err != nil {
return err
}
driver.db = db
if err := driver.ensureVersionTableExists(); err != nil {
return err
}
return nil
}
func (driver *Driver) ensureVersionTableExists() error {
if _, err := driver.db.Exec(`CREATE TABLE IF NOT EXISTS ` + tableName + ` (
version int not null primary key
);`); err != nil {
return err
}
return nil
}
func (driver *Driver) FilenameExtension() string {
return "sql"
}
func (driver *Driver) Migrate(files file.Files) error {
for _, f := range files {
tx, err := driver.db.Begin()
if err != nil {
return err
}
if f.Direction == direction.Up {
if _, err := tx.Exec(`INSERT INTO `+tableName+` (version) VALUES ($1)`, f.Version); err != nil {
if err := tx.Rollback(); err != nil {
// haha, what now?
}
return err
}
} else if f.Direction == direction.Down {
if _, err := tx.Exec(`DELETE FROM `+tableName+` WHERE version=$1`, f.Version); err != nil {
if err := tx.Rollback(); err != nil {
// haha, what now?
}
return err
}
}
f.Read()
if _, err := tx.Exec(string(f.Content)); err != nil {
if err := tx.Rollback(); err != nil {
// haha, what now?
}
return err
}
if err := tx.Commit(); err != nil {
return err
}
}
return nil
}
func (driver *Driver) Version() (uint64, error) {
var version uint64
err := driver.db.QueryRow(`SELECT version FROM ` + tableName + ` ORDER BY version DESC`).Scan(&version)
switch {
case err == sql.ErrNoRows:
return 0, nil
case err != nil:
return 0, err
default:
return version, nil
}
}

View File

@ -0,0 +1,96 @@
package postgres
import (
"database/sql"
_ "github.com/lib/pq"
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
"testing"
)
func TestMigrate(t *testing.T) {
connection, err := sql.Open("postgres", "postgres://localhost/migratetest?sslmode=disable")
if err != nil {
t.Fatal(err)
}
if _, err := connection.Exec(`DROP TABLE IF EXISTS hello; DROP TABLE IF EXISTS yolo; DROP TABLE IF EXISTS ` + tableName + `;`); err != nil {
t.Fatal(err)
}
d := &Driver{}
if err := d.Initialize("postgres://localhost/migratetest?sslmode=disable"); err != nil {
t.Fatal(err)
}
version, err := d.Version()
if err != nil {
t.Fatal(err)
}
if version != 0 {
t.Fatal("wrong version", version)
}
files := make(file.MigrationFiles, 0)
files = append(files, file.MigrationFile{
Version: 1,
UpFile: &file.File{
Path: "/tmp",
FileName: "001_initial.up.sql",
Version: 1,
Name: "initial",
Direction: direction.Up,
Content: []byte(`
CREATE TABLE hello (
id serial not null primary key,
message varchar(255) not null default ''
);
CREATE TABLE yolo (
id serial not null primary key,
foobar varchar(255) not null default ''
);
`),
},
DownFile: &file.File{
Path: "/tmp",
FileName: "001_initial.down.sql",
Version: 1,
Name: "initial",
Direction: direction.Down,
Content: []byte(`
DROP TABLE IF EXISTS hello;
DROP TABLE IF EXISTS yolo;
`),
},
})
applyFiles, _ := files.ToLastFrom(0)
if err := d.Migrate(applyFiles); err != nil {
t.Fatal(err)
}
version, _ = d.Version()
if version != 1 {
t.Fatalf("wrong version %v expected 1", version)
}
if _, err := connection.Exec(`INSERT INTO hello (message) VALUES ($1)`, "whats up"); err != nil {
t.Fatal("Migrations failed")
}
applyFiles2, _ := files.ToFirstFrom(1)
if err := d.Migrate(applyFiles2); err != nil {
t.Fatal(err)
}
version, _ = d.Version()
if version != 0 {
t.Fatalf("wrong version %v expected 0", version)
}
if _, err := connection.Exec(`INSERT INTO hello (message) VALUES ($1)`, "whats up"); err == nil {
t.Fatal("Migrations failed")
}
}

252
file/file.go Normal file
View File

@ -0,0 +1,252 @@
package file
import (
"errors"
"fmt"
"github.com/mattes/migrate/migrate/direction"
"io/ioutil"
"path"
"regexp"
"sort"
"strconv"
)
var filenameRegex = "^([0-9]+)_(.*)\\.(up|down)\\.%s$"
func FilenameRegex(filenameExtension string) *regexp.Regexp {
return regexp.MustCompile(fmt.Sprintf(filenameRegex, filenameExtension))
}
type File struct {
Path string
FileName string
Version uint64
Name string
Content []byte
Direction direction.Direction
}
type Files []File
type MigrationFile struct {
Version uint64
UpFile *File
DownFile *File
}
type MigrationFiles []MigrationFile
// Read reads the file contents
func (f *File) Read() error {
content, err := ioutil.ReadFile(path.Join(f.Path, f.FileName))
if err != nil {
return err
}
f.Content = content
return nil
}
// ToFirstFrom fetches all (down) migration files including the migration file
// of the current version to the very first migration file.
func (mf *MigrationFiles) ToFirstFrom(version uint64) (Files, error) {
sort.Sort(sort.Reverse(mf))
files := make(Files, 0)
for _, migrationFile := range *mf {
if migrationFile.Version <= version && migrationFile.DownFile != nil {
files = append(files, *migrationFile.DownFile)
}
}
return files, nil
}
// ToLastFrom fetches all (up) migration files to the most recent migration file.
// The migration file of the current version is not included.
func (mf *MigrationFiles) ToLastFrom(version uint64) (Files, error) {
sort.Sort(mf)
files := make(Files, 0)
for _, migrationFile := range *mf {
if migrationFile.Version > version && migrationFile.UpFile != nil {
files = append(files, *migrationFile.UpFile)
}
}
return files, nil
}
// From travels relatively through migration files.
// +1 will fetch the next up migration file
// +2 will fetch the next two up migration files
// -1 will fetch the the current down migration file
// -2 will fetch the current down and the next down migration file
func (mf *MigrationFiles) From(version uint64, relativeN int) (Files, error) {
var d direction.Direction
if relativeN > 0 {
d = direction.Up
} else if relativeN < 0 {
d = direction.Down
} else { // relativeN == 0
return nil, nil
}
if d == direction.Down {
sort.Sort(sort.Reverse(mf))
} else {
sort.Sort(mf)
}
files := make(Files, 0)
counter := relativeN
if relativeN < 0 {
counter = relativeN * -1
}
for _, migrationFile := range *mf {
if counter > 0 {
if d == direction.Up && migrationFile.Version > version && migrationFile.UpFile != nil {
files = append(files, *migrationFile.UpFile)
counter -= 1
} else if d == direction.Down && migrationFile.Version <= version && migrationFile.DownFile != nil {
files = append(files, *migrationFile.DownFile)
counter -= 1
}
} else {
break
}
}
return files, nil
}
// readMigrationFiles reads all migration files from a given path
func ReadMigrationFiles(path string, filenameRegex *regexp.Regexp) (files MigrationFiles, err error) {
// find all migration files in path
ioFiles, err := ioutil.ReadDir(path)
if err != nil {
return nil, err
}
type tmpFile struct {
version uint64
name string
filename string
d direction.Direction
}
tmpFiles := make([]*tmpFile, 0)
for _, file := range ioFiles {
version, name, d, err := parseFilenameSchema(file.Name(), filenameRegex)
if err == nil {
tmpFiles = append(tmpFiles, &tmpFile{version, name, file.Name(), d})
}
}
// put tmpFiles into MigrationFile struct
parsedVersions := make(map[uint64]bool)
newFiles := make(MigrationFiles, 0)
for _, file := range tmpFiles {
if _, ok := parsedVersions[file.version]; !ok {
migrationFile := MigrationFile{
Version: file.version,
}
var lookFordirection direction.Direction
switch file.d {
case direction.Up:
migrationFile.UpFile = &File{
Path: path,
FileName: file.filename,
Version: file.version,
Name: file.name,
Content: nil,
Direction: direction.Up,
}
lookFordirection = direction.Down
case direction.Down:
migrationFile.DownFile = &File{
Path: path,
FileName: file.filename,
Version: file.version,
Name: file.name,
Content: nil,
Direction: direction.Down,
}
lookFordirection = direction.Up
default:
return nil, errors.New("Unsupported direction.Direction Type")
}
for _, file2 := range tmpFiles {
if file2.version == file.version && file2.d == lookFordirection {
switch lookFordirection {
case direction.Up:
migrationFile.UpFile = &File{
Path: path,
FileName: file2.filename,
Version: file.version,
Name: file2.name,
Content: nil,
Direction: direction.Up,
}
case direction.Down:
migrationFile.DownFile = &File{
Path: path,
FileName: file2.filename,
Version: file.version,
Name: file2.name,
Content: nil,
Direction: direction.Down,
}
}
break
}
}
newFiles = append(newFiles, migrationFile)
parsedVersions[file.version] = true
}
}
sort.Sort(newFiles)
return newFiles, nil
}
// parseFilenameSchema parses the filename and returns
// version, name, d (up|down)
// the schema looks like 000_name.(up|down).extension
func parseFilenameSchema(filename string, filenameRegex *regexp.Regexp) (version uint64, name string, d direction.Direction, err error) {
matches := filenameRegex.FindStringSubmatch(filename)
if len(matches) != 4 {
return 0, "", 0, errors.New("Unable to parse filename schema")
}
version, err = strconv.ParseUint(matches[1], 10, 0)
if err != nil {
return 0, "", 0, errors.New(fmt.Sprintf("Unable to parse version '%v' in filename schema", matches[0]))
}
if matches[3] == "up" {
d = direction.Up
} else if matches[3] == "down" {
d = direction.Down
} else {
return 0, "", 0, errors.New(fmt.Sprintf("Unable to parse up|down '%v' in filename schema", matches[3]))
}
return version, matches[2], d, nil
}
// implement sort interface ...
// Len is the number of elements in the collection.
func (f MigrationFiles) Len() int {
return len(f)
}
// Less reports whether the element with
// index i should sort before the element with index j.
func (f MigrationFiles) Less(i, j int) bool {
return f[i].Version < f[j].Version
}
// Swap swaps the elements with indexes i and j.
func (f MigrationFiles) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}

211
file/file_test.go Normal file
View File

@ -0,0 +1,211 @@
package file
import (
"github.com/mattes/migrate/migrate/direction"
"io/ioutil"
"os"
"path"
"testing"
)
func TestParseFilenameSchema(t *testing.T) {
var tests = []struct {
filename string
filenameExtension string
expectVersion uint64
expectName string
expectDirection direction.Direction
expectErr bool
}{
{"001_test_file.up.sql", "sql", 1, "test_file", direction.Up, false},
{"001_test_file.down.sql", "sql", 1, "test_file", direction.Down, false},
{"10034_test_file.down.sql", "sql", 10034, "test_file", direction.Down, false},
{"-1_test_file.down.sql", "sql", 0, "", direction.Up, true},
{"test_file.down.sql", "sql", 0, "", direction.Up, true},
{"100_test_file.down", "sql", 0, "", direction.Up, true},
{"100_test_file.sql", "sql", 0, "", direction.Up, true},
{"100_test_file", "sql", 0, "", direction.Up, true},
{"test_file", "sql", 0, "", direction.Up, true},
{"100", "sql", 0, "", direction.Up, true},
{".sql", "sql", 0, "", direction.Up, true},
{"up.sql", "sql", 0, "", direction.Up, true},
{"down.sql", "sql", 0, "", direction.Up, true},
}
for _, test := range tests {
version, name, migrate, err := parseFilenameSchema(test.filename, FilenameRegex(test.filenameExtension))
if test.expectErr && err == nil {
t.Fatal("Expected error, but got none.", test)
}
if !test.expectErr && err != nil {
t.Fatal("Did not expect error, but got one:", err, test)
}
if err == nil {
if version != test.expectVersion {
t.Error("Wrong version number", test)
}
if name != test.expectName {
t.Error("wrong name", test)
}
if migrate != test.expectDirection {
t.Error("wrong migrate", test)
}
}
}
}
func TestFiles(t *testing.T) {
tmpdir, err := ioutil.TempDir("/tmp", "TestLookForMigrationFilesInSearchPath")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
if err := ioutil.WriteFile(path.Join(tmpdir, "nonsense.txt"), nil, 0755); err != nil {
t.Fatal("Unable to write files in tmpdir", err)
}
ioutil.WriteFile(path.Join(tmpdir, "002_migrationfile.up.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "002_migrationfile.down.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "001_migrationfile.up.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "001_migrationfile.down.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "101_create_table.up.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "101_drop_tables.down.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "301_migrationfile.up.sql"), nil, 0755)
ioutil.WriteFile(path.Join(tmpdir, "401_migrationfile.down.sql"), []byte("test"), 0755)
files, err := ReadMigrationFiles(tmpdir, FilenameRegex("sql"))
if err != nil {
t.Fatal(err)
}
if len(files) == 0 {
t.Fatal("No files returned.")
}
if len(files) != 5 {
t.Fatal("Wrong number of files returned.")
}
// test sort order
if files[0].Version != 1 || files[1].Version != 2 || files[2].Version != 101 || files[3].Version != 301 || files[4].Version != 401 {
t.Error("Sort order is incorrect")
t.Error(files)
}
// test UpFile and DownFile
if files[0].UpFile == nil {
t.Fatalf("Missing up file for version %v", files[0].Version)
}
if files[0].DownFile == nil {
t.Fatalf("Missing down file for version %v", files[0].Version)
}
if files[1].UpFile == nil {
t.Fatalf("Missing up file for version %v", files[1].Version)
}
if files[1].DownFile == nil {
t.Fatalf("Missing down file for version %v", files[1].Version)
}
if files[2].UpFile == nil {
t.Fatalf("Missing up file for version %v", files[2].Version)
}
if files[2].DownFile == nil {
t.Fatalf("Missing down file for version %v", files[2].Version)
}
if files[3].UpFile == nil {
t.Fatalf("Missing up file for version %v", files[3].Version)
}
if files[3].DownFile != nil {
t.Fatalf("There should not be a down file for version %v", files[3].Version)
}
if files[4].UpFile != nil {
t.Fatalf("There should not be a up file for version %v", files[4].Version)
}
if files[4].DownFile == nil {
t.Fatalf("Missing down file for version %v", files[4].Version)
}
// test read
if err := files[4].DownFile.Read(); err != nil {
t.Error("Unable to read file", err)
}
if files[4].DownFile.Content == nil {
t.Fatal("Read content is nil")
}
if string(files[4].DownFile.Content) != "test" {
t.Fatal("Read content is wrong")
}
// test names
if files[0].UpFile.Name != "migrationfile" {
t.Error("file name is not correct", files[0].UpFile.Name)
}
if files[0].UpFile.FileName != "001_migrationfile.up.sql" {
t.Error("file name is not correct", files[0].UpFile.FileName)
}
// test file.From()
// there should be the following versions:
// 1(up&down), 2(up&down), 101(up&down), 301(up), 401(down)
var tests = []struct {
from uint64
relative int
expectRange []uint64
}{
{0, 2, []uint64{1, 2}},
{1, 4, []uint64{2, 101, 301}},
{1, 0, nil},
{0, 1, []uint64{1}},
{0, 0, nil},
{101, -2, []uint64{101, 2}},
{401, -1, []uint64{401}},
}
for _, test := range tests {
rangeFiles, err := files.From(test.from, test.relative)
if err != nil {
t.Error("Unable to fetch range:", err)
}
if len(rangeFiles) != len(test.expectRange) {
t.Fatalf("file.From(): expected %v files, got %v. For test %v.", len(test.expectRange), len(rangeFiles), test.expectRange)
}
for i, version := range test.expectRange {
if rangeFiles[i].Version != version {
t.Fatal("file.From(): returned files dont match expectations", test.expectRange)
}
}
}
// test ToFirstFrom
tffFiles, err := files.ToFirstFrom(401)
if err != nil {
t.Fatal(err)
}
if len(tffFiles) != 4 {
t.Fatalf("Wrong number of files returned by ToFirstFrom(), expected %v, got %v.", 5, len(tffFiles))
}
if tffFiles[0].Direction != direction.Down {
t.Error("ToFirstFrom() did not return DownFiles")
}
// test ToLastFrom
tofFiles, err := files.ToLastFrom(0)
if err != nil {
t.Fatal(err)
}
if len(tofFiles) != 4 {
t.Fatalf("Wrong number of files returned by ToLastFrom(), expected %v, got %v.", 5, len(tofFiles))
}
if tofFiles[0].Direction != direction.Up {
t.Error("ToFirstFrom() did not return UpFiles")
}
}

43
main.go Normal file
View File

@ -0,0 +1,43 @@
package main
import (
"flag"
"fmt"
"github.com/mattes/migrate/migrate"
"os"
)
var db = flag.String("db", "schema://url", "Driver connection URL")
var path = flag.String("path", "./db/migrations:./migrations:./db", "Migrations search path")
var help = flag.Bool("help", false, "Show help")
func main() {
flag.Parse()
if *help {
usage()
os.Exit(0)
}
command := flag.Arg(0)
switch command {
case "create":
if *path != "" {
migrate.SetSearchPath(*path)
}
files, err := migrate.Create(*db, "blablabla")
if err != nil {
fmt.Println(err)
os.Exit(1)
}
fmt.Println(files)
}
// fmt.Println(*db)
}
func usage() {
fmt.Fprint(os.Stderr, "Usage of migrate:\n")
flag.PrintDefaults()
}

View File

@ -0,0 +1,8 @@
package direction
type Direction int
const (
Up Direction = +1
Down = -1
)

219
migrate/migrate.go Normal file
View File

@ -0,0 +1,219 @@
package migrate
import (
"errors"
"fmt"
"github.com/mattes/migrate/driver"
"github.com/mattes/migrate/file"
"github.com/mattes/migrate/migrate/direction"
"github.com/mattes/migrate/searchpath"
"io/ioutil"
"path"
"strconv"
"strings"
)
func init() {
SetSearchPath("./db/migrations", "./migrations", "./db")
}
// Convenience func for searchpath.SetSearchPath(), so users
// don't have to import searchpath
func SetSearchPath(paths ...string) {
searchpath.SetSearchPath(paths...)
}
func common(db string) (driver.Driver, *file.MigrationFiles, uint64, error) {
d, err := driver.New(db)
if err != nil {
return nil, nil, 0, err
}
p, err := searchpath.FindPath(file.FilenameRegex(d.FilenameExtension()))
if err != nil {
return nil, nil, 0, err
}
files, err := file.ReadMigrationFiles(p, file.FilenameRegex(d.FilenameExtension()))
if err != nil {
return nil, nil, 0, err
}
version, err := d.Version()
if err != nil {
return nil, nil, 0, err
}
return d, &files, version, nil
}
func Up(db string) error {
d, files, version, err := common(db)
if err != nil {
return err
}
applyMigrationFiles, err := files.ToLastFrom(version)
if err != nil {
return err
}
if len(applyMigrationFiles) > 0 {
return d.Migrate(applyMigrationFiles)
}
return errors.New("No migrations to apply.")
}
func Down(db string) error {
d, files, version, err := common(db)
if err != nil {
return err
}
applyMigrationFiles, err := files.ToFirstFrom(version)
if err != nil {
return err
}
if len(applyMigrationFiles) > 0 {
return d.Migrate(applyMigrationFiles)
}
return errors.New("No migrations to apply.")
}
func Redo(db string) error {
d, files, version, err := common(db)
if err != nil {
return err
}
applyMigrationFilesDown, err := files.From(version, -1)
if err != nil {
return err
}
if len(applyMigrationFilesDown) > 0 {
if err := d.Migrate(applyMigrationFilesDown); err != nil {
return err
}
}
applyMigrationFilesUp, err := files.From(version, +1)
if err != nil {
return err
}
if len(applyMigrationFilesUp) > 0 {
return d.Migrate(applyMigrationFilesUp)
}
return errors.New("No migrations to apply.")
}
func Reset(db string) error {
d, files, version, err := common(db)
if err != nil {
return err
}
applyMigrationFilesDown, err := files.ToFirstFrom(version)
if err != nil {
return err
}
if len(applyMigrationFilesDown) > 0 {
if err := d.Migrate(applyMigrationFilesDown); err != nil {
return err
}
}
applyMigrationFilesUp, err := files.ToLastFrom(0)
if err != nil {
return err
}
if len(applyMigrationFilesUp) > 0 {
return d.Migrate(applyMigrationFilesUp)
}
return errors.New("No migrations to apply.")
}
func Migrate(db string, relativeN int) error {
d, files, version, err := common(db)
if err != nil {
return err
}
applyMigrationFiles, err := files.From(version, relativeN)
if err != nil {
return err
}
if len(applyMigrationFiles) > 0 {
if relativeN > 0 {
return d.Migrate(applyMigrationFiles)
} else if relativeN < 0 {
return d.Migrate(applyMigrationFiles)
} else {
return errors.New("No migrations to apply.")
}
}
return errors.New("No migrations to apply.")
}
func Version(db string) (version uint64, err error) {
d, err := driver.New(db)
if err != nil {
return 0, err
}
return d.Version()
}
func Create(db, name string) (*file.MigrationFile, error) {
d, err := driver.New(db)
if err != nil {
return nil, err
}
p, _ := searchpath.FindPath(file.FilenameRegex(d.FilenameExtension()))
if p == "" {
paths := searchpath.GetSearchPath()
if len(paths) > 0 {
p = paths[0]
} else {
return nil, errors.New("Please specify at least one search path.")
}
}
files, err := file.ReadMigrationFiles(p, file.FilenameRegex(d.FilenameExtension()))
if err != nil {
return nil, err
}
version := uint64(0)
if len(files) > 0 {
lastFile := files[len(files)-1]
version = lastFile.Version
}
version += 1
versionStr := strconv.FormatUint(version, 10)
length := 4
if len(versionStr)%length != 0 {
versionStr = strings.Repeat("0", length-len(versionStr)%length) + versionStr
}
filenamef := "%s_%s.%s.%s"
name = strings.Replace(name, " ", "_", -1)
mfile := &file.MigrationFile{
Version: version,
UpFile: &file.File{
Path: p,
FileName: fmt.Sprintf(filenamef, versionStr, name, "up", d.FilenameExtension()),
Name: name,
Content: []byte(""),
Direction: direction.Up,
},
DownFile: &file.File{
Path: p,
FileName: fmt.Sprintf(filenamef, versionStr, name, "down", d.FilenameExtension()),
Name: name,
Content: []byte(""),
Direction: direction.Down,
},
}
if err := ioutil.WriteFile(path.Join(mfile.UpFile.Path, mfile.UpFile.FileName), mfile.UpFile.Content, 0644); err != nil {
return nil, err
}
if err := ioutil.WriteFile(path.Join(mfile.DownFile.Path, mfile.DownFile.FileName), mfile.DownFile.Content, 0644); err != nil {
return nil, err
}
return mfile, nil
}

45
migrate/migrate_test.go Normal file
View File

@ -0,0 +1,45 @@
package migrate
import (
"github.com/mattes/migrate/searchpath"
"io/ioutil"
"testing"
)
func TestCreate(t *testing.T) {
tmpdir, err := ioutil.TempDir("/tmp", "migrate-postgres-test")
if err != nil {
t.Fatal(err)
}
searchpath.SetSearchPath(tmpdir)
if _, err := Create("postgres://localhost/migratetest?sslmode=disable", "test_migration"); err != nil {
t.Fatal(err)
}
if _, err := Create("postgres://localhost/migratetest?sslmode=disable", "another migration"); err != nil {
t.Fatal(err)
}
files, err := ioutil.ReadDir(tmpdir)
if err != nil {
t.Fatal(err)
}
if len(files) != 4 {
t.Fatal("Expected 2 new files, got", len(files))
}
expectFiles := []string{
"0001_test_migration.up.sql", "0001_test_migration.down.sql",
"0002_another_migration.up.sql", "0002_another_migration.down.sql",
}
foundCounter := 0
for _, expectFile := range expectFiles {
for _, file := range files {
if expectFile == file.Name() {
foundCounter += 1
break
}
}
}
if foundCounter != len(expectFiles) {
t.Error("not all expected files have been found")
}
}

46
searchpath/searchpath.go Normal file
View File

@ -0,0 +1,46 @@
package searchpath
import (
"errors"
"io/ioutil"
"regexp"
)
var searchpath []string
func SetSearchPath(paths ...string) {
searchpath = paths
}
func AppendSearchPath(path string) {
searchpath = append(searchpath, path)
}
func PrependSearchPath(path string) {
searchpath = append((searchpath)[:0], append([]string{path}, (searchpath)[0:]...)...)
}
func GetSearchPath() []string {
return searchpath
}
// FindPath scans files in the search paths and
// returns the path where the regex matches at least twice
func FindPath(regex *regexp.Regexp) (path string, err error) {
count := 0
for _, path := range searchpath {
// TODO refactor ioutil.ReadDir to read only first files per dir
files, err := ioutil.ReadDir(path)
if err == nil {
for _, file := range files {
if regex.MatchString(file.Name()) {
count += 1
}
if count >= 2 {
return path, nil
}
}
}
}
return "", errors.New("no path found")
}

View File

@ -0,0 +1,28 @@
package searchpath
import (
"testing"
)
func TestSetSearchPath(t *testing.T) {
SetSearchPath("a")
if len(searchpath) != 1 || searchpath[0] != "a" {
t.Error("SetSearchPath failed")
}
}
func TestAppendSearchPath(t *testing.T) {
SetSearchPath("a")
AppendSearchPath("b")
if len(searchpath) != 2 || searchpath[0] != "a" || searchpath[1] != "b" {
t.Error("AppendSearchPath failed")
}
}
func TestPrependSearchPath(t *testing.T) {
SetSearchPath("a")
PrependSearchPath("b")
if len(searchpath) != 2 || searchpath[0] != "b" || searchpath[1] != "a" {
t.Error("PrependSearchPath failed")
}
}