diff --git a/agent/connect/ca/provider_consul_config.go b/agent/connect/ca/provider_consul_config.go index 9c94f0e62a..f3af3a4898 100644 --- a/agent/connect/ca/provider_consul_config.go +++ b/agent/connect/ca/provider_consul_config.go @@ -2,7 +2,6 @@ package ca import ( "fmt" - "reflect" "time" "github.com/hashicorp/consul/agent/structs" @@ -15,7 +14,7 @@ func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderC } decodeConf := &mapstructure.DecoderConfig{ - DecodeHook: ParseDurationFunc(), + DecodeHook: structs.ParseDurationFunc(), Result: &config, WeaklyTypedInput: true, } @@ -40,48 +39,6 @@ func ParseConsulCAConfig(raw map[string]interface{}) (*structs.ConsulCAProviderC return &config, nil } -// ParseDurationFunc is a mapstructure hook for decoding a string or -// []uint8 into a time.Duration value. -func ParseDurationFunc() mapstructure.DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data interface{}) (interface{}, error) { - var v time.Duration - if t != reflect.TypeOf(v) { - return data, nil - } - - switch { - case f.Kind() == reflect.String: - if dur, err := time.ParseDuration(data.(string)); err != nil { - return nil, err - } else { - v = dur - } - return v, nil - case f == reflect.SliceOf(reflect.TypeOf(uint8(0))): - s := Uint8ToString(data.([]uint8)) - if dur, err := time.ParseDuration(s); err != nil { - return nil, err - } else { - v = dur - } - return v, nil - default: - return data, nil - } - } -} - -func Uint8ToString(bs []uint8) string { - b := make([]byte, len(bs)) - for i, v := range bs { - b[i] = byte(v) - } - return string(b) -} - func defaultCommonConfig() structs.CommonCAProviderConfig { return structs.CommonCAProviderConfig{ LeafCertTTL: 3 * 24 * time.Hour, diff --git a/agent/connect/ca/provider_consul_test.go b/agent/connect/ca/provider_consul_test.go index 2a37f1b94f..2092d934d4 100644 --- a/agent/connect/ca/provider_consul_test.go +++ b/agent/connect/ca/provider_consul_test.go @@ -65,7 +65,10 @@ func testConsulCAConfig() *structs.CAConfiguration { return &structs.CAConfiguration{ ClusterID: "asdf", Provider: "consul", - Config: map[string]interface{}{}, + Config: map[string]interface{}{ + // Tests duration parsing after msgpack type mangling during raft apply. + "LeafCertTTL": []uint8("72h"), + }, } } diff --git a/agent/connect/ca/provider_vault.go b/agent/connect/ca/provider_vault.go index 743ea8957e..eaf3646090 100644 --- a/agent/connect/ca/provider_vault.go +++ b/agent/connect/ca/provider_vault.go @@ -289,7 +289,7 @@ func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderCon } decodeConf := &mapstructure.DecoderConfig{ - DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + DecodeHook: structs.ParseDurationFunc(), Result: &config, WeaklyTypedInput: true, } diff --git a/agent/connect/ca/provider_vault_test.go b/agent/connect/ca/provider_vault_test.go index 5c248e8dc4..05f8c36448 100644 --- a/agent/connect/ca/provider_vault_test.go +++ b/agent/connect/ca/provider_vault_test.go @@ -32,6 +32,8 @@ func testVaultClusterWithConfig(t *testing.T, rawConf map[string]interface{}) (* "Token": token, "RootPKIPath": "pki-root/", "IntermediatePKIPath": "pki-intermediate/", + // Tests duration parsing after msgpack type mangling during raft apply. + "LeafCertTTL": []uint8("72h"), } for k, v := range rawConf { conf[k] = v diff --git a/agent/connect_ca_endpoint.go b/agent/connect_ca_endpoint.go index 82d1233699..402797a8fa 100644 --- a/agent/connect_ca_endpoint.go +++ b/agent/connect_ca_endpoint.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" - "github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/structs" ) @@ -81,7 +80,7 @@ func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *ht func fixupConfig(conf *structs.CAConfiguration) { for k, v := range conf.Config { if raw, ok := v.([]uint8); ok { - strVal := ca.Uint8ToString(raw) + strVal := structs.Uint8ToString(raw) conf.Config[k] = strVal switch conf.Provider { case structs.ConsulCAProvider: diff --git a/agent/consul/leader_test.go b/agent/consul/leader_test.go index fe967c2f5c..ca42a5404a 100644 --- a/agent/consul/leader_test.go +++ b/agent/consul/leader_test.go @@ -1038,10 +1038,10 @@ func TestLeader_CARootPruning(t *testing.T) { newConfig := &structs.CAConfiguration{ Provider: "consul", Config: map[string]interface{}{ - "LeafCertTTL": 500 * time.Millisecond, + "LeafCertTTL": "500ms", "PrivateKey": newKey, "RootCert": "", - "RotationPeriod": 90 * 24 * time.Hour, + "RotationPeriod": "2160h", "SkipValidate": true, }, } diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 43dcd13ff2..bbd3ef60d9 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -100,7 +100,8 @@ func testServerConfig(t *testing.T) (string, *Config) { Config: map[string]interface{}{ "PrivateKey": "", "RootCert": "", - "RotationPeriod": 90 * 24 * time.Hour, + "RotationPeriod": "2160h", + "LeafCertTTL": "72h", }, } diff --git a/agent/structs/connect_ca.go b/agent/structs/connect_ca.go index 1e869fd45d..76162c2f65 100644 --- a/agent/structs/connect_ca.go +++ b/agent/structs/connect_ca.go @@ -2,6 +2,7 @@ package structs import ( "fmt" + "reflect" "time" "github.com/mitchellh/mapstructure" @@ -202,8 +203,9 @@ func (c *CAConfiguration) GetCommonConfig() (*CommonCAProviderConfig, error) { var config CommonCAProviderConfig decodeConf := &mapstructure.DecoderConfig{ - DecodeHook: mapstructure.StringToTimeDurationHookFunc(), - Result: &config, + DecodeHook: ParseDurationFunc(), + Result: &config, + WeaklyTypedInput: true, } decoder, err := mapstructure.NewDecoder(decodeConf) @@ -265,3 +267,45 @@ type VaultCAProviderConfig struct { RootPKIPath string IntermediatePKIPath string } + +// ParseDurationFunc is a mapstructure hook for decoding a string or +// []uint8 into a time.Duration value. +func ParseDurationFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}) (interface{}, error) { + var v time.Duration + if t != reflect.TypeOf(v) { + return data, nil + } + + switch { + case f.Kind() == reflect.String: + if dur, err := time.ParseDuration(data.(string)); err != nil { + return nil, err + } else { + v = dur + } + return v, nil + case f == reflect.SliceOf(reflect.TypeOf(uint8(0))): + s := Uint8ToString(data.([]uint8)) + if dur, err := time.ParseDuration(s); err != nil { + return nil, err + } else { + v = dur + } + return v, nil + default: + return data, nil + } + } +} + +func Uint8ToString(bs []uint8) string { + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + return string(b) +} diff --git a/agent/structs/connect_ca_test.go b/agent/structs/connect_ca_test.go new file mode 100644 index 0000000000..dd185ebe1f --- /dev/null +++ b/agent/structs/connect_ca_test.go @@ -0,0 +1,58 @@ +package structs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCAConfiguration_GetCommonConfig(t *testing.T) { + tests := []struct { + name string + cfg *CAConfiguration + want *CommonCAProviderConfig + wantErr bool + }{ + { + name: "basic defaults", + cfg: &CAConfiguration{ + Config: map[string]interface{}{ + "RotationPeriod": "2160h", + "LeafCertTTL": "72h", + }, + }, + want: &CommonCAProviderConfig{ + LeafCertTTL: 72 * time.Hour, + }, + }, + { + // Note that this is currently what is actually stored in MemDB, I think + // due to a trip through msgpack somewhere but I'm not really sure why + // since the defaults are applied on the server and so should probably use + // direct RPC that bypasses encoding? Either way this case is important + // because it reflects the actual data as it's stored in state which is + // what matters in real life. + name: "basic defaults after encoding fun", + cfg: &CAConfiguration{ + Config: map[string]interface{}{ + "RotationPeriod": []uint8("2160h"), + "LeafCertTTL": []uint8("72h"), + }, + }, + want: &CommonCAProviderConfig{ + LeafCertTTL: 72 * time.Hour, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.cfg.GetCommonConfig() + if (err != nil) != tt.wantErr { + t.Errorf("CAConfiguration.GetCommonConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + require.Equal(t, tt.want, got) + }) + } +}