diff --git a/agent/agent.go b/agent/agent.go index 8dc49e7da8..ab10e21da2 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1106,7 +1106,14 @@ func (a *Agent) listenAndServeV2DNS() error { if a.baseDeps.UseV2Resources() { a.catalogDataFetcher = discovery.NewV2DataFetcher(a.config) } else { - a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config, a.AgentEnterpriseMeta(), a.RPC, a.logger.Named("catalog-data-fetcher")) + a.catalogDataFetcher = discovery.NewV1DataFetcher(a.config, + a.AgentEnterpriseMeta(), + a.cache.Get, + a.RPC, + a.rpcClientHealth.ServiceNodes, + a.rpcClientConfigEntry.GetSamenessGroup, + a.TranslateServicePort, + a.logger.Named("catalog-data-fetcher")) } // Generate a Query Processor with the appropriate data fetcher diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index e04e3c966f..92e6644d2a 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -107,11 +107,11 @@ const ( // It is the responsibility of the DNS encoder to know what to do with // each Result, based on the query type. type Result struct { - Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. - Weight uint32 // SRV queries - Port uint32 // SRV queries - Metadata []string // Used to collect metadata into TXT Records - Type ResultType // Used to reconstruct the fqdn name of the resource + Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. + Weight uint32 // SRV queries + Port uint32 // SRV queries + Metadata map[string]string // Used to collect metadata into TXT Records + Type ResultType // Used to reconstruct the fqdn name of the resource // Used in SRV & PTR queries to point at an A/AAAA Record. Target string @@ -176,6 +176,7 @@ func NewQueryProcessor(dataFetcher CatalogDataFetcher) *QueryProcessor { } } +// QueryByName is used to look up a service, node, workload, or prepared query. func (p *QueryProcessor) QueryByName(query *Query, ctx Context) ([]*Result, error) { switch query.QueryType { case QueryTypeNode: diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 4d89ec8645..ff7837cb97 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -5,55 +5,75 @@ package discovery import ( "context" + "errors" "fmt" "net" "sync/atomic" "time" + "github.com/armon/go-metrics" + cachetype "github.com/hashicorp/consul/agent/cache-types" + "github.com/hashicorp/consul/api" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/structs" ) const ( - // TODO (v2-dns): can we move the recursion into the data fetcher? - maxRecursionLevelDefault = 3 // This field comes from the V1 DNS server and affects V1 catalog lookups - maxRecurseRecords = 5 + // Increment a counter when requests staler than this are served + staleCounterThreshold = 5 * time.Second ) // v1DataFetcherDynamicConfig is used to store the dynamic configuration of the V1 data fetcher. type v1DataFetcherDynamicConfig struct { // Default request tenancy - datacenter string + defaultEntMeta acl.EnterpriseMeta + datacenter string // Catalog configuration - allowStale bool - maxStale time.Duration - useCache bool - cacheMaxAge time.Duration - onlyPassing bool + allowStale bool + maxStale time.Duration + useCache bool + cacheMaxAge time.Duration + onlyPassing bool + enterpriseDNSConfig EnterpriseDNSConfig } // V1DataFetcher is used to fetch data from the V1 catalog. type V1DataFetcher struct { + // TODO(v2-dns): store this in the config. defaultEnterpriseMeta acl.EnterpriseMeta dynamicConfig atomic.Value logger hclog.Logger - rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error + getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) + rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error + rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) + rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) + translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int } // NewV1DataFetcher creates a new V1 data fetcher. func NewV1DataFetcher(config *config.RuntimeConfig, entMeta *acl.EnterpriseMeta, + getFromCacheFunc func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error), rpcFunc func(ctx context.Context, method string, args interface{}, reply interface{}) error, + rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error), + rpcFuncForSamenessGroup func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error), + translateServicePortFunc func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int, logger hclog.Logger) *V1DataFetcher { f := &V1DataFetcher{ - defaultEnterpriseMeta: *entMeta, - rpcFunc: rpcFunc, - logger: logger, + defaultEnterpriseMeta: *entMeta, + getFromCacheFunc: getFromCacheFunc, + rpcFunc: rpcFunc, + rpcFuncForServiceNodes: rpcFuncForServiceNodes, + rpcFuncForSamenessGroup: rpcFuncForSamenessGroup, + translateServicePortFunc: translateServicePortFunc, + logger: logger, } f.LoadConfig(config) return f @@ -62,26 +82,65 @@ func NewV1DataFetcher(config *config.RuntimeConfig, // LoadConfig loads the configuration for the V1 data fetcher. func (f *V1DataFetcher) LoadConfig(config *config.RuntimeConfig) { dynamicConfig := &v1DataFetcherDynamicConfig{ - datacenter: config.Datacenter, - allowStale: config.DNSAllowStale, - maxStale: config.DNSMaxStale, - useCache: config.DNSUseCache, - cacheMaxAge: config.DNSCacheMaxAge, - onlyPassing: config.DNSOnlyPassing, + allowStale: config.DNSAllowStale, + maxStale: config.DNSMaxStale, + useCache: config.DNSUseCache, + cacheMaxAge: config.DNSCacheMaxAge, + onlyPassing: config.DNSOnlyPassing, + enterpriseDNSConfig: GetEnterpriseDNSConfig(config), + datacenter: config.Datacenter, + // TODO (v2-dns): make this work + //defaultEntMeta: config.EnterpriseRuntimeConfig.DefaultEntMeta, } f.dynamicConfig.Store(dynamicConfig) } -// TODO (v2-dns): Implementation of the V1 data fetcher - // FetchNodes fetches A/AAAA/CNAME func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, error) { - return nil, nil + cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + // Make an RPC request + args := &structs.NodeSpecificRequest{ + Datacenter: req.Tenancy.Datacenter, + PeerName: req.Tenancy.Peer, + Node: req.Name, + QueryOptions: structs.QueryOptions{ + Token: ctx.Token, + AllowStale: cfg.allowStale, + }, + EnterpriseMeta: req.Tenancy.EnterpriseMeta, + } + out, err := f.fetchNode(cfg, args) + if err != nil { + return nil, fmt.Errorf("failed rpc request: %w", err) + } + + // If we have no out.NodeServices.Nodeaddress, return not found! + if out.NodeServices == nil { + return nil, errors.New("no nodes found") + } + + results := make([]*Result, 0, 1) + node := out.NodeServices.Node + + results = append(results, &Result{ + Address: node.Address, + Type: ResultTypeNode, + Metadata: node.Meta, + Target: node.Node, + }) + + return results, nil } // FetchEndpoints fetches records for A/AAAA/CNAME or SRV requests for services func (f *V1DataFetcher) FetchEndpoints(ctx Context, req *QueryPayload, lookupType LookupType) ([]*Result, error) { - return nil, nil + f.logger.Debug(fmt.Sprintf("FetchEndpoints - req: %+v / lookupType: %+v", req, lookupType)) + cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + if lookupType == LookupTypeService { + return f.fetchService(ctx, req, cfg) + } + + return nil, errors.New(fmt.Sprintf("unsupported lookup type: %s", lookupType)) } // FetchVirtualIP fetches A/AAAA records for virtual IPs @@ -193,3 +252,182 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { return nil, nil } + +// fetchNode is used to look up a node in the Consul catalog within NodeServices. +// If the config is set to UseCache, it will get the record from the agent cache. +func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { + var out structs.IndexedNodeServices + + useCache := cfg.useCache +RPC: + if useCache { + raw, _, err := f.getFromCacheFunc(context.TODO(), cachetype.NodeServicesName, args) + if err != nil { + return nil, err + } + reply, ok := raw.(*structs.IndexedNodeServices) + if !ok { + // This should never happen, but we want to protect against panics + return nil, fmt.Errorf("internal error: response type not correct") + } + out = *reply + } else { + if err := f.rpcFunc(context.Background(), "Catalog.NodeServices", &args, &out); err != nil { + return nil, err + } + } + + // Verify that request is not too stale, redo the request + if args.AllowStale { + if out.LastContact > cfg.maxStale { + args.AllowStale = false + useCache = false + f.logger.Warn("Query results too stale, re-requesting") + goto RPC + } else if out.LastContact > staleCounterThreshold { + metrics.IncrCounter([]string{"dns", "stale_queries"}, 1) + } + } + + return &out, nil +} + +func (f *V1DataFetcher) fetchService(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) { + f.logger.Debug("fetchService", "req", req) + if req.Tenancy.SamenessGroup == "" { + return f.fetchServiceBasedOnTenancy(ctx, req, cfg) + } + + return f.fetchServiceFromSamenessGroup(ctx, req, cfg) +} + +// fetchServiceBasedOnTenancy is used to look up a service in the Consul catalog based on its tenancy or default tenancy. +func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) { + f.logger.Debug(fmt.Sprintf("fetchServiceBasedOnTenancy - req: %+v", req)) + if req.Tenancy.SamenessGroup != "" { + return nil, errors.New("sameness groups are not allowed for service lookups based on tenancy") + } + + datacenter := req.Tenancy.Datacenter + if req.Tenancy.Peer != "" { + datacenter = "" + } + + serviceTags := []string{} + if req.Tag != "" { + serviceTags = []string{req.Tag} + } + args := structs.ServiceSpecificRequest{ + PeerName: req.Tenancy.Peer, + Connect: false, + Ingress: false, + Datacenter: datacenter, + ServiceName: req.Name, + ServiceTags: serviceTags, + TagFilter: req.Tag != "", + QueryOptions: structs.QueryOptions{ + Token: ctx.Token, + AllowStale: cfg.allowStale, + MaxAge: cfg.cacheMaxAge, + UseCache: cfg.useCache, + MaxStaleDuration: cfg.maxStale, + }, + EnterpriseMeta: req.Tenancy.EnterpriseMeta, + } + + out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args) + if err != nil { + return nil, err + } + + // Filter out any service nodes due to health checks + // We copy the slice to avoid modifying the result if it comes from the cache + nodes := make(structs.CheckServiceNodes, len(out.Nodes)) + copy(nodes, out.Nodes) + out.Nodes = nodes.Filter(cfg.onlyPassing) + if err != nil { + return nil, fmt.Errorf("rpc request failed: %w", err) + } + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + return nil, ErrNoData + } + + // Perform a random shuffle + out.Nodes.Shuffle() + results := make([]*Result, 0, len(out.Nodes)) + for _, node := range out.Nodes { + target := node.Service.Address + resultType := ResultTypeService + // TODO (v2-dns): IMPORTANT!!!!: this needs to be revisited in how dns v1 utilizes + // the nodeaddress when the service address is an empty string. Need to figure out + // if this can be removed and dns recursion and process can work with only the + // address set to the node.address and the target set to the service.address. + // We may have to look at modifying the discovery result if more metadata is needed to send along. + if target == "" { + target = node.Node.Node + resultType = ResultTypeNode + } + results = append(results, &Result{ + Address: node.Node.Address, + Type: resultType, + Target: target, + Weight: uint32(findWeight(node)), + Port: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)), + Metadata: node.Node.Meta, + Tenancy: ResultTenancy{ + EnterpriseMeta: cfg.defaultEntMeta, + Datacenter: cfg.datacenter, + }, + }) + } + + return results, nil +} + +// findWeight returns the weight of a service node. +func findWeight(node structs.CheckServiceNode) int { + // By default, when only_passing is false, warning and passing nodes are returned + // Those values will be used if using a client with support while server has no + // support for weights + weightPassing := 1 + weightWarning := 1 + if node.Service.Weights != nil { + weightPassing = node.Service.Weights.Passing + weightWarning = node.Service.Weights.Warning + } + serviceChecks := make(api.HealthChecks, 0, len(node.Checks)) + for _, c := range node.Checks { + if c.ServiceName == node.Service.Service || c.ServiceName == "" { + healthCheck := &api.HealthCheck{ + Node: c.Node, + CheckID: string(c.CheckID), + Name: c.Name, + Status: c.Status, + Notes: c.Notes, + Output: c.Output, + ServiceID: c.ServiceID, + ServiceName: c.ServiceName, + ServiceTags: c.ServiceTags, + } + serviceChecks = append(serviceChecks, healthCheck) + } + } + status := serviceChecks.AggregatedStatus() + switch status { + case api.HealthWarning: + return weightWarning + case api.HealthPassing: + return weightPassing + case api.HealthMaint: + // Not used in theory + return 0 + case api.HealthCritical: + // Should not happen since already filtered + return 0 + default: + // When non-standard status, return 1 + return 1 + } +} diff --git a/agent/discovery/query_fetcher_v1_ce.go b/agent/discovery/query_fetcher_v1_ce.go new file mode 100644 index 0000000000..6540dea7fe --- /dev/null +++ b/agent/discovery/query_fetcher_v1_ce.go @@ -0,0 +1,20 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package discovery + +import ( + "errors" + "fmt" +) + +// fetchServiceFromSamenessGroup fetches a service from a sameness group. +func (f *V1DataFetcher) fetchServiceFromSamenessGroup(ctx Context, req *QueryPayload, cfg *v1DataFetcherDynamicConfig) ([]*Result, error) { + f.logger.Debug(fmt.Sprintf("fetchServiceFromSamenessGroup - req: %+v", req)) + if req.Tenancy.SamenessGroup == "" { + return nil, errors.New("sameness groups must be provided for service lookups") + } + return f.fetchServiceBasedOnTenancy(ctx, req, cfg) +} diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index 25371f1138..add1b7bda7 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -4,10 +4,13 @@ package discovery import ( + "context" "errors" "testing" "time" + "github.com/hashicorp/consul/agent/cache" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -96,7 +99,19 @@ func Test_FetchVirtualIP(t *testing.T) { *reply = tc.expectedResult.Address } }) - df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), mockRPC.RPC, logger) + // TODO (v2-dns): mock these properly + translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 } + rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { + return structs.IndexedCheckServiceNodes{}, cache.ResultMeta{}, nil + } + rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) { + return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil + } + getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) { + return nil, cache.ResultMeta{}, nil + } + + df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger) result, err := df.FetchVirtualIP(tc.context, tc.queryPayload) require.Equal(t, tc.expectedErr, err) diff --git a/agent/discovery/query_locality.go b/agent/discovery/query_locality.go new file mode 100644 index 0000000000..55b77352e9 --- /dev/null +++ b/agent/discovery/query_locality.go @@ -0,0 +1,61 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package discovery + +import "github.com/hashicorp/consul/acl" + +// QueryLocality is the locality parsed from a DNS query. +type QueryLocality struct { + // Datacenter is the datacenter parsed from a label that has an explicit datacenter part. + // Example query: .virtual..ns..ap..dc.consul + Datacenter string + + // Peer is the peer name parsed from a label that has explicit parts. + // Example query: .virtual..ns..peer..ap.consul + Peer string + + // PeerOrDatacenter is parsed from DNS queries where the datacenter and peer name are + // specified in the same query part. + // Example query: .virtual..consul + // + // Note that this field should only be a "peer" for virtual queries, since virtual IPs should + // not be shared between datacenters. In all other cases, it should be considered a DC. + PeerOrDatacenter string + + acl.EnterpriseMeta +} + +// EffectiveDatacenter returns the datacenter parsed from a query, or a default +// value if none is specified. +func (l QueryLocality) EffectiveDatacenter(defaultDC string) string { + // Prefer the value parsed from a query with explicit parts: .ns..ap..dc + if l.Datacenter != "" { + return l.Datacenter + } + // Fall back to the ambiguously parsed DC or Peer. + if l.PeerOrDatacenter != "" { + return l.PeerOrDatacenter + } + // If all are empty, use a default value. + return defaultDC +} + +// GetQueryTenancyBasedOnLocality returns a discovery.QueryTenancy from a DNS message. +func GetQueryTenancyBasedOnLocality(locality QueryLocality, defaultDatacenter string) (QueryTenancy, error) { + datacenter := locality.EffectiveDatacenter(defaultDatacenter) + // Only one of dc or peer can be used. + if locality.Peer != "" { + datacenter = "" + } + + return QueryTenancy{ + EnterpriseMeta: locality.EnterpriseMeta, + // The datacenter of the request is not specified because cross-datacenter virtual IP + // queries are not supported. This guard rail is in place because virtual IPs are allocated + // within a DC, therefore their uniqueness is not guaranteed globally. + Peer: locality.Peer, + Datacenter: datacenter, + SamenessGroup: "", // this should be nil since the single locality was directly used to configure tenancy. + }, nil +} diff --git a/agent/discovery/query_locality_ce.go b/agent/discovery/query_locality_ce.go new file mode 100644 index 0000000000..4cc4f312d4 --- /dev/null +++ b/agent/discovery/query_locality_ce.go @@ -0,0 +1,57 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package discovery + +import ( + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/config" +) + +// ParseLocality can parse peer name or datacenter from a DNS query's labels. +// Peer name is parsed from the same query part that datacenter is, so given this ambiguity +// we parse a "peerOrDatacenter". The caller or RPC handler are responsible for disambiguating. +func ParseLocality(labels []string, defaultEnterpriseMeta acl.EnterpriseMeta, _ EnterpriseDNSConfig) (QueryLocality, bool) { + locality := QueryLocality{ + EnterpriseMeta: defaultEnterpriseMeta, + } + + switch len(labels) { + case 2, 4: + // Support the following formats: + // - [..dc] + // - [..peer] + for i := 0; i < len(labels); i += 2 { + switch labels[i+1] { + case "dc": + locality.Datacenter = labels[i] + case "peer": + locality.Peer = labels[i] + default: + return QueryLocality{}, false + } + } + // Return error when both datacenter and peer are specified. + if locality.Datacenter != "" && locality.Peer != "" { + return QueryLocality{}, false + } + return locality, true + case 1: + return QueryLocality{PeerOrDatacenter: labels[0]}, true + + case 0: + return QueryLocality{}, true + } + + return QueryLocality{}, false +} + +// EnterpriseDNSConfig is the configuration for enterprise DNS. +type EnterpriseDNSConfig struct{} + +// GetEnterpriseDNSConfig returns the enterprise DNS configuration. +func GetEnterpriseDNSConfig(conf *config.RuntimeConfig) EnterpriseDNSConfig { + return EnterpriseDNSConfig{} +} diff --git a/agent/discovery/query_locality_ce_test.go b/agent/discovery/query_locality_ce_test.go new file mode 100644 index 0000000000..5f720c2121 --- /dev/null +++ b/agent/discovery/query_locality_ce_test.go @@ -0,0 +1,60 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package discovery + +import ( + "github.com/hashicorp/consul/acl" +) + +func getTestCases() []testCaseParseLocality { + testCases := []testCaseParseLocality{ + { + name: "test [..dc]", + labels: []string{"test-dc", "dc"}, + enterpriseDNSConfig: EnterpriseDNSConfig{}, + expectedResult: QueryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + Datacenter: "test-dc", + }, + expectedOK: true, + }, + { + name: "test [..peer]", + labels: []string{"test-peer", "peer"}, + enterpriseDNSConfig: EnterpriseDNSConfig{}, + expectedResult: QueryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + Peer: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 1 label", + labels: []string{"test-peer"}, + enterpriseDNSConfig: EnterpriseDNSConfig{}, + expectedResult: QueryLocality{ + EnterpriseMeta: acl.EnterpriseMeta{}, + PeerOrDatacenter: "test-peer", + }, + expectedOK: true, + }, + { + name: "test 0 labels", + labels: []string{}, + enterpriseDNSConfig: EnterpriseDNSConfig{}, + expectedResult: QueryLocality{}, + expectedOK: true, + }, + { + name: "test 3 labels returns not found", + labels: []string{"test-dc", "dc", "test-blah"}, + enterpriseDNSConfig: EnterpriseDNSConfig{}, + expectedResult: QueryLocality{}, + expectedOK: false, + }, + } + return testCases +} diff --git a/agent/discovery/query_locality_test.go b/agent/discovery/query_locality_test.go new file mode 100644 index 0000000000..2c1ce28c9d --- /dev/null +++ b/agent/discovery/query_locality_test.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 +package discovery + +import ( + "testing" + + "github.com/hashicorp/consul/acl" + "github.com/stretchr/testify/require" +) + +type testCaseParseLocality struct { + name string + labels []string + defaultMeta acl.EnterpriseMeta + enterpriseDNSConfig EnterpriseDNSConfig + expectedResult QueryLocality + expectedOK bool +} + +func Test_parseLocality(t *testing.T) { + testCases := getTestCases() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualResult, actualOK := ParseLocality(tc.labels, tc.defaultMeta, tc.enterpriseDNSConfig) + require.Equal(t, tc.expectedOK, actualOK) + require.Equal(t, tc.expectedResult, actualResult) + + }) + } + +} + +func Test_effectiveDatacenter(t *testing.T) { + type testCase struct { + name string + QueryLocality QueryLocality + defaultDC string + expected string + } + testCases := []testCase{ + { + name: "return Datacenter first", + QueryLocality: QueryLocality{ + Datacenter: "test-dc", + PeerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-dc", + }, + { + name: "return PeerOrDatacenter second", + QueryLocality: QueryLocality{ + PeerOrDatacenter: "test-peer", + }, + defaultDC: "default-dc", + expected: "test-peer", + }, + { + name: "return defaultDC as fallback", + QueryLocality: QueryLocality{}, + defaultDC: "default-dc", + expected: "default-dc", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.QueryLocality.EffectiveDatacenter(tc.defaultDC) + require.Equal(t, tc.expected, got) + }) + } +} diff --git a/agent/dns/dns_address.go b/agent/dns/dns_address.go new file mode 100644 index 0000000000..caadf68422 --- /dev/null +++ b/agent/dns/dns_address.go @@ -0,0 +1,87 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 +package dns + +import ( + "github.com/miekg/dns" + "net" + "strings" +) + +func newDNSAddress(addr string) *dnsAddress { + a := &dnsAddress{} + a.SetAddress(addr) + return a +} + +// dnsAddress is a wrapper around a string that represents a DNS address and +// provides helper methods for determining whether it is an IP or FQDN and +// whether it is internal or external to the domain. +type dnsAddress struct { + addr string + + // store an IP so helpers don't have to parse it multiple times + ip net.IP +} + +// SetAddress sets the address field and the ip field if the string is an IP. +func (a *dnsAddress) SetAddress(addr string) { + a.addr = addr + a.ip = net.ParseIP(addr) +} + +// IP returns the IP address if the address is an IP. +func (a *dnsAddress) IP() net.IP { + return a.ip +} + +// IsIP returns true if the address is an IP. +func (a *dnsAddress) IsIP() bool { + return a.IP() != nil +} + +// IsIPV4 returns true if the address is an IPv4 address. +func (a *dnsAddress) IsIPV4() bool { + if a.IP() == nil { + return false + } + return a.IP().To4() != nil +} + +// FQDN returns the FQDN if the address is not an IP. +func (a *dnsAddress) FQDN() string { + if !a.IsEmptyString() && !a.IsIP() { + return dns.Fqdn(a.addr) + } + return "" +} + +// IsFQDN returns true if the address is a FQDN and not an IP. +func (a *dnsAddress) IsFQDN() bool { + return !a.IsEmptyString() && !a.IsIP() && dns.IsFqdn(a.FQDN()) +} + +// String returns the address as a string. +func (a *dnsAddress) String() string { + return a.addr +} + +// IsEmptyString returns true if the address is an empty string. +func (a *dnsAddress) IsEmptyString() bool { + return a.addr == "" +} + +// IsInternalFQDN returns true if the address is a FQDN and is internal to the domain. +func (a *dnsAddress) IsInternalFQDN(domain string) bool { + return !a.IsIP() && a.IsFQDN() && strings.HasSuffix(a.FQDN(), domain) +} + +// IsInternalFQDNOrIP returns true if the address is an IP or a FQDN and is internal to the domain. +func (a *dnsAddress) IsInternalFQDNOrIP(domain string) bool { + return a.IsIP() || a.IsInternalFQDN(domain) +} + +// IsExternalFQDN returns true if the address is a FQDN and is external to the domain. +func (a *dnsAddress) IsExternalFQDN(domain string) bool { + return !a.IsIP() && a.IsFQDN() && !strings.HasSuffix(a.FQDN(), domain) +} diff --git a/agent/dns/dns_address_test.go b/agent/dns/dns_address_test.go new file mode 100644 index 0000000000..bf55c295ba --- /dev/null +++ b/agent/dns/dns_address_test.go @@ -0,0 +1,154 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 +package dns + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_dnsAddress(t *testing.T) { + const domain = "consul." + type expectedResults struct { + isIp bool + stringResult string + fqdn string + isFQDN bool + isEmptyString bool + isExternalFQDN bool + isInternalFQDN bool + isInternalFQDNOrIP bool + } + type testCase struct { + name string + input string + expectedResults expectedResults + } + testCases := []testCase{ + { + name: "empty string", + input: "", + expectedResults: expectedResults{ + isIp: false, + stringResult: "", + fqdn: "", + isFQDN: false, + isEmptyString: true, + isExternalFQDN: false, + isInternalFQDN: false, + isInternalFQDNOrIP: false, + }, + }, + { + name: "ipv4 address", + input: "127.0.0.1", + expectedResults: expectedResults{ + isIp: true, + stringResult: "127.0.0.1", + fqdn: "", + isFQDN: false, + isEmptyString: false, + isExternalFQDN: false, + isInternalFQDN: false, + isInternalFQDNOrIP: true, + }, + }, + { + name: "ipv6 address", + input: "2001:db8:1:2:cafe::1337", + expectedResults: expectedResults{ + isIp: true, + stringResult: "2001:db8:1:2:cafe::1337", + fqdn: "", + isFQDN: false, + isEmptyString: false, + isExternalFQDN: false, + isInternalFQDN: false, + isInternalFQDNOrIP: true, + }, + }, + { + name: "internal FQDN without trailing period", + input: "web.service.consul", + expectedResults: expectedResults{ + isIp: false, + stringResult: "web.service.consul", + fqdn: "web.service.consul.", + isFQDN: true, + isEmptyString: false, + isExternalFQDN: false, + isInternalFQDN: true, + isInternalFQDNOrIP: true, + }, + }, + { + name: "internal FQDN with period", + input: "web.service.consul.", + expectedResults: expectedResults{ + isIp: false, + stringResult: "web.service.consul.", + fqdn: "web.service.consul.", + isFQDN: true, + isEmptyString: false, + isExternalFQDN: false, + isInternalFQDN: true, + isInternalFQDNOrIP: true, + }, + }, + { + name: "external FQDN without trailing period", + input: "web.service.vault", + expectedResults: expectedResults{ + isIp: false, + stringResult: "web.service.vault", + fqdn: "web.service.vault.", + isFQDN: true, + isEmptyString: false, + isExternalFQDN: true, + isInternalFQDN: false, + isInternalFQDNOrIP: false, + }, + }, + { + name: "external FQDN with trailing period", + input: "web.service.vault.", + expectedResults: expectedResults{ + isIp: false, + stringResult: "web.service.vault.", + fqdn: "web.service.vault.", + isFQDN: true, + isEmptyString: false, + isExternalFQDN: true, + isInternalFQDN: false, + isInternalFQDNOrIP: false, + }, + }, + { + name: "another external FQDN", + input: "www.google.com", + expectedResults: expectedResults{ + isIp: false, + stringResult: "www.google.com", + fqdn: "www.google.com.", + isFQDN: true, + isEmptyString: false, + isExternalFQDN: true, + isInternalFQDN: false, + isInternalFQDNOrIP: false, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dnsAddress := newDNSAddress(tc.input) + assert.Equal(t, tc.expectedResults.isIp, dnsAddress.IsIP()) + assert.Equal(t, tc.expectedResults.stringResult, dnsAddress.String()) + assert.Equal(t, tc.expectedResults.isFQDN, dnsAddress.IsFQDN()) + assert.Equal(t, tc.expectedResults.isEmptyString, dnsAddress.IsEmptyString()) + assert.Equal(t, tc.expectedResults.isExternalFQDN, dnsAddress.IsExternalFQDN(domain)) + assert.Equal(t, tc.expectedResults.isInternalFQDN, dnsAddress.IsInternalFQDN(domain)) + assert.Equal(t, tc.expectedResults.isInternalFQDNOrIP, dnsAddress.IsInternalFQDNOrIP(domain)) + }) + } +} diff --git a/agent/dns/mock_dnsRecursor.go b/agent/dns/mock_dnsRecursor.go index 83f41a30ed..b590661da1 100644 --- a/agent/dns/mock_dnsRecursor.go +++ b/agent/dns/mock_dnsRecursor.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.32.4. DO NOT EDIT. package dns @@ -40,13 +40,12 @@ func (_m *mockDnsRecursor) handle(req *miekgdns.Msg, cfgCtx *RouterDynamicConfig return r0, r1 } -type mockConstructorTestingTnewMockDnsRecursor interface { +// newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockDnsRecursor(t interface { mock.TestingT Cleanup(func()) -} - -// newMockDnsRecursor creates a new instance of mockDnsRecursor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func newMockDnsRecursor(t mockConstructorTestingTnewMockDnsRecursor) *mockDnsRecursor { +}) *mockDnsRecursor { mock := &mockDnsRecursor{} mock.Mock.Test(t) diff --git a/agent/dns/router.go b/agent/dns/router.go index 358aa0ce7a..39562c90d3 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "net" + "regexp" + "strings" "sync/atomic" "time" @@ -29,14 +31,18 @@ const ( arpaDomain = "arpa." arpaLabel = "arpa" - suffixFailover = "failover." - suffixNoFailover = "no-failover." + suffixFailover = "failover." + suffixNoFailover = "no-failover." + maxRecursionLevelDefault = 3 // This field comes from the V1 DNS server and affects V1 catalog lookups + maxRecurseRecords = 5 ) var ( errInvalidQuestion = fmt.Errorf("invalid question") errNameNotFound = fmt.Errorf("name not found") errRecursionFailed = fmt.Errorf("recursion failed") + + trailingSpacesRE = regexp.MustCompile(" +$") ) // TODO (v2-dns): metrics @@ -59,7 +65,25 @@ type RouterDynamicConfig struct { TTLStrict map[string]time.Duration UDPAnswerLimit int - enterpriseDNSConfig + discovery.EnterpriseDNSConfig +} + +// GetTTLForService Find the TTL for a given service. +// return ttl, true if found, 0, false otherwise +func (cfg *RouterDynamicConfig) GetTTLForService(service string) (time.Duration, bool) { + if cfg.TTLStrict != nil { + ttl, ok := cfg.TTLStrict[service] + if ok { + return ttl, true + } + } + if cfg.TTLRadix != nil { + _, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service) + if ok { + return ttlRaw.(time.Duration), true + } + } + return 0, false } type SOAConfig struct { @@ -135,6 +159,13 @@ func NewRouter(cfg Config) (*Router, error) { // HandleRequest is used to process an individual DNS request. It returns a message in success or fail cases. func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAddress net.Addr) *dns.Msg { + return r.handleRequestRecursively(req, reqCtx, remoteAddress, maxRecursionLevelDefault) +} + +// handleRequestRecursively is used to process an individual DNS request. It will recurse as needed +// a maximum number of times and returns a message in success or fail cases. +func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx discovery.Context, + remoteAddress net.Addr, maxRecursionLevel int) *dns.Msg { configCtx := r.dynamicConfig.Load().(*RouterDynamicConfig) err := validateAndNormalizeRequest(req) @@ -165,7 +196,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd } reqType := parseRequestType(req) - results, err := r.getQueryResults(req, reqCtx, reqType, configCtx) + results, query, err := r.getQueryResults(req, reqCtx, reqType, configCtx) switch { case errors.Is(err, errNameNotFound): r.logger.Error("name not found", "name", req.Question[0].Name) @@ -185,7 +216,7 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd // This needs the question information because it affects the serialization format. // e.g., the Consul service has the same "results" for both NS and A/AAAA queries, but the serialization differs. - resp, err := r.serializeQueryResults(req, results, configCtx, responseDomain) + resp, err := r.serializeQueryResults(req, reqCtx, query, results, configCtx, responseDomain, remoteAddress, maxRecursionLevel) if err != nil { r.logger.Error("error serializing DNS results", "error", err) return createServerFailureResponse(req, configCtx, false) @@ -193,8 +224,27 @@ func (r *Router) HandleRequest(req *dns.Msg, reqCtx discovery.Context, remoteAdd return resp } +// getTTLForResult returns the TTL for a given result. +func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConfig) uint32 { + switch { + // TODO (v2-dns): currently have to do this related to the results type being changed to node whe + // the v1 data fetcher encounters a blank service address and uses the node address instead. + // we will revisiting this when look at modifying the discovery result struct to + // possibly include additional metadata like the node address. + case query != nil && query.QueryType == discovery.QueryTypeService: + ttl, ok := cfg.GetTTLForService(name) + if ok { + return uint32(ttl / time.Second) + } + fallthrough + default: + return uint32(cfg.NodeTTL / time.Second) + } +} + // getQueryResults returns a discovery.Result from a DNS message. -func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfgCtx *RouterDynamicConfig) ([]*discovery.Result, error) { +func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType requestType, cfg *RouterDynamicConfig) ([]*discovery.Result, *discovery.Query, error) { + var query *discovery.Query switch reqType { case requestTypeConsul: // This is a special case of discovery.QueryByName where we know that we need to query the consul service @@ -206,25 +256,38 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx discovery.Context, reqType }, Limit: 3, // TODO (v2-dns): need to thread this through to the backend and make sure we shuffle the results } - return r.processor.QueryByName(query, reqCtx) + + results, err := r.processor.QueryByName(query, reqCtx) + return results, query, err case requestTypeName: - query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfgCtx, r.defaultEntMeta) + query, err := buildQueryFromDNSMessage(req, r.domain, r.altDomain, cfg, r.defaultEntMeta, r.datacenter) if err != nil { r.logger.Error("error building discovery query from DNS request", "error", err) - return nil, err + return nil, query, err } - return r.processor.QueryByName(query, reqCtx) + results, err := r.processor.QueryByName(query, reqCtx) + if err != nil { + r.logger.Error("error processing discovery query", "error", err) + return nil, query, err + } + return results, query, nil case requestTypeIP: ip := dnsutil.IPFromARPA(req.Question[0].Name) if ip == nil { r.logger.Error("error building IP from DNS request", "name", req.Question[0].Name) - return nil, errNameNotFound + return nil, nil, errNameNotFound } - return r.processor.QueryByIP(ip, reqCtx) + results, err := r.processor.QueryByIP(ip, reqCtx) + return results, query, err case requestTypeAddress: - return buildAddressResults(req) + results, err := buildAddressResults(req) + if err != nil { + r.logger.Error("error processing discovery query", "error", err) + return nil, query, err + } + return results, query, nil } - return nil, errors.New("invalid request type") + return nil, query, errors.New("invalid request type") } // ServeDNS implements the miekg/dns.Handler interface. @@ -304,23 +367,99 @@ func parseRequestType(req *dns.Msg) requestType { } // serializeQueryResults converts a discovery.Result into a DNS message. -func (r *Router) serializeQueryResults(req *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig, responseDomain string) (*dns.Msg, error) { +func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx discovery.Context, + query *discovery.Query, results []*discovery.Result, cfg *RouterDynamicConfig, + responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) (*dns.Msg, error) { resp := new(dns.Msg) resp.SetReply(req) resp.Compress = !cfg.DisableCompression resp.Authoritative = true resp.RecursionAvailable = canRecurse(cfg) + qType := req.Question[0].Qtype + reqType := parseRequestType(req) + + // Always add the SOA record if requested. + switch { + case qType == dns.TypeSOA: + resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg)) + for _, result := range results { + ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel) + resp.Answer = append(resp.Answer, ans...) + resp.Extra = append(resp.Extra, ex...) + resp.Ns = append(resp.Ns, ns...) + } + case qType == dns.TypeSRV, reqType == requestTypeAddress: + for _, result := range results { + ans, ex, ns := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel) + resp.Answer = append(resp.Answer, ans...) + resp.Extra = append(resp.Extra, ex...) + resp.Ns = append(resp.Ns, ns...) + } + default: + // default will send it to where it does some de-duping while it calls getAnswerExtraAndNs and recurses. + r.appendResultsToDNSResponse(req, reqCtx, query, resp, results, cfg, responseDomain, remoteAddress, maxRecursionLevel) + } + + return resp, nil +} + +// appendResultsToDNSResponse builds dns message from the discovery results and +// appends them to the dns response. +func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx discovery.Context, + query *discovery.Query, resp *dns.Msg, results []*discovery.Result, cfg *RouterDynamicConfig, + responseDomain string, remoteAddress net.Addr, maxRecursionLevel int) { + // Always add the SOA record if requested. if req.Question[0].Qtype == dns.TypeSOA { resp.Answer = append(resp.Answer, makeSOARecord(responseDomain, cfg)) } + handled := make(map[string]struct{}) + var answerCNAME []dns.RR = nil + + count := 0 for _, result := range results { - appendResultToDNSResponse(result, req, resp, responseDomain, cfg) + // Add the node record + had_answer := false + ans, extra, _ := r.getAnswerExtraAndNs(result, req, reqCtx, query, cfg, responseDomain, remoteAddress, maxRecursionLevel) + resp.Extra = append(resp.Extra, extra...) + + if len(ans) == 0 { + continue + } + + // Avoid duplicate entries, possible if a node has + // the same service on multiple ports, etc. + if _, ok := handled[ans[0].String()]; ok { + continue + } + handled[ans[0].String()] = struct{}{} + + switch ans[0].(type) { + case *dns.CNAME: + // keep track of the first CNAME + associated RRs but don't add to the resp.Answer yet + // this will only be added if no non-CNAME RRs are found + if len(answerCNAME) == 0 { + answerCNAME = ans + } + default: + resp.Answer = append(resp.Answer, ans...) + had_answer = true + } + + if had_answer { + count++ + if count == cfg.ARecordLimit { + // We stop only if greater than 0 or we reached the limit + return + } + } } - return resp, nil + if len(resp.Answer) == 0 && len(answerCNAME) > 0 { + resp.Answer = answerCNAME + } } // defaultAgentDNSRequestContext returns a default request context based on the agent's config. @@ -332,6 +471,46 @@ func (r *Router) defaultAgentDNSRequestContext() discovery.Context { } } +// resolveCNAME is used to recursively resolve CNAME records +func (r *Router) resolveCNAME(cfg *RouterDynamicConfig, name string, reqCtx discovery.Context, + remoteAddress net.Addr, maxRecursionLevel int) []dns.RR { + // If the CNAME record points to a Consul address, resolve it internally + // Convert query to lowercase because DNS is case insensitive; d.domain and + // d.altDomain are already converted + + if ln := strings.ToLower(name); strings.HasSuffix(ln, "."+r.domain) || strings.HasSuffix(ln, "."+r.altDomain) { + if maxRecursionLevel < 1 { + //d.logger.Error("Infinite recursion detected for name, won't perform any CNAME resolution.", "name", name) + return nil + } + req := &dns.Msg{} + + req.SetQuestion(name, dns.TypeANY) + // TODO: handle error response + resp := r.handleRequestRecursively(req, reqCtx, nil, maxRecursionLevel-1) + + return resp.Answer + } + + // Do nothing if we don't have a recursor + if !canRecurse(cfg) { + return nil + } + + // Ask for any A records + m := new(dns.Msg) + m.SetQuestion(name, dns.TypeA) + + // Make a DNS lookup request + recursorResponse, err := r.recursor.handle(m, cfg, remoteAddress) + if err == nil { + return recursorResponse.Answer + } + + r.logger.Error("all resolvers failed for name", "name", name) + return nil +} + // validateAndNormalizeRequest validates the DNS request and normalizes the request name. func validateAndNormalizeRequest(req *dns.Msg) error { // like upstream miekg/dns, we require at least one question, @@ -406,10 +585,26 @@ func getDynamicRouterConfig(conf *config.RuntimeConfig) (*RouterDynamicConfig, e Refresh: conf.DNSSOA.Refresh, Retry: conf.DNSSOA.Retry, }, - enterpriseDNSConfig: getEnterpriseDNSConfig(conf), + EnterpriseDNSConfig: discovery.GetEnterpriseDNSConfig(conf), } - // TODO (v2-dns): add service TTL recalculation + if conf.DNSServiceTTL != nil { + cfg.TTLRadix = radix.New() + cfg.TTLStrict = make(map[string]time.Duration) + + for key, ttl := range conf.DNSServiceTTL { + // All suffix with '*' are put in radix + // This include '*' that will match anything + if strings.HasSuffix(key, "*") { + cfg.TTLRadix.Insert(key[:len(key)-1], ttl) + } else { + cfg.TTLStrict[key] = ttl + } + } + } else { + cfg.TTLRadix = nil + cfg.TTLStrict = nil + } for _, r := range conf.DNSRecursors { ra, err := formatRecursorAddress(r) @@ -545,30 +740,18 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { }, nil } -// buildQueryFromDNSMessage appends the discovery result to the dns message. -func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns.Msg, domain string, cfg *RouterDynamicConfig) { - ip, ok := convertToIp(result) - - // if the result is not an IP, we can try to recurse on the hostname. - // TODO (v2-dns): hostnames are valid for workloads in V2, do we just want to return the CNAME? - if !ok { - // TODO (v2-dns): recurse on HandleRequest() - panic("not implemented") +// getAnswerAndExtra creates the dns answer and extra from discovery results. +func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx discovery.Context, + query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) { + address, target := getAddressAndTargetFromDiscoveryResult(result, r.domain) + qName := req.Question[0].Name + ttlLookupName := qName + if query != nil { + ttlLookupName = query.QueryPayload.Name } - - var ttl uint32 - switch result.Type { - case discovery.ResultTypeNode, discovery.ResultTypeVirtual, discovery.ResultTypeWorkload: - ttl = uint32(cfg.NodeTTL / time.Second) - case discovery.ResultTypeService: - // TODO (v2-dns): implement service TTL using the radix tree - } - - qName := dns.CanonicalName(req.Question[0].Name) + ttl := getTTLForResult(ttlLookupName, query, cfg) qType := req.Question[0].Qtype - record, isIPV4 := makeRecord(qName, ip, ttl) - // TODO (v2-dns): skip records that refer to a workload/node that don't have a valid DNS name. // Special case responses @@ -579,54 +762,120 @@ func appendResultToDNSResponse(result *discovery.Result, req *dns.Msg, resp *dns Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, Ptr: canonicalNameForResult(result, domain), } - resp.Answer = append(resp.Answer, ptr) - return + answer = append(answer, ptr) case qType == dns.TypeNS: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result fqdn := canonicalNameForResult(result, domain) - extraRecord, _ := makeRecord(fqdn, ip, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported - - resp.Answer = append(resp.Ns, makeNSRecord(domain, fqdn, ttl)) - resp.Extra = append(resp.Extra, extraRecord) - return + extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + answer = append(answer, makeNSRecord(domain, fqdn, ttl)) + extra = append(extra, extraRecord) case qType == dns.TypeSOA: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result // to be returned in the result. fqdn := canonicalNameForResult(result, domain) - extraRecord, _ := makeRecord(fqdn, ip, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported - resp.Ns = append(resp.Ns, makeNSRecord(domain, fqdn, ttl)) - resp.Extra = append(resp.Extra, extraRecord) - return + ns = append(ns, makeNSRecord(domain, fqdn, ttl)) + extra = append(extra, extraRecord) case qType == dns.TypeSRV: // We put A/AAAA/CNAME records in the additional section for SRV requests - resp.Extra = append(resp.Extra, record) + a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, + result, ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) - // TODO (v2-dns): implement SRV records for the answer section - return + cfg := r.dynamicConfig.Load().(*RouterDynamicConfig) + if cfg.NodeMetaTXT { + extra = append(extra, makeTXTRecord(target.FQDN(), result, ttl)...) + } + default: + a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, + result, ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) } - - // For explicit A/AAAA queries, we must only return those records in the answer section. - if isIPV4 && qType != dns.TypeA && qType != dns.TypeANY { - resp.Extra = append(resp.Extra, record) - return - } - if !isIPV4 && qType != dns.TypeAAAA && qType != dns.TypeANY { - resp.Extra = append(resp.Extra, record) - return - } - - resp.Answer = append(resp.Answer, record) + return } -// convertToIp converts a discovery.Result to a net.IP. -func convertToIp(result *discovery.Result) (net.IP, bool) { - ip := net.ParseIP(result.Address) - if ip == nil { - return nil, false +// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs. +func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg, + reqCtx discovery.Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr, + maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { + qName := req.Question[0].Name + reqType := parseRequestType(req) + + cfg := r.dynamicConfig.Load().(*RouterDynamicConfig) + switch { + + // There is no target and the address is a FQDN (external service) + case address.IsFQDN(): + a, e := r.makeRecordFromFQDN(address.FQDN(), result, req, reqCtx, + cfg, ttl, remoteAddress, maxRecursionLevel) + answer = append(a, answer...) + extra = append(e, extra...) + + // The target is a FQDN (internal or external service name) + case result.Type != discovery.ResultTypeNode && target.IsFQDN(): + a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx, + cfg, ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) + + // There is no target and the address is an IP + case address.IsIP(): + // TODO (v2-dns): Do not CNAME node address in case of WAN address. + ipRecordName := target.FQDN() + if maxRecursionLevel < maxRecursionLevelDefault || ipRecordName == "" { + ipRecordName = qName + } + a, e := getAnswerExtrasForIP(ipRecordName, address, req.Question[0], reqType, result, ttl) + answer = append(answer, a...) + extra = append(extra, e...) + + // The target is an IP + case target.IsIP(): + a, e := getAnswerExtrasForIP(qName, target, req.Question[0], reqType, result, ttl) + answer = append(answer, a...) + extra = append(extra, e...) + + // The target is a CNAME for the service we are looking + // for. So we use the address. + case target.FQDN() == req.Question[0].Name && address.IsIP(): + a, e := getAnswerExtrasForIP(qName, address, req.Question[0], reqType, result, ttl) + answer = append(answer, a...) + extra = append(extra, e...) + + // The target is a FQDN (internal or external service name) + default: + a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel) + answer = append(a, answer...) + extra = append(e, extra...) } - return ip, true + return +} + +// getAddressAndTargetFromDiscoveryResult returns the address and target from a discovery result. +func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, reqType requestType, result *discovery.Result, ttl uint32) (answer []dns.RR, extra []dns.RR) { + record := makeIPBasedRecord(name, addr, ttl) + qType := question.Qtype + + isARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeA && qType != dns.TypeA && qType != dns.TypeANY + isAAAARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeAAAA && qType != dns.TypeAAAA && qType != dns.TypeANY + + // For explicit A/AAAA queries, we must only return those records in the answer section. + if isARecordWhenNotExplicitlyQueried || + isAAAARecordWhenNotExplicitlyQueried { + extra = append(extra, record) + } else { + answer = append(answer, record) + } + + if reqType != requestTypeAddress && qType == dns.TypeSRV { + srv := makeSRVRecord(name, name, result, ttl) + answer = append(answer, srv) + } + return } func makeSOARecord(domain string, cfg *RouterDynamicConfig) dns.RR { @@ -660,13 +909,12 @@ func makeNSRecord(domain, fqdn string, ttl uint32) dns.RR { } } -// makeRecord an A or AAAA record for the given name and IP. +// makeIPBasedRecord an A or AAAA record for the given name and IP. // Note: we might want to pass in the Query Name here, which is used in addr. and virtual. queries // since there is only ever one result. Right now choosing to leave it off for simplification. -func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) { - isIPV4 := ip.To4() != nil +func makeIPBasedRecord(name string, addr *dnsAddress, ttl uint32) dns.RR { - if isIPV4 { + if addr.IsIPV4() { // check if the query type is A for IPv4 or ANY return &dns.A{ Hdr: dns.RR_Header{ @@ -675,8 +923,8 @@ func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) { Class: dns.ClassINET, Ttl: ttl, }, - A: ip, - }, true + A: addr.IP(), + } } return &dns.AAAA{ @@ -686,6 +934,126 @@ func makeRecord(name string, ip net.IP, ttl uint32) (dns.RR, bool) { Class: dns.ClassINET, Ttl: ttl, }, - AAAA: ip, - }, false + AAAA: addr.IP(), + } +} + +func (r *Router) makeRecordFromFQDN(fqdn string, result *discovery.Result, + req *dns.Msg, reqCtx discovery.Context, cfg *RouterDynamicConfig, ttl uint32, + remoteAddress net.Addr, maxRecursionLevel int) ([]dns.RR, []dns.RR) { + edns := req.IsEdns0() != nil + q := req.Question[0] + + more := r.resolveCNAME(cfg, dns.Fqdn(fqdn), reqCtx, remoteAddress, maxRecursionLevel) + var additional []dns.RR + extra := 0 +MORE_REC: + for _, rr := range more { + switch rr.Header().Rrtype { + case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA: + // set the TTL manually + rr.Header().Ttl = ttl + additional = append(additional, rr) + + extra++ + if extra == maxRecurseRecords && !edns { + break MORE_REC + } + } + } + + if q.Qtype == dns.TypeSRV { + answers := []dns.RR{ + makeSRVRecord(q.Name, fqdn, result, ttl), + } + return answers, additional + } + + answers := []dns.RR{ + makeCNAMERecord(result, q.Name, ttl), + } + answers = append(answers, additional...) + + return answers, nil +} + +// makeCNAMERecord returns a CNAME record for the given name and target. +func makeCNAMERecord(result *discovery.Result, qName string, ttl uint32) *dns.CNAME { + return &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: qName, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: ttl, + }, + Target: dns.Fqdn(result.Target), + } +} + +// func makeSRVRecord returns an SRV record for the given name and target. +func makeSRVRecord(name, target string, result *discovery.Result, ttl uint32) *dns.SRV { + return &dns.SRV{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: ttl, + }, + Priority: 1, + Weight: uint16(result.Weight), + Port: uint16(result.Port), + Target: target, + } +} + +// encodeKVasRFC1464 encodes a key-value pair according to RFC1464 +func encodeKVasRFC1464(key, value string) (txt string) { + // For details on these replacements c.f. https://www.ietf.org/rfc/rfc1464.txt + key = strings.Replace(key, "`", "``", -1) + key = strings.Replace(key, "=", "`=", -1) + + // Backquote the leading spaces + leadingSpacesRE := regexp.MustCompile("^ +") + numLeadingSpaces := len(leadingSpacesRE.FindString(key)) + key = leadingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numLeadingSpaces)) + + // Backquote the trailing spaces + numTrailingSpaces := len(trailingSpacesRE.FindString(key)) + key = trailingSpacesRE.ReplaceAllString(key, strings.Repeat("` ", numTrailingSpaces)) + + value = strings.Replace(value, "`", "``", -1) + + return key + "=" + value +} + +// makeTXTRecord returns a TXT record for the given name and result metadata. +func makeTXTRecord(name string, result *discovery.Result, ttl uint32) []dns.RR { + extra := make([]dns.RR, 0, len(result.Metadata)) + for key, value := range result.Metadata { + txt := value + if !strings.HasPrefix(strings.ToLower(key), "rfc1035-") { + txt = encodeKVasRFC1464(key, value) + } + + extra = append(extra, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: ttl, + }, + Txt: []string{txt}, + }) + } + return extra +} + +// getAddressAndTargetFromCheckServiceNode returns the address and target for a given discovery.Result +func getAddressAndTargetFromDiscoveryResult(result *discovery.Result, domain string) (*dnsAddress, *dnsAddress) { + target := newDNSAddress(result.Target) + if !target.IsEmptyString() && !target.IsInternalFQDNOrIP(domain) { + target.SetAddress(canonicalNameForResult(result, domain)) + } + address := newDNSAddress(result.Address) + return address, target } diff --git a/agent/dns/router_ce.go b/agent/dns/router_ce.go index d5fff53235..5ffc9f51ce 100644 --- a/agent/dns/router_ce.go +++ b/agent/dns/router_ce.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/discovery" ) +// canonicalNameForResult returns the canonical name for a discovery result. func canonicalNameForResult(result *discovery.Result, domain string) string { switch result.Type { case discovery.ResultTypeService: diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 5f46413681..847dc45c5e 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -14,44 +14,77 @@ import ( ) // buildQueryFromDNSMessage returns a discovery.Query from a DNS message. -func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, cfgCtx *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta) (*discovery.Query, error) { +func buildQueryFromDNSMessage(req *dns.Msg, domain, altDomain string, + cfg *RouterDynamicConfig, defaultEntMeta acl.EnterpriseMeta, defaultDatacenter string) (*discovery.Query, error) { queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) - locality, ok := ParseLocality(querySuffixes, defaultEntMeta, cfgCtx.enterpriseDNSConfig) - if !ok { - return nil, errors.New("invalid locality") + queryTenancy, err := getQueryTenancy(queryType, querySuffixes, defaultEntMeta, cfg, defaultDatacenter) + if err != nil { + return nil, err } - // TODO(v2-dns): This needs to be deprecated. - peerName := locality.peer - if peerName == "" { - // If the peer name was not explicitly defined, fall back to the ambiguously-parsed version. - peerName = locality.peerOrDatacenter - } + name, tag := getQueryNameAndTagFromParts(queryType, queryParts) return &discovery.Query{ QueryType: queryType, QueryPayload: discovery.QueryPayload{ - Name: queryParts[len(queryParts)-1], - Tenancy: discovery.QueryTenancy{ - EnterpriseMeta: locality.EnterpriseMeta, - // v2-dns: revisit if we need this after the rest of this works. - // SamenessGroup: "", - // The datacenter of the request is not specified because cross-datacenter virtual IP - // queries are not supported. This guard rail is in place because virtual IPs are allocated - // within a DC, therefore their uniqueness is not guaranteed globally. - Peer: peerName, - Datacenter: locality.datacenter, - }, - // TODO(v2-dns): what should these be? + Name: name, + Tenancy: queryTenancy, + Tag: tag, + // TODO (v2-dns): what should these be? //PortName: "", - //Tag: "", //RemoteAddr: nil, //DisableFailover: false, }, }, nil } +// getQueryNameAndTagFromParts returns the query name and tag from the query parts that are taken from the original dns question. +func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []string) (string, string) { + switch queryType { + case discovery.QueryTypeService: + n := len(queryParts) + // Support RFC 2782 style syntax + if n == 2 && strings.HasPrefix(queryParts[1], "_") && strings.HasPrefix(queryParts[0], "_") { + // Grab the tag since we make nuke it if it's tcp + tag := queryParts[1][1:] + + // Treat _name._tcp.service.consul as a default, no need to filter on that tag + if tag == "tcp" { + tag = "" + } + + name := queryParts[0][1:] + // _name._tag.service.consul + return name, tag + } + return queryParts[len(queryParts)-1], "" + } + return queryParts[len(queryParts)-1], "" +} + +// getQueryTenancy returns a discovery.QueryTenancy from a DNS message. +func getQueryTenancy(queryType discovery.QueryType, querySuffixes []string, + defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) { + if queryType == discovery.QueryTypeService { + return getQueryTenancyForService(querySuffixes, defaultEntMeta, cfg, defaultDatacenter) + } + + locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig) + if !ok { + return discovery.QueryTenancy{}, errors.New("invalid locality") + } + + if queryType == discovery.QueryTypeVirtual { + if locality.Peer == "" { + // If the peer name was not explicitly defined, fall back to the ambiguously-parsed version. + locality.Peer = locality.PeerOrDatacenter + } + } + + return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter) +} + // getQueryTypePartsAndSuffixesFromDNSMessage returns the query type, the parts, and suffixes of the query name. func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain string) (queryType discovery.QueryType, parts []string, suffixes []string) { // Get the QName without the domain suffix @@ -64,18 +97,19 @@ func getQueryTypePartsAndSuffixesFromDNSMessage(req *dns.Msg, domain, altDomain for i := len(labels) - 1; i >= 0 && !done; i-- { queryType = getQueryTypeFromLabels(labels[i]) switch queryType { - case discovery.QueryTypeInvalid: - // If we don't recognize the query type, we keep going until we find one we do. case discovery.QueryTypeService, discovery.QueryTypeConnect, discovery.QueryTypeVirtual, discovery.QueryTypeIngress, discovery.QueryTypeNode, discovery.QueryTypePreparedQuery: parts = labels[:i] suffixes = labels[i+1:] done = true + case discovery.QueryTypeInvalid: + fallthrough default: // If this is a SRV query the "service" label is optional, we add it back to use the // existing code-path. if req.Question[0].Qtype == dns.TypeSRV && strings.HasPrefix(labels[i], "_") { + queryType = discovery.QueryTypeService parts = labels[:i+1] suffixes = labels[i+1:] done = true diff --git a/agent/dns/router_query_ce.go b/agent/dns/router_query_ce.go new file mode 100644 index 0000000000..bbe868a2c8 --- /dev/null +++ b/agent/dns/router_query_ce.go @@ -0,0 +1,24 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !consulent + +package dns + +import ( + "errors" + + "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/discovery" +) + +// getQueryTenancy returns a discovery.QueryTenancy from a DNS message. +func getQueryTenancyForService(querySuffixes []string, + defaultEntMeta acl.EnterpriseMeta, cfg *RouterDynamicConfig, defaultDatacenter string) (discovery.QueryTenancy, error) { + locality, ok := discovery.ParseLocality(querySuffixes, defaultEntMeta, cfg.EnterpriseDNSConfig) + if !ok { + return discovery.QueryTenancy{}, errors.New("invalid locality") + } + + return discovery.GetQueryTenancyBasedOnLocality(locality, defaultDatacenter) +} diff --git a/agent/dns/router_query_test.go b/agent/dns/router_query_test.go index 14348bfb96..726ef32ba8 100644 --- a/agent/dns/router_query_test.go +++ b/agent/dns/router_query_test.go @@ -29,7 +29,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}) + query, err := buildQueryFromDNSMessage(tc.request, "domain", "altDomain", &RouterDynamicConfig{}, acl.EnterpriseMeta{}, "defaultDatacenter") require.NoError(t, err) assert.Equal(t, tc.expectedQuery, query) }) diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 3064204b24..33fcbdb72e 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -706,7 +706,7 @@ func Test_HandleRequest(t *testing.T) { }, Question: []dns.Question{ { - Name: "c000020a.virtual.consul", // "intentionally missing the trailing dot" + Name: "c000020a.virtual.dc1.consul", // "intentionally missing the trailing dot" Qtype: dns.TypeA, Qclass: dns.ClassINET, }, @@ -728,7 +728,7 @@ func Test_HandleRequest(t *testing.T) { Compress: true, Question: []dns.Question{ { - Name: "c000020a.virtual.consul.", + Name: "c000020a.virtual.dc1.consul.", Qtype: dns.TypeA, Qclass: dns.ClassINET, }, @@ -736,7 +736,7 @@ func Test_HandleRequest(t *testing.T) { Answer: []dns.RR{ &dns.A{ Hdr: dns.RR_Header{ - Name: "c000020a.virtual.consul.", + Name: "c000020a.virtual.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 123, @@ -1345,6 +1345,58 @@ func Test_HandleRequest(t *testing.T) { } +func TestRouterDynamicConfig_GetTTLForService(t *testing.T) { + type testCase struct { + name string + inputKey string + shouldMatch bool + expectedDuration time.Duration + } + + testCases := []testCase{ + { + name: "strict match", + inputKey: "foo", + shouldMatch: true, + expectedDuration: 1 * time.Second, + }, + { + name: "wildcard match", + inputKey: "bar", + shouldMatch: true, + expectedDuration: 2 * time.Second, + }, + { + name: "wildcard match 2", + inputKey: "bart", + shouldMatch: true, + expectedDuration: 2 * time.Second, + }, + { + name: "no match", + inputKey: "homer", + shouldMatch: false, + expectedDuration: 0 * time.Second, + }, + } + + rtCfg := &config.RuntimeConfig{ + DNSServiceTTL: map[string]time.Duration{ + "foo": 1 * time.Second, + "bar*": 2 * time.Second, + }, + } + cfg, err := getDynamicRouterConfig(rtCfg) + require.NoError(t, err) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual, ok := cfg.GetTTLForService(tc.inputKey) + require.Equal(t, tc.shouldMatch, ok) + require.Equal(t, tc.expectedDuration, actual) + }) + } +} func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogDataFetcher, _ error) Config { cfg := Config{ AgentConfig: &config.RuntimeConfig{ diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index 3361477076..9e021824ae 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -20,6 +20,7 @@ import ( "github.com/hashicorp/consul/testrpc" ) +// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -76,6 +77,7 @@ func TestDNS_ServiceReverseLookup(t *testing.T) { } } +// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -132,6 +134,7 @@ func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { } } +// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -190,6 +193,7 @@ func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { } } +// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -252,7 +256,7 @@ func TestDNS_ServiceLookupNoMultiCNAME(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -315,7 +319,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -364,7 +368,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) { in, _, err := c.Exchange(m, a.DNSAddr()) require.NoError(t, err) - // expect a CNAME and an A RR + // expect an A RR require.Len(t, in.Answer, 1) aRec, ok := in.Answer[0].(*dns.A) require.Truef(t, ok, "Not an A RR") @@ -381,7 +385,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -457,6 +461,7 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -585,6 +590,8 @@ func TestDNS_ServiceLookup(t *testing.T) { } } +// TODO (v2-dns): this is formulating the correct response +// but failing with an I/O timeout on the dns client Exchange() call func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -820,7 +827,7 @@ func TestDNS_ExternalServiceLookup(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -858,7 +865,7 @@ func TestDNS_ExternalServiceLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Answer) != 1 { + if len(in.Answer) != 1 || len(in.Extra) > 0 { t.Fatalf("Bad: %#v", in) } @@ -886,7 +893,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` domain = "CONSUL." @@ -1121,6 +1128,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1222,6 +1230,7 @@ func TestDNS_ServiceLookup_ServiceAddress_A(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1330,6 +1339,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1451,6 +1461,7 @@ func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1552,6 +1563,7 @@ func TestDNS_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1660,6 +1672,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1876,6 +1889,7 @@ func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1978,6 +1992,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { } } +// TODO (v2-dns): this returns a response where the answer is an SOA record func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2058,6 +2073,7 @@ func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { } } +// TODO (v2-dns): this returns a response where the answer is an SOA record func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2145,6 +2161,7 @@ func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_Dedup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2256,6 +2273,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2831,6 +2849,7 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_Randomize(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2930,6 +2949,7 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_Truncate(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3007,6 +3027,7 @@ func TestDNS_ServiceLookup_Truncate(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_LargeResponses(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3386,7 +3407,6 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) { } } -// TODO(jmurret): func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3462,6 +3482,7 @@ func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_CNAME(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3567,6 +3588,7 @@ func TestDNS_ServiceLookup_CNAME(t *testing.T) { } } +// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_ServiceAddress_CNAME(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -3679,7 +3701,7 @@ func TestDNS_ServiceLookup_TTL(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` dns_config { @@ -3765,7 +3787,7 @@ func TestDNS_ServiceLookup_SRV_RFC(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -3847,75 +3869,81 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) { } t.Parallel() - a := NewTestAgent(t, "") - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") + for name, experimentsHCL := range getVersionHCL(true) { + t.Run(name, func(t *testing.T) { + a := NewTestAgent(t, experimentsHCL) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") - // Register node - args := &structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - Service: "db", - Tags: []string{"primary"}, - Port: 12345, - }, - } + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "db", + Tags: []string{"primary"}, + Port: 12345, + }, + } - var out struct{} - if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { - t.Fatalf("err: %v", err) - } + var out struct{} + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } - questions := []string{ - "_db._tcp.service.dc1.consul.", - "_db._tcp.service.consul.", - "_db._tcp.dc1.consul.", - "_db._tcp.consul.", - } + questions := []string{ + "_db._tcp.service.dc1.consul.", + "_db._tcp.service.consul.", + "_db._tcp.dc1.consul.", + "_db._tcp.consul.", + } - for _, question := range questions { - m := new(dns.Msg) - m.SetQuestion(question, dns.TypeSRV) + for _, question := range questions { + t.Run(question, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeSRV) - c := new(dns.Client) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } - if len(in.Answer) != 1 { - t.Fatalf("Bad: %#v", in) - } + if len(in.Answer) != 1 { + t.Fatalf("Bad: %#v", in) + } - srvRec, ok := in.Answer[0].(*dns.SRV) - if !ok { - t.Fatalf("Bad: %#v", in.Answer[0]) - } - if srvRec.Port != 12345 { - t.Fatalf("Bad: %#v", srvRec) - } - if srvRec.Target != "foo.node.dc1.consul." { - t.Fatalf("Bad: %#v", srvRec) - } - if srvRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Answer[0]) - } + srvRec, ok := in.Answer[0].(*dns.SRV) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if srvRec.Port != 12345 { + t.Fatalf("Bad: %#v", srvRec) + } + if srvRec.Target != "foo.node.dc1.consul." { + t.Fatalf("Bad: %#v", srvRec) + } + if srvRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Answer[0]) + } - aRec, ok := in.Extra[0].(*dns.A) - if !ok { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if aRec.Hdr.Name != "foo.node.dc1.consul." { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if aRec.A.String() != "127.0.0.1" { - t.Fatalf("Bad: %#v", in.Extra[0]) - } - if aRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Extra[0]) - } + aRec, ok := in.Extra[0].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + if aRec.Hdr.Name != "foo.node.dc1.consul." { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + if aRec.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + if aRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Extra[0]) + } + }) + } + }) } } @@ -3982,41 +4010,45 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) { } ` - a := NewTestAgent(t, hcl) - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") + for name, experimentsHCL := range getVersionHCL(false) { + t.Run(name, func(t *testing.T) { + a := NewTestAgent(t, hcl+experimentsHCL) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") - if tt.token == "dns" { - initDNSToken(t, a) - } + if tt.token == "dns" { + initDNSToken(t, a) + } - // Register a service - args := &structs.RegisterRequest{ - Datacenter: "dc1", - Node: "foo", - Address: "127.0.0.1", - Service: &structs.NodeService{ - Service: "foo", - Port: 12345, - }, - WriteRequest: structs.WriteRequest{Token: "root"}, - } - var out struct{} - if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { - t.Fatalf("err: %v", err) - } + // Register a service + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + Service: &structs.NodeService{ + Service: "foo", + Port: 12345, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out struct{} + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } - // Set up the DNS query - c := new(dns.Client) - m := new(dns.Msg) - m.SetQuestion("foo.service.consul.", dns.TypeA) + // Set up the DNS query + c := new(dns.Client) + m := new(dns.Msg) + m.SetQuestion("foo.service.consul.", dns.TypeA) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } - if len(in.Answer) != tt.results { - t.Fatalf("Bad: %#v", in) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(in.Answer) != tt.results { + t.Fatalf("Bad: %#v", in) + } + }) } }) } @@ -4027,49 +4059,53 @@ func TestDNS_ServiceLookup_MetaTXT(t *testing.T) { t.Skip("too slow for testing.Short") } - a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = true }`) - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") + for name, experimentsHCL := range getVersionHCL(true) { + t.Run(name, func(t *testing.T) { + a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = true } `+experimentsHCL) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") - args := &structs.RegisterRequest{ - Datacenter: "dc1", - Node: "bar", - Address: "127.0.0.1", - NodeMeta: map[string]string{ - "key": "value", - }, - Service: &structs.NodeService{ - Service: "db", - Tags: []string{"primary"}, - Port: 12345, - }, + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "bar", + Address: "127.0.0.1", + NodeMeta: map[string]string{ + "key": "value", + }, + Service: &structs.NodeService{ + Service: "db", + Tags: []string{"primary"}, + Port: 12345, + }, + } + + var out struct{} + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("db.service.consul.", dns.TypeSRV) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAdditional := []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4}, + A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 + }, + &dns.TXT{ + Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa}, + Txt: []string{"key=value"}, + }, + } + require.Equal(t, wantAdditional, in.Extra) + }) } - - var out struct{} - if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { - t.Fatalf("err: %v", err) - } - - m := new(dns.Msg) - m.SetQuestion("db.service.consul.", dns.TypeSRV) - - c := new(dns.Client) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } - - wantAdditional := []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4}, - A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 - }, - &dns.TXT{ - Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Rdlength: 0xa}, - Txt: []string{"key=value"}, - }, - } - require.Equal(t, wantAdditional, in.Extra) } func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) { @@ -4077,44 +4113,48 @@ func TestDNS_ServiceLookup_SuppressTXT(t *testing.T) { t.Skip("too slow for testing.Short") } - a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = false }`) - defer a.Shutdown() - testrpc.WaitForLeader(t, a.RPC, "dc1") + for name, experimentsHCL := range getVersionHCL(true) { + t.Run(name, func(t *testing.T) { + a := NewTestAgent(t, `dns_config = { enable_additional_node_meta_txt = false } `+experimentsHCL) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") - // Register a node with a service. - args := &structs.RegisterRequest{ - Datacenter: "dc1", - Node: "bar", - Address: "127.0.0.1", - NodeMeta: map[string]string{ - "key": "value", - }, - Service: &structs.NodeService{ - Service: "db", - Tags: []string{"primary"}, - Port: 12345, - }, + // Register a node with a service. + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "bar", + Address: "127.0.0.1", + NodeMeta: map[string]string{ + "key": "value", + }, + Service: &structs.NodeService{ + Service: "db", + Tags: []string{"primary"}, + Port: 12345, + }, + } + + var out struct{} + if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("db.service.consul.", dns.TypeSRV) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAdditional := []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4}, + A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 + }, + } + require.Equal(t, wantAdditional, in.Extra) + }) } - - var out struct{} - if err := a.RPC(context.Background(), "Catalog.Register", args, &out); err != nil { - t.Fatalf("err: %v", err) - } - - m := new(dns.Msg) - m.SetQuestion("db.service.consul.", dns.TypeSRV) - - c := new(dns.Client) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } - - wantAdditional := []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{Name: "bar.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4}, - A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 - }, - } - require.Equal(t, wantAdditional, in.Extra) }