diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 0cd737a..5f791af 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -182,3 +182,28 @@ func versionCmd(m *migrate.Migrate) { log.Println(v) } } + +// numDownMigrationsFromArgs returns an int for number of migrations to apply +// and a bool indicating if we need a confirm before applying +func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { + if applyAll { + if len(args) > 0 { + return 0, false, errors.New("-all cannot be used with other arguments") + } + return -1, false, nil + } + + switch len(args) { + case 0: + return -1, true, nil + case 1: + downValue := args[0] + n, err := strconv.ParseUint(downValue, 10, 64) + if err != nil { + return 0, false, errors.New("can't read limit argument N") + } + return int(n), false, nil + default: + return 0, false, errors.New("too many arguments") + } +} diff --git a/internal/cli/commands_test.go b/internal/cli/commands_test.go index 040815b..a0b0856 100644 --- a/internal/cli/commands_test.go +++ b/internal/cli/commands_test.go @@ -78,3 +78,41 @@ func TestNextSeq(t *testing.T) { }) } } + +func TestNumDownFromArgs(t *testing.T) { + cases := []struct { + name string + args []string + applyAll bool + expectedNeedConfirm bool + expectedNum int + expectedErrStr string + }{ + {"no args", []string{}, false, true, -1, ""}, + {"down all", []string{}, true, false, -1, ""}, + {"down 5", []string{"5"}, false, false, 5, ""}, + {"down N", []string{"N"}, false, false, 0, "can't read limit argument N"}, + {"extra arg after -all", []string{"5"}, true, false, 0, "-all cannot be used with other arguments"}, + {"extra arg before -all", []string{"5", "-all"}, false, false, 0, "too many arguments"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + num, needsConfirm, err := numDownMigrationsFromArgs(c.applyAll, c.args) + if needsConfirm != c.expectedNeedConfirm { + t.Errorf("Incorrect needsConfirm was: %v wanted %v", needsConfirm, c.expectedNeedConfirm) + } + + if num != c.expectedNum { + t.Errorf("Incorrect num was: %v wanted %v", num, c.expectedNum) + } + + if err != nil { + if err.Error() != c.expectedErrStr { + t.Error("Incorrect error: " + err.Error() + " != " + c.expectedErrStr) + } + } else if c.expectedErrStr != "" { + t.Error("Expected error: " + c.expectedErrStr + " but got nil instead") + } + }) + } +} diff --git a/internal/cli/main.go b/internal/cli/main.go index 9992bf2..7860639 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -37,7 +37,7 @@ func Main(version string) { Options: -source Location of the migrations (driver://url) - -path Shorthand for -source=file://path + -path Shorthand for -source=file://path -database Run migrations against this database (driver://url) -prefetch N Number of migrations to load in advance before executing (default 10) -lock-timeout N Allow N seconds to acquire database lock (default 15) @@ -186,16 +186,33 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n") log.fatalErr(migraterErr) } - limit := -1 - if flag.Arg(1) != "" { - n, err := strconv.ParseUint(flag.Arg(1), 10, 64) - if err != nil { - log.fatal("error: can't read limit argument N") - } - limit = int(n) + downFlagSet := flag.NewFlagSet("down", flag.ExitOnError) + applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") + + args := flag.Args()[1:] + if err := downFlagSet.Parse(args); err != nil { + log.fatalErr(err) } - downCmd(migrater, limit) + downArgs := downFlagSet.Args() + num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) + if err != nil { + log.fatalErr(err) + } + if needsConfirm { + log.Println("Are you sure you want to apply all down migrations? [y/N]") + var response string + fmt.Scanln(&response) + response = strings.ToLower(strings.TrimSpace(response)) + + if response == "y" { + log.Println("Applying all down migrations") + } else { + log.fatal("Not applying all down migrations") + } + } + + downCmd(migrater, num) if log.verbose { log.Println("Finished after", time.Since(startTime))