// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package flags import ( "fmt" "os" "path/filepath" "reflect" "sort" "strconv" "time" "github.com/mitchellh/mapstructure" ) // TODO (slackpad) - Trying out a different pattern here for config handling. // These classes support the flag.Value interface but work in a manner where // we can tell if they have been set. This lets us work with an all-pointer // config structure and merge it in a clean-ish way. If this ends up being a // good pattern we should pull this out into a reusable library. // ConfigDecodeHook should be passed to mapstructure in order to decode into // the *Value objects here. var ConfigDecodeHook = mapstructure.ComposeDecodeHookFunc( BoolToBoolValueFunc(), StringToDurationValueFunc(), StringToStringValueFunc(), Float64ToUintValueFunc(), ) // BoolValue provides a flag value that's aware if it has been set. type BoolValue struct { v *bool } // IsBoolFlag is an optional method of the flag.Value // interface which marks this value as boolean when // the return value is true. See flag.Value for details. func (b *BoolValue) IsBoolFlag() bool { return true } // Merge will overlay this value if it has been set. func (b *BoolValue) Merge(onto *bool) { if b.v != nil { *onto = *(b.v) } } // Set implements the flag.Value interface. func (b *BoolValue) Set(v string) error { if b.v == nil { b.v = new(bool) } var err error *(b.v), err = strconv.ParseBool(v) return err } // String implements the flag.Value interface. func (b *BoolValue) String() string { var current bool if b.v != nil { current = *(b.v) } return fmt.Sprintf("%v", current) } // BoolToBoolValueFunc is a mapstructure hook that looks for an incoming bool // mapped to a BoolValue and does the translation. func BoolToBoolValueFunc() mapstructure.DecodeHookFunc { return func( f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { if f.Kind() != reflect.Bool { return data, nil } val := BoolValue{} if t != reflect.TypeOf(val) { return data, nil } val.v = new(bool) *(val.v) = data.(bool) return val, nil } } // DurationValue provides a flag value that's aware if it has been set. type DurationValue struct { v *time.Duration } // Merge will overlay this value if it has been set. func (d *DurationValue) Merge(onto *time.Duration) { if d.v != nil { *onto = *(d.v) } } // Set implements the flag.Value interface. func (d *DurationValue) Set(v string) error { if d.v == nil { d.v = new(time.Duration) } var err error *(d.v), err = time.ParseDuration(v) return err } // String implements the flag.Value interface. func (d *DurationValue) String() string { var current time.Duration if d.v != nil { current = *(d.v) } return current.String() } // StringToDurationValueFunc is a mapstructure hook that looks for an incoming // string mapped to a DurationValue and does the translation. func StringToDurationValueFunc() mapstructure.DecodeHookFunc { return func( f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { if f.Kind() != reflect.String { return data, nil } val := DurationValue{} if t != reflect.TypeOf(val) { return data, nil } if err := val.Set(data.(string)); err != nil { return nil, err } return val, nil } } // StringValue provides a flag value that's aware if it has been set. type StringValue struct { v *string } // Merge will overlay this value if it has been set. func (s *StringValue) Merge(onto *string) { if s.v != nil { *onto = *(s.v) } } // Set implements the flag.Value interface. func (s *StringValue) Set(v string) error { if s.v == nil { s.v = new(string) } *(s.v) = v return nil } // String implements the flag.Value interface. func (s *StringValue) String() string { var current string if s.v != nil { current = *(s.v) } return current } // StringToStringValueFunc is a mapstructure hook that looks for an incoming // string mapped to a StringValue and does the translation. func StringToStringValueFunc() mapstructure.DecodeHookFunc { return func( f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { if f.Kind() != reflect.String { return data, nil } val := StringValue{} if t != reflect.TypeOf(val) { return data, nil } val.v = new(string) *(val.v) = data.(string) return val, nil } } // UintValue provides a flag value that's aware if it has been set. type UintValue struct { v *uint } // Merge will overlay this value if it has been set. func (u *UintValue) Merge(onto *uint) { if u.v != nil { *onto = *(u.v) } } // Set implements the flag.Value interface. func (u *UintValue) Set(v string) error { if u.v == nil { u.v = new(uint) } parsed, err := strconv.ParseUint(v, 0, 64) *(u.v) = (uint)(parsed) return err } // String implements the flag.Value interface. func (u *UintValue) String() string { var current uint if u.v != nil { current = *(u.v) } return fmt.Sprintf("%v", current) } // Float64ToUintValueFunc is a mapstructure hook that looks for an incoming // float64 mapped to a UintValue and does the translation. func Float64ToUintValueFunc() mapstructure.DecodeHookFunc { return func( f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { if f.Kind() != reflect.Float64 { return data, nil } val := UintValue{} if t != reflect.TypeOf(val) { return data, nil } fv := data.(float64) if fv < 0 { return nil, fmt.Errorf("value cannot be negative") } // The standard guarantees at least this, and this is fine for // values we expect to use in configs vs. being fancy with the // machine's size for uint. if fv > (1<<32 - 1) { return nil, fmt.Errorf("value is too large") } val.v = new(uint) *(val.v) = (uint)(fv) return val, nil } } // VisitFn is a callback that gets a chance to visit each file found during a // traversal with visit(). type VisitFn func(path string) error // Visit will call the visitor function on the path if it's a file, or for each // file in the path if it's a directory. Directories will not be recursed into, // and files in the directory will be visited in alphabetical order. func Visit(path string, visitor VisitFn) error { f, err := os.Open(path) if err != nil { return fmt.Errorf("error reading %q: %v", path, err) } defer f.Close() fi, err := f.Stat() if err != nil { return fmt.Errorf("error checking %q: %v", path, err) } if !fi.IsDir() { if err := visitor(path); err != nil { return fmt.Errorf("error in %q: %v", path, err) } return nil } contents, err := f.Readdir(-1) if err != nil { return fmt.Errorf("error listing %q: %v", path, err) } sort.Sort(dirEnts(contents)) for _, fi := range contents { if fi.IsDir() { continue } fullPath := filepath.Join(path, fi.Name()) if err := visitor(fullPath); err != nil { return fmt.Errorf("error in %q: %v", fullPath, err) } } return nil } // dirEnts applies sort.Interface to directory entries for sorting by name. type dirEnts []os.FileInfo func (d dirEnts) Len() int { return len(d) } func (d dirEnts) Less(i, j int) bool { return d[i].Name() < d[j].Name() } func (d dirEnts) Swap(i, j int) { d[i], d[j] = d[j], d[i] }