diff --git a/file/file.go b/file/file.go index 57addb3..54992a4 100644 --- a/file/file.go +++ b/file/file.go @@ -163,9 +163,18 @@ func ReadMigrationFiles(path string, filenameRegex *regexp.Regexp) (files Migrat d direction.Direction } tmpFiles := make([]*tmpFile, 0) + tmpFileMap := map[uint64]map[direction.Direction]tmpFile{} for _, file := range ioFiles { version, name, d, err := parseFilenameSchema(file.Name(), filenameRegex) if err == nil { + if _, ok := tmpFileMap[version]; !ok { + tmpFileMap[version] = map[direction.Direction]tmpFile{} + } + if existing, ok := tmpFileMap[version][d]; !ok { + tmpFileMap[version][d] = tmpFile{version: version, name: name, filename: file.Name(), d: d} + } else { + return nil, fmt.Errorf("duplicate migration file version %d : %q and %q", version, existing.filename, file.Name()) + } tmpFiles = append(tmpFiles, &tmpFile{version, name, file.Name(), d}) } } diff --git a/file/file_test.go b/file/file_test.go index 482c1a9..b9ddeab 100644 --- a/file/file_test.go +++ b/file/file_test.go @@ -209,3 +209,45 @@ func TestFiles(t *testing.T) { } } + +func TestDuplicateFiles(t *testing.T) { + dups := []string{ + "001_migration.up.sql", + "001_duplicate.up.sql", + } + + root, cleanFn, err := makeFiles("TestDuplicateFiles", dups...) + defer cleanFn() + + if err != nil { + t.Fatal(err) + } + + _, err = ReadMigrationFiles(root, FilenameRegex("sql")) + if err == nil { + t.Fatal("Expected duplicate migration file error") + } +} + +// makeFiles takes an identifier, and a list of file names and uses them to create a temporary +// directory populated with files named with the names passed in. makeFiles returns the root +// directory name, and a func suitable for a defer cleanup to remove the temporary files after +// the calling function exits. +func makeFiles(testname string, names ...string) (root string, cleanup func(), err error) { + cleanup = func() {} + root, err = ioutil.TempDir("/tmp", testname) + if err != nil { + return + } + cleanup = func() { os.RemoveAll(root) } + if err = ioutil.WriteFile(path.Join(root, "nonsense.txt"), nil, 0755); err != nil { + return + } + + for _, name := range names { + if err = ioutil.WriteFile(path.Join(root, name), nil, 0755); err != nil { + return + } + } + return +}