tlsutil: unexport two types

These types are only used internally and should not be exported.

Also remove some unnecessary function wrapping.
This commit is contained in:
Daniel Nephin 2021-06-21 11:42:32 -04:00
parent 62340a56b9
commit 8d9d6c6a09
2 changed files with 15 additions and 18 deletions

View File

@ -34,8 +34,8 @@ type DCWrapper func(dc string, conn net.Conn) (net.Conn, error)
// a constant value. This is usually done by currying DCWrapper. // a constant value. This is usually done by currying DCWrapper.
type Wrapper func(conn net.Conn) (net.Conn, error) type Wrapper func(conn net.Conn) (net.Conn, error)
// TLSLookup maps the tls_min_version configuration to the internal value // tlsLookup maps the tls_min_version configuration to the internal value
var TLSLookup = map[string]uint16{ var tlsLookup = map[string]uint16{
"": tls.VersionTLS10, // default in golang "": tls.VersionTLS10, // default in golang
"tls10": tls.VersionTLS10, "tls10": tls.VersionTLS10,
"tls11": tls.VersionTLS11, "tls11": tls.VersionTLS11,
@ -43,9 +43,6 @@ var TLSLookup = map[string]uint16{
"tls13": tls.VersionTLS13, "tls13": tls.VersionTLS13,
} }
// TLSVersions has all the keys from the map above.
var TLSVersions = strings.Join(tlsVersions(), ", ")
// Config used to create tls.Config // Config used to create tls.Config
type Config struct { type Config struct {
// VerifyIncoming is used to verify the authenticity of incoming // VerifyIncoming is used to verify the authenticity of incoming
@ -133,7 +130,7 @@ type Config struct {
func tlsVersions() []string { func tlsVersions() []string {
versions := []string{} versions := []string{}
for v := range TLSLookup { for v := range tlsLookup {
if v != "" { if v != "" {
versions = append(versions, v) versions = append(versions, v)
} }
@ -364,8 +361,9 @@ func pool(pems []string) (*x509.CertPool, error) {
func validateConfig(config Config, pool *x509.CertPool, cert *tls.Certificate) error { func validateConfig(config Config, pool *x509.CertPool, cert *tls.Certificate) error {
// Check if a minimum TLS version was set // Check if a minimum TLS version was set
if config.TLSMinVersion != "" { if config.TLSMinVersion != "" {
if _, ok := TLSLookup[config.TLSMinVersion]; !ok { if _, ok := tlsLookup[config.TLSMinVersion]; !ok {
return fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [%s]", config.TLSMinVersion, TLSVersions) versions := strings.Join(tlsVersions(), ", ")
return fmt.Errorf("TLSMinVersion: value %s not supported, please specify one of [%s]", config.TLSMinVersion, versions)
} }
} }
@ -517,10 +515,10 @@ func (c *Configurator) commonTLSConfig(verifyIncoming bool) *tls.Config {
tlsConfig.ClientCAs = c.caPool tlsConfig.ClientCAs = c.caPool
tlsConfig.RootCAs = c.caPool tlsConfig.RootCAs = c.caPool
// This is possible because TLSLookup also contains "" with golang's // This is possible because tlsLookup also contains "" with golang's
// default (tls10). And because the initial check makes sure the // default (tls10). And because the initial check makes sure the
// version correctly matches. // version correctly matches.
tlsConfig.MinVersion = TLSLookup[c.base.TLSMinVersion] tlsConfig.MinVersion = tlsLookup[c.base.TLSMinVersion]
// Set ClientAuth if necessary // Set ClientAuth if necessary
if verifyIncoming { if verifyIncoming {
@ -794,9 +792,7 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper {
return nil return nil
} }
return func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error) { return c.wrapALPNTLSClient
return c.wrapALPNTLSClient(dc, nodeName, alpnProto, conn)
}
} }
// AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case // AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case

View File

@ -708,12 +708,12 @@ func TestConfigurator_CommonTLSConfigCAs(t *testing.T) {
func TestConfigurator_CommonTLSConfigTLSMinVersion(t *testing.T) { func TestConfigurator_CommonTLSConfigTLSMinVersion(t *testing.T) {
c, err := NewConfigurator(Config{TLSMinVersion: ""}, nil) c, err := NewConfigurator(Config{TLSMinVersion: ""}, nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.commonTLSConfig(false).MinVersion, TLSLookup["tls10"]) require.Equal(t, c.commonTLSConfig(false).MinVersion, tlsLookup["tls10"])
for _, version := range tlsVersions() { for _, version := range tlsVersions() {
require.NoError(t, c.Update(Config{TLSMinVersion: version})) require.NoError(t, c.Update(Config{TLSMinVersion: version}))
require.Equal(t, c.commonTLSConfig(false).MinVersion, require.Equal(t, c.commonTLSConfig(false).MinVersion,
TLSLookup[version]) tlsLookup[version])
} }
require.Error(t, c.Update(Config{TLSMinVersion: "tlsBOGUS"})) require.Error(t, c.Update(Config{TLSMinVersion: "tlsBOGUS"}))
@ -930,12 +930,12 @@ func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) {
c.base.ServerName = "servername" c.base.ServerName = "servername"
tlsConf = c.OutgoingTLSConfigForCheck(true, "") tlsConf = c.OutgoingTLSConfigForCheck(true, "")
require.Equal(t, true, tlsConf.InsecureSkipVerify) require.Equal(t, true, tlsConf.InsecureSkipVerify)
require.Equal(t, TLSLookup[c.base.TLSMinVersion], tlsConf.MinVersion) require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion)
require.Equal(t, c.base.ServerName, tlsConf.ServerName) require.Equal(t, c.base.ServerName, tlsConf.ServerName)
tlsConf = c.OutgoingTLSConfigForCheck(true, "servername2") tlsConf = c.OutgoingTLSConfigForCheck(true, "servername2")
require.Equal(t, true, tlsConf.InsecureSkipVerify) require.Equal(t, true, tlsConf.InsecureSkipVerify)
require.Equal(t, TLSLookup[c.base.TLSMinVersion], tlsConf.MinVersion) require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion)
require.Equal(t, "servername2", tlsConf.ServerName) require.Equal(t, "servername2", tlsConf.ServerName)
} }
@ -1141,5 +1141,6 @@ func TestConfigurator_AutoEncrytCertExpired(t *testing.T) {
func TestConfig_tlsVersions(t *testing.T) { func TestConfig_tlsVersions(t *testing.T) {
require.Equal(t, []string{"tls10", "tls11", "tls12", "tls13"}, tlsVersions()) require.Equal(t, []string{"tls10", "tls11", "tls12", "tls13"}, tlsVersions())
require.Equal(t, strings.Join(tlsVersions(), ", "), TLSVersions) expected := "tls10, tls11, tls12, tls13"
require.Equal(t, expected, strings.Join(tlsVersions(), ", "))
} }