From 602e3c4fd50488c9044679c32364550bc7d49b29 Mon Sep 17 00:00:00 2001 From: John Murret Date: Fri, 2 Feb 2024 20:23:52 -0700 Subject: [PATCH] DNS V2 - Revise discovery result to have service and node name and address fields. (#20468) * DNS V2 - Revise discovery result to have service and node name and address fields. * NET-7488 - dns v2 add support for prepared queries in catalog v1 data model (#20470) NET-7488 - dns v2 add support for prepared queries in catalog v1 data model. --- agent/discovery/discovery.go | 22 +- agent/discovery/discovery_test.go | 4 +- agent/discovery/query_fetcher_v1.go | 252 ++++++++++++++++------ agent/discovery/query_fetcher_v1_test.go | 259 +++++------------------ agent/discovery/query_fetcher_v2.go | 10 +- agent/discovery/query_fetcher_v2_test.go | 13 +- agent/dns/router.go | 216 ++++++++++++------- agent/dns/router_ce.go | 21 +- agent/dns/router_ce_test.go | 8 +- agent/dns/router_query.go | 27 ++- agent/dns/router_query_test.go | 2 +- agent/dns/router_test.go | 127 ++++++++--- agent/dns_service_lookup_test.go | 35 +-- 13 files changed, 570 insertions(+), 426 deletions(-) diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index 0439ca20bc..c56a65f7a1 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -79,11 +79,11 @@ type QueryTenancy struct { // QueryPayload represents all information needed by the data backend // to decide which records to include. type QueryPayload struct { - Name string - PortName string // v1 - this could optionally be "connect" or "ingress"; v2 - this is the service port name - Tag string // deprecated: use for V1 only - RemoteAddr net.Addr // deprecated: used for prepared queries - Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain + Name string + PortName string // v1 - this could optionally be "connect" or "ingress"; v2 - this is the service port name + Tag string // deprecated: use for V1 only + SourceIP net.IP // deprecated: used for prepared queries + Tenancy QueryTenancy // tenancy includes any additional labels specified before the domain // v2 fields only EnableFailover bool @@ -104,19 +104,23 @@ const ( // It is the responsibility of the DNS encoder to know what to do with // each Result, based on the query type. type Result struct { - Address string // A/AAAA/CNAME records - could be used in the Extra section. CNAME is required to handle hostname addresses in workloads & nodes. + Service *Location // The name and address of the service. + Node *Location // The name and address of the node. Weight uint32 // SRV queries PortName string // Used to generate a fgdn when a specifc port was queried PortNumber uint32 // SRV queries Metadata map[string]string // Used to collect metadata into TXT Records Type ResultType // Used to reconstruct the fqdn name of the resource - // Used in SRV & PTR queries to point at an A/AAAA Record. - Target string - Tenancy ResultTenancy } +// Location is used to represent a service, node, or workload. +type Location struct { + Name string + Address string +} + // ResultTenancy is used to reconstruct the fqdn name of the resource. type ResultTenancy struct { Namespace string diff --git a/agent/discovery/discovery_test.go b/agent/discovery/discovery_test.go index af7fd148b3..a53ec7b866 100644 --- a/agent/discovery/discovery_test.go +++ b/agent/discovery/discovery_test.go @@ -26,9 +26,9 @@ var ( } testResult = &Result{ - Address: "1.2.3.4", + Node: &Location{Address: "1.2.3.4"}, Type: ResultTypeNode, // This isn't correct for some test cases, but we are only asserting the right data fetcher functions are called - Target: "foo", + Service: &Location{Name: "foo"}, } ) diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 81c73dca66..c3146a48ac 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -33,6 +33,10 @@ type v1DataFetcherDynamicConfig struct { // Default request tenancy datacenter string + segmentName string + nodeName string + nodePartition string + // Catalog configuration allowStale bool maxStale time.Duration @@ -115,17 +119,19 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e } results := make([]*Result, 0, 1) - node := out.NodeServices.Node + n := out.NodeServices.Node results = append(results, &Result{ - Address: node.Address, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, Type: ResultTypeNode, - Metadata: node.Meta, - Target: node.Node, + Metadata: n.Meta, Tenancy: ResultTenancy{ // Namespace is not required because nodes are not namespaced - Partition: node.GetEnterpriseMeta().PartitionOrDefault(), - Datacenter: node.Datacenter, + Partition: n.GetEnterpriseMeta().PartitionOrDefault(), + Datacenter: n.Datacenter, }, }) @@ -163,8 +169,11 @@ func (f *V1DataFetcher) FetchVirtualIP(ctx Context, req *QueryPayload) (*Result, } result := &Result{ - Address: out, - Type: ResultTypeVirtual, + Service: &Location{ + Name: req.Name, + Address: out, + }, + Type: ResultTypeVirtual, } return result, nil } @@ -196,9 +205,11 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, for _, n := range out.Nodes { if targetIP == n.Address { results = append(results, &Result{ - Address: n.Address, - Type: ResultTypeNode, - Target: n.Node, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, + Type: ResultTypeNode, Tenancy: ResultTenancy{ Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), @@ -226,13 +237,19 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, for _, n := range sout.ServiceNodes { if n.ServiceAddress == targetIP { results = append(results, &Result{ - Address: n.ServiceAddress, - Type: ResultTypeService, - Target: n.ServiceName, + Service: &Location{ + Name: n.ServiceName, + Address: n.ServiceAddress, + }, + Type: ResultTypeService, + Node: &Location{ + Name: n.Node, + Address: n.Address, + }, Tenancy: ResultTenancy{ - Namespace: f.defaultEnterpriseMeta.NamespaceOrDefault(), - Partition: f.defaultEnterpriseMeta.PartitionOrDefault(), - Datacenter: configCtx.datacenter, + Namespace: n.NamespaceOrEmpty(), + Partition: n.PartitionOrEmpty(), + Datacenter: n.Datacenter, }, }) return results, nil @@ -256,7 +273,119 @@ func (f *V1DataFetcher) FetchWorkload(ctx Context, req *QueryPayload) (*Result, // FetchPreparedQuery evaluates the results of a prepared query. // deprecated in V2 func (f *V1DataFetcher) FetchPreparedQuery(ctx Context, req *QueryPayload) ([]*Result, error) { - return nil, nil + cfg := f.dynamicConfig.Load().(*v1DataFetcherDynamicConfig) + + // Execute the prepared query. + args := structs.PreparedQueryExecuteRequest{ + Datacenter: req.Tenancy.Datacenter, + QueryIDOrName: req.Name, + QueryOptions: structs.QueryOptions{ + Token: ctx.Token, + AllowStale: cfg.allowStale, + MaxAge: cfg.cacheMaxAge, + }, + + // Always pass the local agent through. In the DNS interface, there + // is no provision for passing additional query parameters, so we + // send the local agent's data through to allow distance sorting + // relative to ourself on the server side. + Agent: structs.QuerySource{ + Datacenter: cfg.datacenter, + Segment: cfg.segmentName, + Node: cfg.nodeName, + NodePartition: cfg.nodePartition, + }, + Source: structs.QuerySource{ + Ip: req.SourceIP.String(), + }, + } + + out, err := f.executePreparedQuery(cfg, args) + if err != nil { + return nil, err + } + + // (v2-dns) TODO: (v2-dns) get TTLS working. They come from the database so not having + // TTL on the discovery result poses challenges. + + /* + // TODO (slackpad) - What's a safe limit we can set here? It seems like + // with dup filtering done at this level we need to get everything to + // match the previous behavior. We can optimize by pushing more filtering + // into the query execution, but for now I think we need to get the full + // response. We could also choose a large arbitrary number that will + // likely work in practice, like 10*maxUDPAnswerLimit which should help + // reduce bandwidth if there are thousands of nodes available. + // Determine the TTL. The parse should never fail since we vet it when + // the query is created, but we check anyway. If the query didn't + // specify a TTL then we will try to use the agent's service-specific + // TTL configs. + var ttl time.Duration + if out.DNS.TTL != "" { + var err error + ttl, err = time.ParseDuration(out.DNS.TTL) + if err != nil { + f.logger.Warn("Failed to parse TTL for prepared query , ignoring", + "ttl", out.DNS.TTL, + "prepared_query", req.Name, + ) + } + } else { + ttl, _ = cfg.GetTTLForService(out.Service) + } + */ + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + return nil, ErrNoData + } + + // Perform a random shuffle + out.Nodes.Shuffle() + return f.buildResultsFromServiceNodes(out.Nodes), nil +} + +// executePreparedQuery is used to execute a PreparedQuery against the Consul catalog. +// If the config is set to UseCache, it will use agent cache. +func (f *V1DataFetcher) executePreparedQuery(cfg *v1DataFetcherDynamicConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { + var out structs.PreparedQueryExecuteResponse + +RPC: + if cfg.useCache { + raw, m, err := f.getFromCacheFunc(context.TODO(), cachetype.PreparedQueryName, &args) + if err != nil { + return nil, err + } + reply, ok := raw.(*structs.PreparedQueryExecuteResponse) + if !ok { + // This should never happen, but we want to protect against panics + return nil, err + } + + f.logger.Trace("cache results for prepared query", + "cache_hit", m.Hit, + "prepared_query", args.QueryIDOrName, + ) + + out = *reply + } else { + if err := f.rpcFunc(context.Background(), "PreparedQuery.Execute", &args, &out); err != nil { + return nil, err + } + } + + // Verify that request is not too stale, redo the request. + if args.AllowStale { + if out.LastContact > cfg.maxStale { + args.AllowStale = false + f.logger.Warn("Query results too stale, re-requesting") + goto RPC + } else if out.LastContact > staleCounterThreshold { + metrics.IncrCounter([]string{"dns", "stale_queries"}, 1) + } + } + + return &out, nil } func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { @@ -269,6 +398,34 @@ func (f *V1DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { return validateEnterpriseTenancy(req.Tenancy) } +// buildResultsFromServiceNodes builds a list of results from a list of nodes. +func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServiceNode) []*Result { + results := make([]*Result, 0) + for _, n := range nodes { + + results = append(results, &Result{ + Service: &Location{ + Name: n.Service.Service, + Address: n.Service.Address, + }, + Node: &Location{ + Name: n.Node.Node, + Address: n.Node.Address, + }, + Type: ResultTypeService, + Weight: uint32(findWeight(n)), + PortNumber: uint32(f.translateServicePortFunc(n.Node.Datacenter, n.Service.Port, n.Service.TaggedAddresses)), + Metadata: n.Node.Meta, + Tenancy: ResultTenancy{ + Namespace: n.Service.NamespaceOrEmpty(), + Partition: n.Service.PartitionOrEmpty(), + Datacenter: n.Node.Datacenter, + }, + }) + } + return results +} + // fetchNode is used to look up a node in the Consul catalog within NodeServices. // If the config is set to UseCache, it will get the record from the agent cache. func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { @@ -353,7 +510,12 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa out, _, err := f.rpcFuncForServiceNodes(context.TODO(), args) if err != nil { - return nil, err + return nil, fmt.Errorf("rpc request failed: %w", err) + } + + // If we have no nodes, return not found! + if len(out.Nodes) == 0 { + return nil, ErrNoData } // Filter out any service nodes due to health checks @@ -372,57 +534,7 @@ func (f *V1DataFetcher) fetchServiceBasedOnTenancy(ctx Context, req *QueryPayloa // Perform a random shuffle out.Nodes.Shuffle() - results := make([]*Result, 0, len(out.Nodes)) - for _, node := range out.Nodes { - address, target, resultType := getAddressTargetAndResultType(node) - - results = append(results, &Result{ - Address: address, - Type: resultType, - Target: target, - Weight: uint32(findWeight(node)), - PortNumber: uint32(f.translateServicePortFunc(node.Node.Datacenter, node.Service.Port, node.Service.TaggedAddresses)), - Metadata: node.Node.Meta, - Tenancy: ResultTenancy{ - Namespace: node.Service.NamespaceOrEmpty(), - Partition: node.Service.PartitionOrEmpty(), - Datacenter: node.Node.Datacenter, - }, - }) - } - - return results, nil -} - -// getAddressTargetAndResultType returns the address, target and result type for a check service node. -func getAddressTargetAndResultType(node structs.CheckServiceNode) (string, string, ResultType) { - // Set address and target - // if service address is present, set target and address based on service. - // otherwise get it from the node. - address := node.Service.Address - target := node.Service.Service - resultType := ResultTypeService - - addressIP := net.ParseIP(address) - if addressIP == nil { - resultType = ResultTypeNode - if node.Service.Address != "" { - // cases where service address is foo or foo.node.consul - // For usage in DNS, these discovery results necessitate a CNAME record. - // These cases can be inferred from the discovery result when Type is Node and - // target is not an IP. - target = node.Service.Address - } else { - // cases where service address is empty and the service is bound to - // node with an address. These do not require a CNAME record in. - // For usage in DNS, these discovery results do not require a CNAME record. - // These cases can be inferred from the discovery result when Type is Node and - // target is not an IP. - target = node.Node.Node - } - address = node.Node.Address - } - return address, target, resultType + return f.buildResultsFromServiceNodes(out.Nodes), nil } // findWeight returns the weight of a service node. diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index 703548f3e5..95f3a9fe50 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -51,8 +51,11 @@ func Test_FetchVirtualIP(t *testing.T) { Token: "test-token", }, expectedResult: &Result{ - Address: "192.168.10.10", - Type: ResultTypeVirtual, + Service: &Location{ + Name: "db", + Address: "192.168.10.10", + }, + Type: ResultTypeVirtual, }, expectedErr: nil, }, @@ -97,7 +100,7 @@ func Test_FetchVirtualIP(t *testing.T) { if tc.expectedErr == nil { // set the out parameter to ensure that it is used to formulate the result.Address reply := args.Get(3).(*string) - *reply = tc.expectedResult.Address + *reply = tc.expectedResult.Service.Address } }) // TODO (v2-dns): mock these properly @@ -131,210 +134,62 @@ func Test_FetchEndpoints(t *testing.T) { DNSUseCache: true, DNSCacheMaxAge: 100, } - tests := []struct { - name string - queryPayload *QueryPayload - context Context - rpcFuncForServiceNodes func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) - expectedResults []*Result - expectedErr error - }{ + ctx := Context{ + Token: "test-token", + } + expectedResults := []*Result{ { - name: "when service address is IPv4, result type is service, address is service address and target is service name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, + Node: &Location{ + Name: "node-name", + Address: "node-address", }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "127.0.0.1", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil + Service: &Location{ + Name: "service-name", + Address: "service-address", }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "127.0.0.1", - Target: "service-name", - Type: ResultTypeService, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is IPv6, result type is service, address is service address and target is service name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "2001:db8:1:2:cafe::1337", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "2001:db8:1:2:cafe::1337", - Target: "service-name", - Type: ResultTypeService, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is not IP but is not empty, result type is node, address is node address, and target is service address", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "foo", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "node-address", - Target: "foo", - Type: ResultTypeNode, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, - }, - { - name: "when service address is empty, result type is node, address is node address, and target is node name", - queryPayload: &QueryPayload{ - Name: "service-name", - Tenancy: QueryTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - rpcFuncForServiceNodes: func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { - return structs.IndexedCheckServiceNodes{ - Nodes: []structs.CheckServiceNode{ - { - Node: &structs.Node{ - Address: "node-address", - Node: "node-name", - Partition: defaultTestPartition, - }, - Service: &structs.NodeService{ - Address: "", - Service: "service-name", - EnterpriseMeta: acl.NewEnterpriseMetaWithPartition(defaultTestPartition, defaultTestNamespace), - }, - }, - }, - }, cache.ResultMeta{}, nil - }, - context: Context{ - Token: "test-token", - }, - expectedResults: []*Result{ - { - Address: "node-address", - Target: "node-name", - Type: ResultTypeNode, - Weight: 1, - Tenancy: ResultTenancy{ - Namespace: defaultTestNamespace, - Partition: defaultTestPartition, - }, - }, - }, - expectedErr: nil, + Type: ResultTypeService, + Weight: 1, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - logger := testutil.Logger(t) - mockRPC := cachetype.NewMockRPC(t) - // TODO (v2-dns): mock these properly - translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 } - rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) { - return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil - } - getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) { - return nil, cache.ResultMeta{}, nil - } - - df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, tc.rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger) - - results, err := df.FetchEndpoints(tc.context, tc.queryPayload, LookupTypeService) - require.Equal(t, tc.expectedErr, err) - require.Equal(t, tc.expectedResults, results) - }) + logger := testutil.Logger(t) + mockRPC := cachetype.NewMockRPC(t) + // TODO (v2-dns): mock these properly + translateServicePortFunc := func(dc string, port int, taggedAddresses map[string]structs.ServiceAddress) int { return 0 } + rpcFuncForSamenessGroup := func(ctx context.Context, req *structs.ConfigEntryQuery) (structs.SamenessGroupConfigEntry, cache.ResultMeta, error) { + return structs.SamenessGroupConfigEntry{}, cache.ResultMeta{}, nil } + getFromCacheFunc := func(ctx context.Context, t string, r cache.Request) (interface{}, cache.ResultMeta, error) { + return nil, cache.ResultMeta{}, nil + } + rpcFuncForServiceNodes := func(ctx context.Context, req structs.ServiceSpecificRequest) (structs.IndexedCheckServiceNodes, cache.ResultMeta, error) { + return structs.IndexedCheckServiceNodes{ + Nodes: []structs.CheckServiceNode{ + { + Node: &structs.Node{ + Address: "node-address", + Node: "node-name", + }, + Service: &structs.NodeService{ + Address: "service-address", + Service: "service-name", + }, + }, + }, + }, cache.ResultMeta{}, nil + } + queryPayload := &QueryPayload{ + Name: "service-name", + Tenancy: QueryTenancy{ + Peer: "test-peer", + Namespace: defaultTestNamespace, + Partition: defaultTestPartition, + }, + } + + df := NewV1DataFetcher(rc, acl.DefaultEnterpriseMeta(), getFromCacheFunc, mockRPC.RPC, rpcFuncForServiceNodes, rpcFuncForSamenessGroup, translateServicePortFunc, logger) + + results, err := df.FetchEndpoints(ctx, queryPayload, LookupTypeService) + require.NoError(t, err) + require.Equal(t, expectedResults, results) } diff --git a/agent/discovery/query_fetcher_v2.go b/agent/discovery/query_fetcher_v2.go index 5371b6f4b0..0fd4dd9c74 100644 --- a/agent/discovery/query_fetcher_v2.go +++ b/agent/discovery/query_fetcher_v2.go @@ -124,13 +124,15 @@ func (f *V2DataFetcher) FetchWorkload(reqContext Context, req *QueryPayload) (*R tenancy := response.GetResource().GetId().GetTenancy() result := &Result{ - Address: address, - Type: ResultTypeWorkload, + Node: &Location{ + Address: address, + Name: response.GetResource().GetId().GetName(), + }, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: tenancy.GetNamespace(), Partition: tenancy.GetPartition(), }, - Target: response.GetResource().GetId().GetName(), } if req.PortName == "" { @@ -169,7 +171,7 @@ func (f *V2DataFetcher) ValidateRequest(_ Context, req *QueryPayload) error { if req.Tag != "" { return ErrNotSupported } - if req.RemoteAddr != nil { + if req.SourceIP != nil { return ErrNotSupported } return nil diff --git a/agent/discovery/query_fetcher_v2_test.go b/agent/discovery/query_fetcher_v2_test.go index 86e2af6b63..f93e3d5f48 100644 --- a/agent/discovery/query_fetcher_v2_test.go +++ b/agent/discovery/query_fetcher_v2_test.go @@ -58,13 +58,12 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", - Type: ResultTypeWorkload, + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: resource.DefaultNamespaceName, Partition: resource.DefaultPartitionName, }, - Target: "foo-1234", }, expectedErr: nil, }, @@ -130,7 +129,7 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, Type: ResultTypeWorkload, PortName: "api", PortNumber: 5678, @@ -138,7 +137,6 @@ func Test_FetchWorkload(t *testing.T) { Namespace: resource.DefaultNamespaceName, Partition: resource.DefaultPartitionName, }, - Target: "foo-1234", }, expectedErr: nil, }, @@ -189,13 +187,12 @@ func Test_FetchWorkload(t *testing.T) { }) }, expectedResult: &Result{ - Address: "1.2.3.4", - Type: ResultTypeWorkload, + Node: &Location{Name: "foo-1234", Address: "1.2.3.4"}, + Type: ResultTypeWorkload, Tenancy: ResultTenancy{ Namespace: "test-namespace", Partition: "test-partition", }, - Target: "foo-1234", }, expectedErr: nil, }, diff --git a/agent/dns/router.go b/agent/dns/router.go index 94732ee5de..ecf3b9fdd2 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -42,6 +42,7 @@ var ( errInvalidQuestion = fmt.Errorf("invalid question") errNameNotFound = fmt.Errorf("name not found") errNotImplemented = fmt.Errorf("not implemented") + errQueryNotFound = fmt.Errorf("query not found") errRecursionFailed = fmt.Errorf("recursion failed") trailingSpacesRE = regexp.MustCompile(" +$") @@ -93,7 +94,7 @@ type DiscoveryQueryProcessor interface { // //go:generate mockery --name dnsRecursor --inpackage type dnsRecursor interface { - handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddr net.Addr) (*dns.Msg, error) + handle(req *dns.Msg, cfgCtx *RouterDynamicConfig, remoteAddress net.Addr) (*dns.Msg, error) } // Router replaces miekg/dns.ServeMux with a simpler router that only checks for the 2-3 valid domains @@ -126,12 +127,13 @@ func NewRouter(cfg Config) (*Router, error) { logger := cfg.Logger.Named(logging.DNS) router := &Router{ - processor: cfg.Processor, - recursor: newRecursor(logger), - domain: domain, - altDomain: altDomain, - logger: logger, - tokenFunc: cfg.TokenFunc, + processor: cfg.Processor, + recursor: newRecursor(logger), + domain: domain, + altDomain: altDomain, + datacenter: cfg.AgentConfig.Datacenter, + logger: logger, + tokenFunc: cfg.TokenFunc, } if err := router.ReloadConfig(cfg.AgentConfig); err != nil { @@ -160,7 +162,7 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, return createServerFailureResponse(req, configCtx, false) } - responseDomain, needRecurse := r.parseDomain(req) + responseDomain, needRecurse := r.parseDomain(req.Question[0].Name) if needRecurse && !canRecurse(configCtx) { // This is the same error as an unmatched domain return createRefusedResponse(req) @@ -187,7 +189,7 @@ func (r *Router) handleRequestRecursively(req *dns.Msg, reqCtx Context, } reqType := parseRequestType(req) - results, query, err := r.getQueryResults(req, reqCtx, reqType, qName) + results, query, err := r.getQueryResults(req, reqCtx, reqType, qName, remoteAddress) switch { case errors.Is(err, errNameNotFound): r.logger.Error("name not found", "name", qName) @@ -272,7 +274,8 @@ func getTTLForResult(name string, query *discovery.Query, cfg *RouterDynamicConf } // getQueryResults returns a discovery.Result from a DNS message. -func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, qName string) ([]*discovery.Result, *discovery.Query, error) { +func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestType, + qName string, remoteAddress net.Addr) ([]*discovery.Result, *discovery.Query, error) { switch reqType { case requestTypeConsul: // This is a special case of discovery.QueryByName where we know that we need to query the consul service @@ -295,7 +298,7 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) return results, query, err case requestTypeName: - query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain) + query, err := buildQueryFromDNSMessage(req, reqCtx, r.domain, r.altDomain, remoteAddress) if err != nil { r.logger.Error("error building discovery query from DNS request", "error", err) return nil, query, err @@ -303,6 +306,13 @@ func (r *Router) getQueryResults(req *dns.Msg, reqCtx Context, reqType requestTy results, err := r.processor.QueryByName(query, discovery.Context{Token: reqCtx.Token}) if err != nil { r.logger.Error("error processing discovery query", "error", err) + switch err.Error() { + case errNameNotFound.Error(): + return nil, query, errNameNotFound + case errQueryNotFound.Error(): + return nil, query, errQueryNotFound + } + return nil, query, err } return results, query, nil @@ -376,8 +386,8 @@ const ( // it will return true for needRecurse. The logic is based on miekg/dns.ServeDNS matcher. // The implementation assumes that the only valid domains are "consul." and the alternative domain, and // that DS query types are not supported. -func (r *Router) parseDomain(req *dns.Msg) (string, bool) { - target := dns.CanonicalName(req.Question[0].Name) +func (r *Router) parseDomain(questionName string) (string, bool) { + target := dns.CanonicalName(questionName) target, _ = stripSuffix(target) for offset, overflow := 0, false; !overflow; offset, overflow = dns.NextLabel(target, offset) { @@ -786,8 +796,10 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { return []*discovery.Result{ { - Address: ip.String(), - Type: discovery.ResultTypeNode, // We choose node by convention since we do not know the origin of the IP + Node: &discovery.Location{ + Address: ip.String(), + }, + Type: discovery.ResultTypeNode, // We choose node by convention since we do not know the origin of the IP }, }, nil } @@ -796,8 +808,14 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, reqCtx Context, query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) { - target := newDNSAddress(result.Target) - address := newDNSAddress(result.Address) + serviceAddress := newDNSAddress("") + if result.Service != nil { + serviceAddress = newDNSAddress(result.Service.Address) + } + nodeAddress := newDNSAddress("") + if result.Node != nil { + nodeAddress = newDNSAddress(result.Node.Address) + } qName := req.Question[0].Name ttlLookupName := qName if query != nil { @@ -812,53 +830,60 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, req *dns.Msg, req switch { // PTR requests are first since they are a special case of domain overriding question type case parseRequestType(req) == requestTypeIP: + ptrTarget := "" + if result.Type == discovery.ResultTypeNode { + ptrTarget = result.Node.Name + } else if result.Type == discovery.ResultTypeService { + ptrTarget = result.Service.Name + } + ptr := &dns.PTR{ Hdr: dns.RR_Header{Name: qName, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0}, - Ptr: canonicalNameForResult(result, domain), + Ptr: canonicalNameForResult(result.Type, ptrTarget, domain, result.Tenancy, result.PortName), } answer = append(answer, ptr) case qType == dns.TypeNS: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result - fqdn := canonicalNameForResult(result, domain) - extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + fqdn := canonicalNameForResult(result.Type, serviceAddress.String(), domain, result.Tenancy, result.PortName) + extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported answer = append(answer, makeNSRecord(domain, fqdn, ttl)) extra = append(extra, extraRecord) case qType == dns.TypeSOA: // TODO (v2-dns): fqdn in V1 has the datacenter included, this would need to be added to discovery.Result // to be returned in the result. - fqdn := canonicalNameForResult(result, domain) - extraRecord := makeIPBasedRecord(fqdn, address, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported + fqdn := canonicalNameForResult(result.Type, serviceAddress.String(), domain, result.Tenancy, result.PortName) + extraRecord := makeIPBasedRecord(fqdn, nodeAddress, ttl) // TODO (v2-dns): this is not sufficient, because recursion and CNAMES are supported ns = append(ns, makeNSRecord(domain, fqdn, ttl)) extra = append(extra, extraRecord) case qType == dns.TypeSRV: // We put A/AAAA/CNAME records in the additional section for SRV requests - a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, - result, ttl, remoteAddress, cfg, maxRecursionLevel) + a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx, + result, ttl, remoteAddress, cfg, domain, maxRecursionLevel) answer = append(answer, a...) extra = append(extra, e...) if cfg.NodeMetaTXT { - name := target.FQDN() - if !target.IsInternalFQDN(r.domain) && !target.IsExternalFQDN(r.domain) { - name = canonicalNameForResult(result, r.domain) + name := serviceAddress.FQDN() + if !serviceAddress.IsInternalFQDN(r.domain) && !serviceAddress.IsExternalFQDN(r.domain) { + name = canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName) } extra = append(extra, makeTXTRecord(name, result, ttl)...) } default: - a, e := r.getAnswerExtrasForAddressAndTarget(address, target, req, reqCtx, - result, ttl, remoteAddress, cfg, maxRecursionLevel) + a, e := r.getAnswerExtrasForAddressAndTarget(nodeAddress, serviceAddress, req, reqCtx, + result, ttl, remoteAddress, cfg, domain, maxRecursionLevel) answer = append(answer, a...) extra = append(extra, e...) } return } -// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from address and target dnsAddress pairs. -func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target *dnsAddress, req *dns.Msg, +// getAnswerExtrasForAddressAndTarget creates the dns answer and extra from nodeAddress and serviceAddress dnsAddress pairs. +func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, serviceAddress *dnsAddress, req *dns.Msg, reqCtx Context, result *discovery.Result, ttl uint32, remoteAddress net.Addr, - cfg *RouterDynamicConfig, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { + cfg *RouterDynamicConfig, domain string, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR) { qName := req.Question[0].Name reqType := parseRequestType(req) @@ -866,64 +891,82 @@ func (r *Router) getAnswerExtrasForAddressAndTarget(address *dnsAddress, target // Virtual IPs and Address requests // both return IPs with empty targets case (reqType == requestTypeAddress || result.Type == discovery.ResultTypeVirtual) && - target.IsEmptyString() && address.IsIP(): - a, e := getAnswerExtrasForIP(qName, address, req.Question[0], reqType, - result, ttl) - answer = append(a, answer...) - extra = append(e, extra...) - - // Address is a FQDN and requires a CNAME lookup. - case address.IsFQDN(): - a, e := r.makeRecordFromFQDN(address.FQDN(), result, req, reqCtx, - cfg, ttl, remoteAddress, maxRecursionLevel) - answer = append(a, answer...) - extra = append(e, extra...) - - // Target is FQDN that point to IP - case target.IsFQDN() && address.IsIP(): - var a, e []dns.RR - if result.Type == discovery.ResultTypeNode || result.Type == discovery.ResultTypeWorkload { - // if it is a node record it means the service address pointed to a node - // and the node address was used. So we create an A record for the node address, - // as well as a CNAME for the service to node mapping. - name := target.FQDN() - if !target.IsInternalFQDN(r.domain) && !target.IsExternalFQDN(r.domain) { - name = canonicalNameForResult(result, r.domain) - } else if target.IsInternalFQDN(r.domain) { - answer = append(answer, makeCNAMERecord(qName, canonicalNameForResult(result, r.domain), ttl)) - } - a, e = getAnswerExtrasForIP(name, address, req.Question[0], reqType, - result, ttl) - } else { - // if it is a service record, it means that the service address had the IP directly - // and there was not a need for an intermediate CNAME. - a, e = getAnswerExtrasForIP(qName, address, req.Question[0], reqType, - result, ttl) - } + serviceAddress.IsEmptyString() && nodeAddress.IsIP(): + a, e := getAnswerExtrasForIP(qName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) answer = append(answer, a...) extra = append(extra, e...) - // The target is a FQDN (internal or external service name) - default: - a, e := r.makeRecordFromFQDN(target.FQDN(), result, req, reqCtx, cfg, + case result.Type == discovery.ResultTypeNode: + canonicalNodeName := canonicalNameForResult(result.Type, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + case serviceAddress.IsEmptyString() && nodeAddress.IsEmptyString(): + return nil, nil + + // There is no service address and the node address is an IP + case serviceAddress.IsEmptyString() && nodeAddress.IsIP(): + canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // There is no service address and the node address is a FQDN (external service) + case serviceAddress.IsEmptyString(): + a, e := r.makeRecordFromFQDN(nodeAddress.FQDN(), result, req, reqCtx, cfg, ttl, remoteAddress, maxRecursionLevel) - answer = append(a, answer...) - extra = append(e, extra...) + answer = append(answer, a...) + extra = append(extra, e...) + + // The service address is an IP + case serviceAddress.IsIP(): + canonicalServiceName := canonicalNameForResult(discovery.ResultTypeService, result.Service.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalServiceName, serviceAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // If the service address is a CNAME for the service we are looking + // for then use the node address. + case serviceAddress.FQDN() == req.Question[0].Name && nodeAddress.IsIP(): + canonicalNodeName := canonicalNameForResult(discovery.ResultTypeNode, result.Node.Name, domain, result.Tenancy, result.PortName) + a, e := getAnswerExtrasForIP(canonicalNodeName, nodeAddress, req.Question[0], reqType, + result, ttl, domain) + answer = append(answer, a...) + extra = append(extra, e...) + + // The service address is a FQDN (internal or external service name) + default: + a, e := r.makeRecordFromFQDN(serviceAddress.FQDN(), result, req, reqCtx, cfg, + ttl, remoteAddress, maxRecursionLevel) + answer = append(answer, a...) + extra = append(extra, e...) } return } // getAnswerExtrasForIP creates the dns answer and extra from IP dnsAddress pairs. func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, - reqType requestType, result *discovery.Result, ttl uint32) (answer []dns.RR, extra []dns.RR) { + reqType requestType, result *discovery.Result, ttl uint32, _ string) (answer []dns.RR, extra []dns.RR) { qType := question.Qtype // Have to pass original question name here even if the system has recursed // and stripped off the domain suffix. recHdrName := question.Name if qType == dns.TypeSRV { - recHdrName = name + nameSplit := strings.Split(name, ".") + if len(nameSplit) > 1 && nameSplit[1] == addrLabel { + recHdrName = name + } else { + recHdrName = name + } + name = question.Name } + record := makeIPBasedRecord(recHdrName, addr, ttl) isARecordWhenNotExplicitlyQueried := record.Header().Rrtype == dns.TypeA && qType != dns.TypeA && qType != dns.TypeANY @@ -938,12 +981,28 @@ func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, } if reqType != requestTypeAddress && qType == dns.TypeSRV { - srv := makeSRVRecord(name, name, result, ttl) + srv := makeSRVRecord(name, recHdrName, result, ttl) answer = append(answer, srv) } return } +// encodeIPAsFqdn encodes an IP address as a FQDN. +func encodeIPAsFqdn(result *discovery.Result, ip net.IP, responseDomain string) string { + ipv4 := ip.To4() + ipStr := hex.EncodeToString(ip) + if ipv4 != nil { + ipStr = ipStr[len(ipStr)-(net.IPv4len*2):] + } + if result.Tenancy.PeerName != "" { + // Exclude the datacenter from the FQDN on the addr for peers. + // This technically makes no difference, since the addr endpoint ignores the DC + // component of the request, but do it anyway for a less confusing experience. + return fmt.Sprintf("%s.addr.%s", ipStr, responseDomain) + } + return fmt.Sprintf("%s.addr.%s.%s", ipStr, result.Tenancy.Datacenter, responseDomain) +} + func makeSOARecord(domain string, cfg *RouterDynamicConfig) dns.RR { return &dns.SOA{ Hdr: dns.RR_Header{ @@ -1035,8 +1094,15 @@ MORE_REC: return answers, additional } + address := "" + if result.Service != nil && result.Service.Address != "" { + address = result.Service.Address + } else if result.Node != nil { + address = result.Node.Address + } + answers := []dns.RR{ - makeCNAMERecord(q.Name, result.Target, ttl), + makeCNAMERecord(q.Name, address, ttl), } answers = append(answers, additional...) diff --git a/agent/dns/router_ce.go b/agent/dns/router_ce.go index 3a44ca1cdc..67cab00490 100644 --- a/agent/dns/router_ce.go +++ b/agent/dns/router_ce.go @@ -12,26 +12,27 @@ import ( ) // canonicalNameForResult returns the canonical name for a discovery result. -func canonicalNameForResult(result *discovery.Result, domain string) string { - switch result.Type { +func canonicalNameForResult(resultType discovery.ResultType, target, domain string, + tenancy discovery.ResultTenancy, portName string) string { + switch resultType { case discovery.ResultTypeService: - return fmt.Sprintf("%s.%s.%s.%s", result.Target, "service", result.Tenancy.Datacenter, domain) + return fmt.Sprintf("%s.%s.%s.%s", target, "service", tenancy.Datacenter, domain) case discovery.ResultTypeNode: - if result.Tenancy.PeerName != "" { + if tenancy.PeerName != "" { // We must return a more-specific DNS name for peering so // that there is no ambiguity with lookups. return fmt.Sprintf("%s.node.%s.peer.%s", - result.Target, - result.Tenancy.PeerName, + target, + tenancy.PeerName, domain) } // Return a simpler format for non-peering nodes. - return fmt.Sprintf("%s.node.%s.%s", result.Target, result.Tenancy.Datacenter, domain) + return fmt.Sprintf("%s.node.%s.%s", target, tenancy.Datacenter, domain) case discovery.ResultTypeWorkload: - if result.PortName != "" { - return fmt.Sprintf("%s.port.%s.workload.%s", result.PortName, result.Target, domain) + if portName != "" { + return fmt.Sprintf("%s.port.%s.workload.%s", portName, target, domain) } - return fmt.Sprintf("%s.workload.%s", result.Target, domain) + return fmt.Sprintf("%s.workload.%s", target, domain) } return "" } diff --git a/agent/dns/router_ce_test.go b/agent/dns/router_ce_test.go index 69f73e2dbf..72455249fc 100644 --- a/agent/dns/router_ce_test.go +++ b/agent/dns/router_ce_test.go @@ -37,9 +37,9 @@ func getAdditionalTestCases(t *testing.T) []HandleTestCase { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foonode", Address: "1.2.3.4"}, Type: discovery.ResultTypeNode, - Target: "foo", + Service: &discovery.Location{Name: "foo", Address: "foo"}, Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", PeerName: "peer1", @@ -100,9 +100,9 @@ func getAdditionalTestCases(t *testing.T) []HandleTestCase { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foonode", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "foo", Address: "foo"}, Type: discovery.ResultTypeService, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index 13a4935be0..0bc7aae19f 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -4,6 +4,7 @@ package dns import ( + "net" "strings" "github.com/miekg/dns" @@ -12,7 +13,8 @@ import ( ) // buildQueryFromDNSMessage returns a discovery.Query from a DNS message. -func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string) (*discovery.Query, error) { +func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain string, + remoteAddress net.Addr) (*discovery.Query, error) { queryType, queryParts, querySuffixes := getQueryTypePartsAndSuffixesFromDNSMessage(req, domain, altDomain) queryTenancy, err := getQueryTenancy(reqCtx, queryType, querySuffixes) @@ -36,7 +38,7 @@ func buildQueryFromDNSMessage(req *dns.Msg, reqCtx Context, domain, altDomain st Tenancy: queryTenancy, Tag: tag, PortName: portName, - //RemoteAddr: nil, // TODO (v2-dns): Prepared Queries for V1 Catalog + SourceIP: getSourceIP(req, queryType, remoteAddress), }, }, nil } @@ -177,3 +179,24 @@ func getQueryTypeFromLabels(label string) discovery.QueryType { return discovery.QueryTypeInvalid } } + +// getSourceIP returns the source IP from the dns request. +func getSourceIP(req *dns.Msg, queryType discovery.QueryType, remoteAddr net.Addr) (sourceIP net.IP) { + if queryType == discovery.QueryTypePreparedQuery { + subnet := ednsSubnetForRequest(req) + + if subnet != nil { + sourceIP = subnet.Address + } else { + switch v := remoteAddr.(type) { + case *net.UDPAddr: + sourceIP = v.IP + case *net.TCPAddr: + sourceIP = v.IP + case *net.IPAddr: + sourceIP = v.IP + } + } + } + return sourceIP +} diff --git a/agent/dns/router_query_test.go b/agent/dns/router_query_test.go index dc4ea6592e..94182de9e0 100644 --- a/agent/dns/router_query_test.go +++ b/agent/dns/router_query_test.go @@ -206,7 +206,7 @@ func Test_buildQueryFromDNSMessage(t *testing.T) { if context == nil { context = &Context{} } - query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".") + query, err := buildQueryFromDNSMessage(tc.request, *context, "consul.", ".", nil) require.NoError(t, err) assert.Equal(t, tc.expectedQuery, query) }) diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index aa38d91ef3..36c9eb8331 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -21,8 +21,6 @@ import ( "github.com/hashicorp/consul/agent/structs" ) -// TODO (v2-dns) - // TBD Test Cases // 1. Reload the configuration (e.g. SOA) // 2. Something to check the token makes it through to the data fetcher @@ -717,8 +715,8 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", mock.Anything, mock.Anything).Return(&discovery.Result{ - Address: "240.0.0.2", - Type: discovery.ResultTypeVirtual, + Node: &discovery.Location{Address: "240.0.0.2"}, + Type: discovery.ResultTypeVirtual, }, nil) }, validateAndNormalizeExpected: true, @@ -767,8 +765,8 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { fetcher.(*discovery.MockCatalogDataFetcher).On("FetchVirtualIP", mock.Anything, mock.Anything).Return(&discovery.Result{ - Address: "2001:db8:1:2:cafe::1337", - Type: discovery.ResultTypeVirtual, + Node: &discovery.Location{Address: "2001:db8:1:2:cafe::1337"}, + Type: discovery.ResultTypeVirtual, }, nil) }, validateAndNormalizeExpected: true, @@ -819,14 +817,14 @@ func Test_HandleRequest(t *testing.T) { On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). Return([]*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "service-one", Address: "server-one"}, Type: discovery.ResultTypeWorkload, - Target: "server-one", // This would correlate to the workload name }, { - Address: "4.5.6.7", + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Service: &discovery.Location{Name: "service-one", Address: "server-two"}, Type: discovery.ResultTypeWorkload, - Target: "server-two", // This would correlate to the workload name }, }, nil). Run(func(args mock.Arguments) { @@ -941,14 +939,14 @@ func Test_HandleRequest(t *testing.T) { On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). Return([]*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "server-one", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "service-one", Address: "server-one"}, Type: discovery.ResultTypeWorkload, - Target: "server-one", // This would correlate to the workload name }, { - Address: "4.5.6.7", + Node: &discovery.Location{Name: "server-two", Address: "4.5.6.7"}, + Service: &discovery.Location{Name: "service-two", Address: "server-two"}, Type: discovery.ResultTypeWorkload, - Target: "server-two", // This would correlate to the workload name }, }, nil). Run(func(args mock.Arguments) { @@ -1051,9 +1049,9 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "1.2.3.4", + Node: &discovery.Location{Name: "foo", Address: "1.2.3.4"}, + Service: &discovery.Location{Name: "bar", Address: "foo"}, Type: discovery.ResultTypeNode, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, @@ -1113,9 +1111,9 @@ func Test_HandleRequest(t *testing.T) { configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { results := []*discovery.Result{ { - Address: "2001:db8::567:89ab", + Node: &discovery.Location{Name: "foo", Address: "2001:db8::567:89ab"}, + Service: &discovery.Location{Name: "web", Address: "foo"}, Type: discovery.ResultTypeNode, - Target: "foo", Tenancy: discovery.ResultTenancy{ Datacenter: "dc2", }, @@ -1315,6 +1313,85 @@ func Test_HandleRequest(t *testing.T) { }, }, }, + { + // TestDNS_ExternalServiceToConsulCNAMELookup + name: "req type: service / question type: SRV / CNAME required: no", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "alias", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeVirtual, + Service: &discovery.Location{Name: "alias", Address: "web.service.consul"}, + Node: &discovery.Location{Name: "web", Address: "web.service.consul"}, + }, + }, + nil).On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "web", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeNode, + Service: &discovery.Location{Name: "web", Address: "webnode"}, + Node: &discovery.Location{Name: "webnode", Address: "127.0.0.2"}, + }, + }, nil).On("ValidateRequest", mock.Anything, + mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + Answer: []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "alias.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 123, + }, + Target: "web.service.consul.", + Priority: 1, + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "web.service.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("127.0.0.2"), + }, + }, + }, + }, // TODO (v2-dns): add a test to make sure only 3 records are returned // V2 Workload Lookup { @@ -1333,12 +1410,12 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", + Node: &discovery.Location{Address: "1.2.3.4"}, Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{}, PortName: "api", PortNumber: 5678, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). @@ -1394,10 +1471,10 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", + Node: &discovery.Location{Address: "1.2.3.4"}, Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{}, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). @@ -1453,14 +1530,14 @@ func Test_HandleRequest(t *testing.T) { }, configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { result := &discovery.Result{ - Address: "1.2.3.4", - Type: discovery.ResultTypeWorkload, + Node: &discovery.Location{Address: "1.2.3.4"}, + Type: discovery.ResultTypeWorkload, Tenancy: discovery.ResultTenancy{ Namespace: "bar", Partition: "baz", Datacenter: "dc3", }, - Target: "foo", + Service: &discovery.Location{Name: "foo"}, } fetcher.(*discovery.MockCatalogDataFetcher). diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index 9e021824ae..9a7e71d80f 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -6,6 +6,7 @@ package agent import ( "context" "fmt" + "net" "sort" "strings" "testing" @@ -20,7 +21,6 @@ import ( "github.com/hashicorp/consul/testrpc" ) -// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -77,7 +77,6 @@ func TestDNS_ServiceReverseLookup(t *testing.T) { } } -// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -134,7 +133,6 @@ func TestDNS_ServiceReverseLookup_IPV6(t *testing.T) { } } -// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -193,7 +191,6 @@ func TestDNS_ServiceReverseLookup_CustomDomain(t *testing.T) { } } -// TODO (v2-dns): requires PTR implementation func TestDNS_ServiceReverseLookupNodeAddress(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -379,6 +376,7 @@ func TestDNS_ServiceLookupPreferNoCNAME(t *testing.T) { } } +// TODO (v2-dns): requires additional recursion work func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -453,10 +451,20 @@ func TestDNS_ServiceLookupMultiAddrNoCNAME(t *testing.T) { in, _, err := c.Exchange(m, a.DNSAddr()) require.NoError(t, err) - // expect a CNAME and an A RR + // expect two A RRs require.Len(t, in.Answer, 2) require.IsType(t, &dns.A{}, in.Answer[0]) + require.Equal(t, "db.service.consul.", in.Answer[0].Header().Name) + isOneOfTheseIPs := func(ip net.IP) bool { + if ip.Equal(net.ParseIP("198.18.0.1")) || ip.Equal(net.ParseIP("198.18.0.3")) { + return true + } + return false + } + require.True(t, isOneOfTheseIPs(in.Answer[0].(*dns.A).A)) require.IsType(t, &dns.A{}, in.Answer[1]) + require.Equal(t, "db.service.consul.", in.Answer[1].Header().Name) + require.True(t, isOneOfTheseIPs(in.Answer[1].(*dns.A).A)) }) } } @@ -590,15 +598,13 @@ func TestDNS_ServiceLookup(t *testing.T) { } } -// TODO (v2-dns): this is formulating the correct response -// but failing with an I/O timeout on the dns client Exchange() call func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` node_name = "my.test-node" @@ -999,7 +1005,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` node_name = "test-node" @@ -1672,7 +1678,7 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } } -// TODO (v2-dns): this requires a prepared query +// TODO (v2-dns): this requires WAN translation work to be implemented func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1889,7 +1895,6 @@ func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { } } -// TODO (v2-dns): this requires a prepared query func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1913,7 +1918,7 @@ func TestDNS_CaseInsensitiveServiceLookup(t *testing.T) { } for _, tst := range tests { t.Run(fmt.Sprintf("A lookup %v", tst.name), func(t *testing.T) { - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, fmt.Sprintf("%s %s", tst.config, experimentsHCL)) defer a.Shutdown() @@ -2161,14 +2166,13 @@ func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { } } -// TODO (v2-dns): this requires a prepared query func TestDNS_ServiceLookup_Dedup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -2413,6 +2417,7 @@ func TestDNS_ServiceLookup_Dedup_SRV(t *testing.T) { } } +// TODO (v2-dns): this requires implementing health filtering func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2577,6 +2582,7 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { } } +// TODO (v2-dns): this requires implementing health filtering func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -2698,6 +2704,7 @@ func TestDNS_ServiceLookup_OnlyFailing(t *testing.T) { } } +// TODO (v2-dns): this requires implementing health filtering func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short")