diff --git a/agent/agent.go b/agent/agent.go index 8dfd6450ca..7b483e9b57 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -17,6 +17,7 @@ import ( "sync" "time" + "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/go-connlimit" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" @@ -642,21 +643,6 @@ func (a *Agent) Start(ctx context.Context) error { return err } - // Warn if the node name is incompatible with DNS - if InvalidDnsRe.MatchString(a.config.NodeName) { - a.logger.Warn("Node name will not be discoverable "+ - "via DNS due to invalid characters. Valid characters include "+ - "all alpha-numerics and dashes.", - "node_name", a.config.NodeName, - ) - } else if len(a.config.NodeName) > MaxDNSLabelLength { - a.logger.Warn("Node name will not be discoverable "+ - "via DNS due to it being too long. Valid lengths are between "+ - "1 and 63 bytes.", - "node_name", a.config.NodeName, - ) - } - // load the tokens - this requires the logger to be setup // which is why we can't do this in New a.loadTokens(a.config) @@ -2484,13 +2470,13 @@ func (a *Agent) validateService(service *structs.NodeService, chkTypes []*struct } // Warn if the service name is incompatible with DNS - if InvalidDnsRe.MatchString(service.Service) { + if dns.InvalidNameRe.MatchString(service.Service) { a.logger.Warn("Service name will not be discoverable "+ "via DNS due to invalid characters. Valid characters include "+ "all alpha-numerics and dashes.", "service", service.Service, ) - } else if len(service.Service) > MaxDNSLabelLength { + } else if len(service.Service) > dns.MaxLabelLength { a.logger.Warn("Service name will not be discoverable "+ "via DNS due to it being too long. Valid lengths are between "+ "1 and 63 bytes.", @@ -2500,13 +2486,13 @@ func (a *Agent) validateService(service *structs.NodeService, chkTypes []*struct // Warn if any tags are incompatible with DNS for _, tag := range service.Tags { - if InvalidDnsRe.MatchString(tag) { + if dns.InvalidNameRe.MatchString(tag) { a.logger.Debug("Service tag will not be discoverable "+ "via DNS due to invalid characters. Valid characters include "+ "all alpha-numerics and dashes.", "tag", tag, ) - } else if len(tag) > MaxDNSLabelLength { + } else if len(tag) > dns.MaxLabelLength { a.logger.Debug("Service tag will not be discoverable "+ "via DNS due to it being too long. Valid lengths are between "+ "1 and 63 bytes.", diff --git a/agent/config/builder.go b/agent/config/builder.go index 4b11a2e337..93e94b4ea3 100644 --- a/agent/config/builder.go +++ b/agent/config/builder.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/consul/agent/connect/ca" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth" + "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/lib" @@ -1117,9 +1118,20 @@ func (b *Builder) Validate(rt RuntimeConfig) error { return fmt.Errorf("data_dir %q is not a directory", rt.DataDir) } } - if rt.NodeName == "" { + + switch { + case rt.NodeName == "": return fmt.Errorf("node_name cannot be empty") + case dns.InvalidNameRe.MatchString(rt.NodeName): + b.warn("Node name %q will not be discoverable "+ + "via DNS due to invalid characters. Valid characters include "+ + "all alpha-numerics and dashes.", rt.NodeName) + case len(rt.NodeName) > dns.MaxLabelLength: + b.warn("Node name %q will not be discoverable "+ + "via DNS due to it being too long. Valid lengths are between "+ + "1 and 63 bytes.", rt.NodeName) } + if ipaddr.IsAny(rt.AdvertiseAddrLAN.IP) { return fmt.Errorf("Advertise address cannot be 0.0.0.0, :: or [::]") } diff --git a/agent/config/builder_test.go b/agent/config/builder_test.go index 01e58b8385..9be12a4d34 100644 --- a/agent/config/builder_test.go +++ b/agent/config/builder_test.go @@ -3,8 +3,10 @@ package config import ( "fmt" "io/ioutil" + "net" "os" "path/filepath" + "strings" "testing" "time" @@ -121,3 +123,62 @@ func setupConfigFiles(t *testing.T) []string { subpath, } } + +func TestBuilder_BuildAndValidate_NodeName(t *testing.T) { + type testCase struct { + name string + nodeName string + expectedWarn string + } + + fn := func(t *testing.T, tc testCase) { + b, err := NewBuilder(BuilderOpts{ + Config: Config{ + NodeName: pString(tc.nodeName), + DataDir: pString("dir"), + }, + }) + patchBuilderShims(b) + require.NoError(t, err) + _, err = b.BuildAndValidate() + require.NoError(t, err) + require.Len(t, b.Warnings, 1) + require.Contains(t, b.Warnings[0], tc.expectedWarn) + } + + var testCases = []testCase{ + { + name: "invalid character - unicode", + nodeName: "🐼", + expectedWarn: `Node name "🐼" will not be discoverable via DNS due to invalid characters`, + }, + { + name: "invalid character - slash", + nodeName: "thing/other/ok", + expectedWarn: `Node name "thing/other/ok" will not be discoverable via DNS due to invalid characters`, + }, + { + name: "too long", + nodeName: strings.Repeat("a", 66), + expectedWarn: "due to it being too long.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fn(t, tc) + }) + } +} + +func patchBuilderShims(b *Builder) { + b.hostname = func() (string, error) { + return "thehostname", nil + } + b.getPrivateIPv4 = func() ([]*net.IPAddr, error) { + return []*net.IPAddr{ipAddr("10.0.0.1")}, nil + } + b.getPublicIPv6 = func() ([]*net.IPAddr, error) { + return []*net.IPAddr{ipAddr("dead:beef::1")}, nil + } +} diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 1a7edb29ab..819dce8979 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -4280,27 +4280,16 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) { t.Fatal("NewBuilder", err) } - // mock the hostname function unless a mock is provided - b.hostname = tt.hostname - if b.hostname == nil { - b.hostname = func() (string, error) { return "nodex", nil } + patchBuilderShims(b) + if tt.hostname != nil { + b.hostname = tt.hostname } - - // mock the ip address detection - privatev4 := tt.privatev4 - if privatev4 == nil { - privatev4 = func() ([]*net.IPAddr, error) { - return []*net.IPAddr{ipAddr("10.0.0.1")}, nil - } + if tt.privatev4 != nil { + b.getPrivateIPv4 = tt.privatev4 } - publicv6 := tt.publicv6 - if publicv6 == nil { - publicv6 = func() ([]*net.IPAddr, error) { - return []*net.IPAddr{ipAddr("dead:beef::1")}, nil - } + if tt.publicv6 != nil { + b.getPublicIPv6 = tt.publicv6 } - b.getPrivateIPv4 = privatev4 - b.getPublicIPv6 = publicv6 // read the source fragements for i, data := range srcs { @@ -4332,12 +4321,10 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) { if err != nil && tt.err != "" && !strings.Contains(err.Error(), tt.err) { t.Fatalf("error %q does not contain %q", err.Error(), tt.err) } - require.Equal(t, tt.warns, b.Warnings, "warnings") - - // stop if we expected an error if tt.err != "" { return } + require.Equal(t, tt.warns, b.Warnings, "warnings") // build a default configuration, then patch the fields we expect to change // and compare it with the generated configuration. Since the expected diff --git a/agent/dns.go b/agent/dns.go index 187b0463c5..686ce8b20b 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -5,17 +5,17 @@ import ( "encoding/hex" "fmt" "net" + "regexp" "strings" "sync/atomic" "time" - "regexp" - metrics "github.com/armon/go-metrics" radix "github.com/armon/go-radix" "github.com/coredns/coredns/plugin/pkg/dnsutil" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" + agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/ipaddr" @@ -38,12 +38,8 @@ const ( staleCounterThreshold = 5 * time.Second defaultMaxUDPSize = 512 - - MaxDNSLabelLength = 63 ) -var InvalidDnsRe = regexp.MustCompile(`[^A-Za-z0-9\\-]+`) - type dnsSOAConfig struct { Refresh uint32 // 3600 by default Retry uint32 // 600 @@ -539,7 +535,7 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns for _, o := range out.Nodes { name, dc := o.Node.Node, o.Node.Datacenter - if InvalidDnsRe.MatchString(name) { + if agentdns.InvalidNameRe.MatchString(name) { d.logger.Warn("Skipping invalid node for NS records", "node", name) continue } diff --git a/agent/dns/dns.go b/agent/dns/dns.go new file mode 100644 index 0000000000..8744eb351f --- /dev/null +++ b/agent/dns/dns.go @@ -0,0 +1,10 @@ +package dns + +import "regexp" + +// MaxLabelLength is the maximum length for a name that can be used in DNS. +const MaxLabelLength = 63 + +// InvalidNameRe is a regex that matches characters which can not be included in +// a DNS name. +var InvalidNameRe = regexp.MustCompile(`[^A-Za-z0-9\\-]+`) diff --git a/agent/dns_test.go b/agent/dns_test.go index 07c4211871..a82cc9c7a3 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/hashicorp/consul/agent/config" + agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/lib" @@ -6976,7 +6977,7 @@ func TestDNSInvalidRegex(t *testing.T) { } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - if got, want := InvalidDnsRe.MatchString(test.in), test.invalid; got != want { + if got, want := agentdns.InvalidNameRe.MatchString(test.in), test.invalid; got != want { t.Fatalf("Expected %v to return %v", test.in, want) } })