diff --git a/tlsutil/config.go b/tlsutil/config.go index 037f5978f2..8d66cf9750 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -34,8 +34,8 @@ type DCWrapper func(dc string, conn net.Conn) (net.Conn, error) // a constant value. This is usually done by currying DCWrapper. type Wrapper func(conn net.Conn) (net.Conn, error) -// TLSLookup maps the tls_min_version configuration to the internal value -var TLSLookup = map[string]uint16{ +// tlsLookup maps the tls_min_version configuration to the internal value +var tlsLookup = map[string]uint16{ "": tls.VersionTLS10, // default in golang "tls10": tls.VersionTLS10, "tls11": tls.VersionTLS11, @@ -43,9 +43,6 @@ var TLSLookup = map[string]uint16{ "tls13": tls.VersionTLS13, } -// TLSVersions has all the keys from the map above. -var TLSVersions = strings.Join(tlsVersions(), ", ") - // Config used to create tls.Config type Config struct { // VerifyIncoming is used to verify the authenticity of incoming @@ -133,7 +130,7 @@ type Config struct { func tlsVersions() []string { versions := []string{} - for v := range TLSLookup { + for v := range tlsLookup { if v != "" { versions = append(versions, v) } @@ -142,11 +139,6 @@ func tlsVersions() []string { return versions } -// KeyPair is used to open and parse a certificate and key file -func (c *Config) KeyPair() (*tls.Certificate, error) { - return loadKeyPair(c.CertFile, c.KeyFile) -} - // SpecificDC is used to invoke a static datacenter // and turns a DCWrapper into a Wrapper type. func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper { @@ -158,6 +150,8 @@ func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper { } } +// autoTLS stores configuration that is received from the auto-encrypt or +// auto-config features. type autoTLS struct { manualCAPems []string connectCAPems []string @@ -165,25 +159,31 @@ type autoTLS struct { verifyServerHostname bool } -func (a *autoTLS) caPems() []string { +func (a autoTLS) caPems() []string { return append(a.manualCAPems, a.connectCAPems...) } +// manual stores the TLS CA and cert received from Configurator.Update which +// generally comes from the agent configuration. type manual struct { caPems []string cert *tls.Certificate } -// Configurator holds a Config and is responsible for generating all the -// *tls.Config necessary for Consul. Except the one in the api package. +// Configurator provides tls.Config and net.Dial wrappers to enable TLS for +// clients and servers, for both HTTPS and RPC requests. +// Configurator receives an initial TLS configuration from agent configuration, +// and receives updates from config reloads, auto-encrypt, and auto-config. type Configurator struct { // lock synchronizes access to all fields on this struct except for logger and version. - lock sync.RWMutex - base *Config - autoTLS *autoTLS - manual *manual + lock sync.RWMutex + base *Config + autoTLS autoTLS + manual manual + caPool *x509.CertPool + // peerDatacenterUseTLS is a map of DC name to a bool indicating if the DC + // uses TLS for RPC requests. peerDatacenterUseTLS map[string]bool - caPool *x509.CertPool // logger is not protected by a lock. It must never be changed after // Configurator is created. @@ -204,8 +204,6 @@ func NewConfigurator(config Config, logger hclog.Logger) (*Configurator, error) c := &Configurator{ logger: logger.Named(logging.TLSUtil), - manual: &manual{}, - autoTLS: &autoTLS{}, peerDatacenterUseTLS: map[string]bool{}, } err := c.Update(config) @@ -282,7 +280,7 @@ func (c *Configurator) UpdateAutoTLSCA(connectCAPems []string) error { return nil } -// UpdateAutoTLSCert +// UpdateAutoTLSCert receives the updated Auto-Encrypt certificate. func (c *Configurator) UpdateAutoTLSCert(pub, priv string) error { cert, err := tls.X509KeyPair([]byte(pub), []byte(priv)) if err != nil { @@ -298,8 +296,8 @@ func (c *Configurator) UpdateAutoTLSCert(pub, priv string) error { return nil } -// UpdateAutoTLS sets everything under autoEncrypt. This is being called on the -// client when it received its cert from AutoEncrypt/AutoConfig endpoints. +// UpdateAutoTLS receives updates from Auto-Config, only expected to be called on +// client agents. func (c *Configurator) UpdateAutoTLS(manualCAPems, connectCAPems []string, pub, priv string, verifyServerHostname bool) error { cert, err := tls.X509KeyPair([]byte(pub), []byte(priv)) if err != nil { @@ -364,8 +362,9 @@ func pool(pems []string) (*x509.CertPool, error) { func validateConfig(config Config, pool *x509.CertPool, cert *tls.Certificate) error { // Check if a minimum TLS version was set if config.TLSMinVersion != "" { - if _, ok := TLSLookup[config.TLSMinVersion]; !ok { - return fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [%s]", config.TLSMinVersion, TLSVersions) + if _, ok := tlsLookup[config.TLSMinVersion]; !ok { + versions := strings.Join(tlsVersions(), ", ") + return fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [%s]", config.TLSMinVersion, versions) } } @@ -517,10 +516,10 @@ func (c *Configurator) commonTLSConfig(verifyIncoming bool) *tls.Config { tlsConfig.ClientCAs = c.caPool tlsConfig.RootCAs = c.caPool - // This is possible because TLSLookup also contains "" with golang's + // This is possible because tlsLookup also contains "" with golang's // default (tls10). And because the initial check makes sure the // version correctly matches. - tlsConfig.MinVersion = TLSLookup[c.base.TLSMinVersion] + tlsConfig.MinVersion = tlsLookup[c.base.TLSMinVersion] // Set ClientAuth if necessary if verifyIncoming { @@ -794,9 +793,7 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper { return nil } - return func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error) { - return c.wrapALPNTLSClient(dc, nodeName, alpnProto, conn) - } + return c.wrapALPNTLSClient } // AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 1282cda958..0811c00ac8 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -405,7 +405,7 @@ func TestConfig_ParseCiphers(t *testing.T) { require.Equal(t, []uint16{}, v) } -func TestConfigurator_loadKeyPair(t *testing.T) { +func TestLoadKeyPair(t *testing.T) { type variant struct { cert, key string shoulderr bool @@ -422,24 +422,20 @@ func TestConfigurator_loadKeyPair(t *testing.T) { false, false}, } for i, v := range variants { - info := fmt.Sprintf("case %d", i) - cert1, err1 := loadKeyPair(v.cert, v.key) - config := &Config{CertFile: v.cert, KeyFile: v.key} - cert2, err2 := config.KeyPair() - if v.shoulderr { - require.Error(t, err1, info) - require.Error(t, err2, info) - } else { - require.NoError(t, err1, info) - require.NoError(t, err2, info) - } - if v.isnil { - require.Nil(t, cert1, info) - require.Nil(t, cert2, info) - } else { - require.NotNil(t, cert1, info) - require.NotNil(t, cert2, info) - } + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + cert, err := loadKeyPair(v.cert, v.key) + if v.shoulderr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + if v.isnil { + require.Nil(t, cert) + } else { + require.NotNil(t, cert) + } + }) } } @@ -510,7 +506,7 @@ func TestConfigurator_ErrorPropagation(t *testing.T) { variants = append(variants, variant{Config{TLSMinVersion: v}, false, false}) } - c := Configurator{autoTLS: &autoTLS{}, manual: &manual{}} + c := Configurator{} for i, v := range variants { info := fmt.Sprintf("case %d, config: %+v", i, v.config) _, err1 := NewConfigurator(v.config, nil) @@ -518,7 +514,7 @@ func TestConfigurator_ErrorPropagation(t *testing.T) { var err3 error if !v.excludeCheck { - cert, err := v.config.KeyPair() + cert, err := loadKeyPair(v.config.CertFile, v.config.KeyFile) require.NoError(t, err, info) pems, err := LoadCAs(v.config.CAFile, v.config.CAPath) require.NoError(t, err, info) @@ -708,19 +704,19 @@ func TestConfigurator_CommonTLSConfigCAs(t *testing.T) { func TestConfigurator_CommonTLSConfigTLSMinVersion(t *testing.T) { c, err := NewConfigurator(Config{TLSMinVersion: ""}, nil) require.NoError(t, err) - require.Equal(t, c.commonTLSConfig(false).MinVersion, TLSLookup["tls10"]) + require.Equal(t, c.commonTLSConfig(false).MinVersion, tlsLookup["tls10"]) for _, version := range tlsVersions() { require.NoError(t, c.Update(Config{TLSMinVersion: version})) require.Equal(t, c.commonTLSConfig(false).MinVersion, - TLSLookup[version]) + tlsLookup[version]) } require.Error(t, c.Update(Config{TLSMinVersion: "tlsBOGUS"})) } func TestConfigurator_CommonTLSConfigVerifyIncoming(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} type variant struct { verify bool expected tls.ClientAuthType @@ -735,7 +731,7 @@ func TestConfigurator_CommonTLSConfigVerifyIncoming(t *testing.T) { } func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} type variant struct { verify bool autoEncryptTLS bool @@ -913,7 +909,7 @@ func TestConfigurator_IncomingALPNRPCConfig(t *testing.T) { } func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) } @@ -921,7 +917,7 @@ func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) { c := Configurator{base: &Config{ TLSMinVersion: "tls12", EnableAgentTLSForChecks: false, - }, autoTLS: &autoTLS{}} + }} tlsConf := c.OutgoingTLSConfigForCheck(true, "") require.Equal(t, true, tlsConf.InsecureSkipVerify) require.Equal(t, uint16(0), tlsConf.MinVersion) @@ -930,17 +926,17 @@ func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) { c.base.ServerName = "servername" tlsConf = c.OutgoingTLSConfigForCheck(true, "") require.Equal(t, true, tlsConf.InsecureSkipVerify) - require.Equal(t, TLSLookup[c.base.TLSMinVersion], tlsConf.MinVersion) + require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion) require.Equal(t, c.base.ServerName, tlsConf.ServerName) tlsConf = c.OutgoingTLSConfigForCheck(true, "servername2") require.Equal(t, true, tlsConf.InsecureSkipVerify) - require.Equal(t, TLSLookup[c.base.TLSMinVersion], tlsConf.MinVersion) + require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion) require.Equal(t, "servername2", tlsConf.ServerName) } func TestConfigurator_OutgoingRPCConfig(t *testing.T) { - c := &Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := &Configurator{base: &Config{}} require.Nil(t, c.OutgoingRPCConfig()) c, err := NewConfigurator(Config{ @@ -958,7 +954,7 @@ func TestConfigurator_OutgoingRPCConfig(t *testing.T) { } func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) { - c := &Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := &Configurator{base: &Config{}} require.Nil(t, c.OutgoingALPNRPCConfig()) c, err := NewConfigurator(Config{ @@ -978,7 +974,7 @@ func TestConfigurator_OutgoingALPNRPCConfig(t *testing.T) { } func TestConfigurator_OutgoingRPCWrapper(t *testing.T) { - c := &Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := &Configurator{base: &Config{}} wrapper := c.OutgoingRPCWrapper() require.NotNil(t, wrapper) conn := &net.TCPConn{} @@ -1000,7 +996,7 @@ func TestConfigurator_OutgoingRPCWrapper(t *testing.T) { } func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) { - c := &Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := &Configurator{base: &Config{}} wrapper := c.OutgoingRPCWrapper() require.NotNil(t, wrapper) conn := &net.TCPConn{} @@ -1075,7 +1071,7 @@ func TestConfigurator_ServerNameOrNodeName(t *testing.T) { } func TestConfigurator_VerifyOutgoing(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} type variant struct { verify bool autoEncryptTLS bool @@ -1108,7 +1104,7 @@ func TestConfigurator_Domain(t *testing.T) { } func TestConfigurator_VerifyServerHostname(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} require.False(t, c.VerifyServerHostname()) c.base.VerifyServerHostname = true @@ -1125,7 +1121,7 @@ func TestConfigurator_VerifyServerHostname(t *testing.T) { } func TestConfigurator_AutoEncrytCertExpired(t *testing.T) { - c := Configurator{base: &Config{}, autoTLS: &autoTLS{}} + c := Configurator{base: &Config{}} require.True(t, c.AutoEncryptCertExpired()) cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key") @@ -1141,5 +1137,6 @@ func TestConfigurator_AutoEncrytCertExpired(t *testing.T) { func TestConfig_tlsVersions(t *testing.T) { require.Equal(t, []string{"tls10", "tls11", "tls12", "tls13"}, tlsVersions()) - require.Equal(t, strings.Join(tlsVersions(), ", "), TLSVersions) + expected := "tls10, tls11, tls12, tls13" + require.Equal(t, expected, strings.Join(tlsVersions(), ", ")) }