From 197dc10a7fe42d9913485b1c80b2bd205847b62f Mon Sep 17 00:00:00 2001 From: Kyle Havlovitz Date: Tue, 7 Feb 2017 20:05:16 -0500 Subject: [PATCH] Add utility types to enable checking for unset flags --- command/base/command.go | 28 ++- command/base/config_util.go | 319 ++++++++++++++++++++++++++++++ command/base/config_util_test.go | 128 ++++++++++++ test/command/merge/a.json | 7 + test/command/merge/b.json | 7 + test/command/merge/empty/foo.json | 3 + test/command/merge/nope | 7 + test/command/merge/subdir/c.json | 7 + test/command/merge/zero.json | 0 9 files changed, 489 insertions(+), 17 deletions(-) create mode 100644 command/base/config_util.go create mode 100644 command/base/config_util_test.go create mode 100644 test/command/merge/a.json create mode 100644 test/command/merge/b.json create mode 100644 test/command/merge/empty/foo.json create mode 100644 test/command/merge/nope create mode 100644 test/command/merge/subdir/c.json create mode 100644 test/command/merge/zero.json diff --git a/command/base/command.go b/command/base/command.go index 032b802633..ba159afd19 100644 --- a/command/base/command.go +++ b/command/base/command.go @@ -35,10 +35,10 @@ type Command struct { flagSet *flag.FlagSet // These are the options which correspond to the HTTP API options - httpAddr string - token string - datacenter string - stale bool + httpAddr stringValue + token stringValue + datacenter stringValue + stale boolValue } // HTTPClient returns a client with the parsed flags. It panics if the command @@ -52,15 +52,9 @@ func (c *Command) HTTPClient() (*api.Client, error) { } config := api.DefaultConfig() - if c.datacenter != "" { - config.Datacenter = c.datacenter - } - if c.httpAddr != "" { - config.Address = c.httpAddr - } - if c.token != "" { - config.Token = c.token - } + c.httpAddr.Merge(&config.Address) + c.token.Merge(&config.Token) + c.datacenter.Merge(&config.Datacenter) c.Ui.Info(fmt.Sprintf("client http addr: %s", config.Address)) return api.NewClient(config) } @@ -71,12 +65,12 @@ func (c *Command) httpFlagsClient(f *flag.FlagSet) *flag.FlagSet { f = flag.NewFlagSet("", flag.ContinueOnError) } - f.StringVar(&c.httpAddr, "http-addr", "", + f.Var(&c.httpAddr, "http-addr", "Address and port to the Consul HTTP agent. The value can be an IP "+ "address or DNS address, but it must also include the port. This can "+ "also be specified via the CONSUL_HTTP_ADDR environment variable. The "+ "default value is 127.0.0.1:8500.") - f.StringVar(&c.token, "token", "", + f.Var(&c.token, "token", "ACL token to use in the request. This can also be specified via the "+ "CONSUL_HTTP_TOKEN environment variable. If unspecified, the query will "+ "default to the token of the Consul agent at the HTTP address.") @@ -90,10 +84,10 @@ func (c *Command) httpFlagsServer(f *flag.FlagSet) *flag.FlagSet { f = flag.NewFlagSet("", flag.ContinueOnError) } - f.StringVar(&c.datacenter, "datacenter", "", + f.Var(&c.datacenter, "datacenter", "Name of the datacenter to query. If unspecified, this will default to "+ "the datacenter of the queried agent.") - f.BoolVar(&c.stale, "stale", false, + f.Var(&c.stale, "stale", "Permit any Consul server (non-leader) to respond to this request. This "+ "allows for lower latency and higher throughput, but can result in "+ "stale data. This option has no effect on non-read operations. The "+ diff --git a/command/base/config_util.go b/command/base/config_util.go new file mode 100644 index 0000000000..8496992595 --- /dev/null +++ b/command/base/config_util.go @@ -0,0 +1,319 @@ +package base + +import ( + "fmt" + "reflect" + "strconv" + "time" + + "github.com/mitchellh/mapstructure" + "os" + "path/filepath" + "sort" +) + +// TODO (slackpad) - Trying out a different pattern here for config handling. +// These classes support the flag.Value interface but work in a manner where +// we can tell if they have been set. This lets us work with an all-pointer +// config structure and merge it in a clean-ish way. If this ends up being a +// good pattern we should pull this out into a reusable library. + +// configDecodeHook should be passed to mapstructure in order to decode into +// the *Value objects here. +var configDecodeHook = mapstructure.ComposeDecodeHookFunc( + boolToBoolValueFunc(), + stringToDurationValueFunc(), + stringToStringValueFunc(), + float64ToUintValueFunc(), +) + +// boolValue provides a flag value that's aware if it has been set. +type boolValue struct { + v *bool +} + +// See flag.Value. +func (b *boolValue) IsBoolFlag() bool { + return true +} + +// Merge will overlay this value if it has been set. +func (b *boolValue) Merge(onto *bool) { + if b.v != nil { + *onto = *(b.v) + } +} + +// See flag.Value. +func (b *boolValue) Set(v string) error { + if b.v == nil { + b.v = new(bool) + } + var err error + *(b.v), err = strconv.ParseBool(v) + return err +} + +// See flag.Value. +func (b *boolValue) String() string { + var current bool + if b.v != nil { + current = *(b.v) + } + return fmt.Sprintf("%v", current) +} + +// boolToBoolValueFunc is a mapstructure hook that looks for an incoming bool +// mapped to a boolValue and does the translation. +func boolToBoolValueFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.Bool { + return data, nil + } + + val := boolValue{} + if t != reflect.TypeOf(val) { + return data, nil + } + + val.v = new(bool) + *(val.v) = data.(bool) + return val, nil + } +} + +// durationValue provides a flag value that's aware if it has been set. +type durationValue struct { + v *time.Duration +} + +// Merge will overlay this value if it has been set. +func (d *durationValue) Merge(onto *time.Duration) { + if d.v != nil { + *onto = *(d.v) + } +} + +// See flag.Value. +func (d *durationValue) Set(v string) error { + if d.v == nil { + d.v = new(time.Duration) + } + var err error + *(d.v), err = time.ParseDuration(v) + return err +} + +// See flag.Value. +func (d *durationValue) String() string { + var current time.Duration + if d.v != nil { + current = *(d.v) + } + return current.String() +} + +// stringToDurationValueFunc is a mapstructure hook that looks for an incoming +// string mapped to a durationValue and does the translation. +func stringToDurationValueFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + + val := durationValue{} + if t != reflect.TypeOf(val) { + return data, nil + } + if err := val.Set(data.(string)); err != nil { + return nil, err + } + return val, nil + } +} + +// stringValue provides a flag value that's aware if it has been set. +type stringValue struct { + v *string +} + +// Merge will overlay this value if it has been set. +func (s *stringValue) Merge(onto *string) { + if s.v != nil { + *onto = *(s.v) + } +} + +// See flag.Value. +func (s *stringValue) Set(v string) error { + if s.v == nil { + s.v = new(string) + } + *(s.v) = v + return nil +} + +// See flag.Value. +func (s *stringValue) String() string { + var current string + if s.v != nil { + current = *(s.v) + } + return current +} + +// stringToStringValueFunc is a mapstructure hook that looks for an incoming +// string mapped to a stringValue and does the translation. +func stringToStringValueFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + + val := stringValue{} + if t != reflect.TypeOf(val) { + return data, nil + } + val.v = new(string) + *(val.v) = data.(string) + return val, nil + } +} + +// uintValue provides a flag value that's aware if it has been set. +type uintValue struct { + v *uint +} + +// Merge will overlay this value if it has been set. +func (u *uintValue) Merge(onto *uint) { + if u.v != nil { + *onto = *(u.v) + } +} + +// See flag.Value. +func (u *uintValue) Set(v string) error { + if u.v == nil { + u.v = new(uint) + } + parsed, err := strconv.ParseUint(v, 0, 64) + *(u.v) = (uint)(parsed) + return err +} + +// See flag.Value. +func (u *uintValue) String() string { + var current uint + if u.v != nil { + current = *(u.v) + } + return fmt.Sprintf("%v", current) +} + +// float64ToUintValueFunc is a mapstructure hook that looks for an incoming +// float64 mapped to a uintValue and does the translation. +func float64ToUintValueFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + if f.Kind() != reflect.Float64 { + return data, nil + } + + val := uintValue{} + if t != reflect.TypeOf(val) { + return data, nil + } + + fv := data.(float64) + if fv < 0 { + return nil, fmt.Errorf("value cannot be negative") + } + + // The standard guarantees at least this, and this is fine for + // values we expect to use in configs vs. being fancy with the + // machine's size for uint. + if fv > (1<<32 - 1) { + return nil, fmt.Errorf("value is too large") + } + + val.v = new(uint) + *(val.v) = (uint)(fv) + return val, nil + } +} + +// visitFn is a callback that gets a chance to visit each file found during a +// traversal with visit(). +type visitFn func(path string) error + +// visit will call the visitor function on the path if it's a file, or for each +// file in the path if it's a directory. Directories will not be recursed into, +// and files in the directory will be visited in alphabetical order. +func visit(path string, visitor visitFn) error { + f, err := os.Open(path) + if err != nil { + return fmt.Errorf("error reading %q: %v", path, err) + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return fmt.Errorf("error checking %q: %v", path, err) + } + + if !fi.IsDir() { + if err := visitor(path); err != nil { + return fmt.Errorf("error in %q: %v", path, err) + } + return nil + } + + contents, err := f.Readdir(-1) + if err != nil { + return fmt.Errorf("error listing %q: %v", path, err) + } + + sort.Sort(dirEnts(contents)) + for _, fi := range contents { + if fi.IsDir() { + continue + } + + fullPath := filepath.Join(path, fi.Name()) + if err := visitor(fullPath); err != nil { + return fmt.Errorf("error in %q: %v", fullPath, err) + } + } + + return nil +} + +// dirEnts applies sort.Interface to directory entries for sorting by name. +type dirEnts []os.FileInfo + +// See sort.Interface. +func (d dirEnts) Len() int { + return len(d) +} + +// See sort.Interface. +func (d dirEnts) Less(i, j int) bool { + return d[i].Name() < d[j].Name() +} + +// See sort.Interface. +func (d dirEnts) Swap(i, j int) { + d[i], d[j] = d[j], d[i] +} diff --git a/command/base/config_util_test.go b/command/base/config_util_test.go new file mode 100644 index 0000000000..b8264751fb --- /dev/null +++ b/command/base/config_util_test.go @@ -0,0 +1,128 @@ +package base + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "path" + "reflect" +) + +func TestConfigUtil_Values(t *testing.T) { + type config struct { + B boolValue `mapstructure:"bool"` + D durationValue `mapstructure:"duration"` + S stringValue `mapstructure:"string"` + U uintValue `mapstructure:"uint"` + } + + cases := []struct { + in string + success string + failure string + }{ + { + `{ }`, + `"false" "0s" "" "0"`, + "", + }, + { + `{ "bool": true, "duration": "2h", "string": "hello", "uint": 23 }`, + `"true" "2h0m0s" "hello" "23"`, + "", + }, + { + `{ "bool": "nope" }`, + "", + "got 'string'", + }, + { + `{ "duration": "nope" }`, + "", + "invalid duration nope", + }, + { + `{ "string": 123 }`, + "", + "got 'float64'", + }, + { + `{ "uint": -1 }`, + "", + "value cannot be negative", + }, + { + `{ "uint": 4294967296 }`, + "", + "value is too large", + }, + } + for i, c := range cases { + var raw interface{} + dec := json.NewDecoder(bytes.NewBufferString(c.in)) + if err := dec.Decode(&raw); err != nil { + t.Fatalf("(case %d) err: %v", i, err) + } + + var r config + msdec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: configDecodeHook, + Result: &r, + ErrorUnused: true, + }) + if err != nil { + t.Fatalf("(case %d) err: %v", i, err) + } + + err = msdec.Decode(raw) + if c.failure != "" { + if err == nil || !strings.Contains(err.Error(), c.failure) { + t.Fatalf("(case %d) err: %v", i, err) + } + continue + } + if err != nil { + t.Fatalf("(case %d) err: %v", i, err) + } + + actual := fmt.Sprintf("%q %q %q %q", + r.B.String(), + r.D.String(), + r.S.String(), + r.U.String()) + if actual != c.success { + t.Fatalf("(case %d) bad: %s", i, actual) + } + } +} + +func TestConfigUtil_Visit(t *testing.T) { + var trail []string + visitor := func(path string) error { + trail = append(trail, path) + return nil + } + + basePath := "../../test/command/merge" + if err := visit(basePath, visitor); err != nil { + t.Fatalf("err: %v", err) + } + if err := visit(path.Join(basePath, "subdir", "c.json"), visitor); err != nil { + t.Fatalf("err: %v", err) + } + + expected := []string{ + path.Join(basePath, "a.json"), + path.Join(basePath, "b.json"), + path.Join(basePath, "nope"), + path.Join(basePath, "zero.json"), + path.Join(basePath, "subdir", "c.json"), + } + if !reflect.DeepEqual(trail, expected) { + t.Fatalf("bad: %#v", trail) + } +} diff --git a/test/command/merge/a.json b/test/command/merge/a.json new file mode 100644 index 0000000000..50d6c60128 --- /dev/null +++ b/test/command/merge/a.json @@ -0,0 +1,7 @@ +{ + "snapshot_agent": { + "snapshot": { + "interval": "1h" + } + } +} diff --git a/test/command/merge/b.json b/test/command/merge/b.json new file mode 100644 index 0000000000..ddf3dff857 --- /dev/null +++ b/test/command/merge/b.json @@ -0,0 +1,7 @@ +{ + "snapshot_agent": { + "snapshot": { + "interval": "2h" + } + } +} diff --git a/test/command/merge/empty/foo.json b/test/command/merge/empty/foo.json new file mode 100644 index 0000000000..756fa36b9b --- /dev/null +++ b/test/command/merge/empty/foo.json @@ -0,0 +1,3 @@ +{ + "not_related": true +} diff --git a/test/command/merge/nope b/test/command/merge/nope new file mode 100644 index 0000000000..0c947af68a --- /dev/null +++ b/test/command/merge/nope @@ -0,0 +1,7 @@ +{ + "snapshot_agent": { + "snapshot": { + "interval": "3h" + } + } +} diff --git a/test/command/merge/subdir/c.json b/test/command/merge/subdir/c.json new file mode 100644 index 0000000000..a72381acb0 --- /dev/null +++ b/test/command/merge/subdir/c.json @@ -0,0 +1,7 @@ +{ + "snapshot_agent": { + "snapshot": { + "interval": "5h" + } + } +} diff --git a/test/command/merge/zero.json b/test/command/merge/zero.json new file mode 100644 index 0000000000..e69de29bb2