diff --git a/agent/agent.go b/agent/agent.go index 781c36f4e0..98e1252cef 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1140,11 +1140,13 @@ func (a *Agent) listenAndServeV2DNS() error { // create server cfg := dns.Config{ - AgentConfig: a.config, - EntMeta: *a.AgentEnterpriseMeta(), - Logger: a.logger, - Processor: processor, - TokenFunc: a.getTokenFunc(), + AgentConfig: a.config, + EntMeta: *a.AgentEnterpriseMeta(), + Logger: a.logger, + Processor: processor, + TokenFunc: a.getTokenFunc(), + TranslateAddressFunc: a.TranslateAddress, + TranslateServiceAddressFunc: a.TranslateServiceAddress, } for _, addr := range a.config.DNSAddrs { diff --git a/agent/catalog_endpoint.go b/agent/catalog_endpoint.go index 1dac61befa..8af4654b90 100644 --- a/agent/catalog_endpoint.go +++ b/agent/catalog_endpoint.go @@ -13,6 +13,7 @@ import ( cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/internal/dnsutil" ) var CatalogCounters = []prometheus.CounterDefinition{ @@ -257,7 +258,7 @@ RETRY_ONCE: } out.ConsistencyLevel = args.QueryOptions.ConsistencyLevel() - s.agent.TranslateAddresses(args.Datacenter, out.Nodes, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(args.Datacenter, out.Nodes, dnsutil.TranslateAddressAcceptAny) // Use empty list instead of nil if out.Nodes == nil { @@ -403,7 +404,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R } out.ConsistencyLevel = args.QueryOptions.ConsistencyLevel() - s.agent.TranslateAddresses(args.Datacenter, out.ServiceNodes, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(args.Datacenter, out.ServiceNodes, dnsutil.TranslateAddressAcceptAny) // Use empty list instead of nil if out.ServiceNodes == nil { @@ -457,7 +458,7 @@ RETRY_ONCE: } out.ConsistencyLevel = args.QueryOptions.ConsistencyLevel() if out.NodeServices != nil { - s.agent.TranslateAddresses(args.Datacenter, out.NodeServices, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(args.Datacenter, out.NodeServices, dnsutil.TranslateAddressAcceptAny) } // TODO: The NodeServices object in IndexedNodeServices is a pointer to @@ -521,7 +522,7 @@ RETRY_ONCE: goto RETRY_ONCE } out.ConsistencyLevel = args.QueryOptions.ConsistencyLevel() - s.agent.TranslateAddresses(args.Datacenter, &out.NodeServices, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(args.Datacenter, &out.NodeServices, dnsutil.TranslateAddressAcceptAny) // Use empty list instead of nil for _, s := range out.NodeServices.Services { diff --git a/agent/discovery/discovery.go b/agent/discovery/discovery.go index 8dd5f62828..51557f9ec8 100644 --- a/agent/discovery/discovery.go +++ b/agent/discovery/discovery.go @@ -119,10 +119,18 @@ type Result struct { Tenancy ResultTenancy } -// Location is used to represent a service, node, or workload. -type Location struct { +// TaggedAddress is used to represent a tagged address. +type TaggedAddress struct { Name string Address string + Port Port +} + +// Location is used to represent a service, node, or workload. +type Location struct { + Name string + Address string + TaggedAddresses map[string]*TaggedAddress // Used to collect tagged addresses into A/AAAA Records } type DNSConfig struct { diff --git a/agent/discovery/query_fetcher_v1.go b/agent/discovery/query_fetcher_v1.go index 56ed3ce27d..3cf1654dec 100644 --- a/agent/discovery/query_fetcher_v1.go +++ b/agent/discovery/query_fetcher_v1.go @@ -131,11 +131,13 @@ func (f *V1DataFetcher) FetchNodes(ctx Context, req *QueryPayload) ([]*Result, e results = append(results, &Result{ Node: &Location{ - Name: n.Node, - Address: n.Address, + Name: n.Node, + Address: n.Address, + TaggedAddresses: makeTaggedAddressesFromStrings(n.TaggedAddresses), }, Type: ResultTypeNode, Metadata: n.Meta, + Tenancy: ResultTenancy{ // Namespace is not required because nodes are not namespaced Partition: n.GetEnterpriseMeta().PartitionOrDefault(), @@ -210,8 +212,9 @@ func (f *V1DataFetcher) FetchRecordsByIp(reqCtx Context, ip net.IP) ([]*Result, if targetIP == n.Address { results = append(results, &Result{ Node: &Location{ - Name: n.Node, - Address: n.Address, + Name: n.Node, + Address: n.Address, + TaggedAddresses: makeTaggedAddressesFromStrings(n.TaggedAddresses), }, Type: ResultTypeNode, Tenancy: ResultTenancy{ @@ -415,12 +418,14 @@ func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServic n := nodes[idx] results = append(results, &Result{ Service: &Location{ - Name: n.Service.Service, - Address: n.Service.Address, + Name: n.Service.Service, + Address: n.Service.Address, + TaggedAddresses: makeTaggedAddressesFromServiceAddresses(n.Service.TaggedAddresses), }, Node: &Location{ - Name: n.Node.Node, - Address: n.Node.Address, + Name: n.Node.Node, + Address: n.Node.Address, + TaggedAddresses: makeTaggedAddressesFromStrings(n.Node.TaggedAddresses), }, Type: ResultTypeService, DNS: DNSConfig{ @@ -442,6 +447,33 @@ func (f *V1DataFetcher) buildResultsFromServiceNodes(nodes []structs.CheckServic return results } +// makeTaggedAddressesFromServiceAddresses is used to convert a map of service addresses to a map of Locations. +func makeTaggedAddressesFromServiceAddresses(tagged map[string]structs.ServiceAddress) map[string]*TaggedAddress { + taggedAddresses := make(map[string]*TaggedAddress) + for k, v := range tagged { + taggedAddresses[k] = &TaggedAddress{ + Name: k, + Address: v.Address, + Port: Port{ + Number: uint32(v.Port), + }, + } + } + return taggedAddresses +} + +// makeTaggedAddressesFromStrings is used to convert a map of strings to a map of Locations. +func makeTaggedAddressesFromStrings(tagged map[string]string) map[string]*TaggedAddress { + taggedAddresses := make(map[string]*TaggedAddress) + for k, v := range tagged { + taggedAddresses[k] = &TaggedAddress{ + Name: k, + Address: v, + } + } + return taggedAddresses +} + // fetchNode is used to look up a node in the Consul catalog within NodeServices. // If the config is set to UseCache, it will get the record from the agent cache. func (f *V1DataFetcher) fetchNode(cfg *v1DataFetcherDynamicConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { diff --git a/agent/discovery/query_fetcher_v1_test.go b/agent/discovery/query_fetcher_v1_test.go index 61e0618936..de9e6f22b0 100644 --- a/agent/discovery/query_fetcher_v1_test.go +++ b/agent/discovery/query_fetcher_v1_test.go @@ -140,12 +140,14 @@ func Test_FetchEndpoints(t *testing.T) { expectedResults := []*Result{ { Node: &Location{ - Name: "node-name", - Address: "node-address", + Name: "node-name", + Address: "node-address", + TaggedAddresses: map[string]*TaggedAddress{}, }, Service: &Location{ - Name: "service-name", - Address: "service-address", + Name: "service-name", + Address: "service-address", + TaggedAddresses: map[string]*TaggedAddress{}, }, Type: ResultTypeService, DNS: DNSConfig{ diff --git a/agent/dns.go b/agent/dns.go index 1bbf485fcb..c07fd12175 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" + dnsutil "github.com/hashicorp/consul/internal/dnsutil" libdns "github.com/hashicorp/consul/internal/dnsutil" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/lib" @@ -1801,13 +1802,13 @@ func makeARecord(qType uint16, ip net.IP, ttl time.Duration) dns.RR { // In case of an SRV query the answer will be a IN SRV and additional data will store an IN A to the node IP // Otherwise it will return a IN A record func (d *DNSServer) makeRecordFromNode(node *structs.Node, qType uint16, qName string, ttl time.Duration, maxRecursionLevel int) []dns.RR { - addrTranslate := TranslateAddressAcceptDomain + addrTranslate := dnsutil.TranslateAddressAcceptDomain if qType == dns.TypeA { - addrTranslate |= TranslateAddressAcceptIPv4 + addrTranslate |= dnsutil.TranslateAddressAcceptIPv4 } else if qType == dns.TypeAAAA { - addrTranslate |= TranslateAddressAcceptIPv6 + addrTranslate |= dnsutil.TranslateAddressAcceptIPv6 } else { - addrTranslate |= TranslateAddressAcceptAny + addrTranslate |= dnsutil.TranslateAddressAcceptAny } addr := d.agent.TranslateAddress(node.Datacenter, node.Address, node.TaggedAddresses, addrTranslate) @@ -1973,13 +1974,13 @@ MORE_REC: // Craft dns records from a CheckServiceNode struct func (d *DNSServer) makeNodeServiceRecords(lookup serviceLookup, node structs.CheckServiceNode, req *dns.Msg, ttl time.Duration, cfg *dnsConfig, maxRecursionLevel int) ([]dns.RR, []dns.RR) { - addrTranslate := TranslateAddressAcceptDomain + addrTranslate := dnsutil.TranslateAddressAcceptDomain if req.Question[0].Qtype == dns.TypeA { - addrTranslate |= TranslateAddressAcceptIPv4 + addrTranslate |= dnsutil.TranslateAddressAcceptIPv4 } else if req.Question[0].Qtype == dns.TypeAAAA { - addrTranslate |= TranslateAddressAcceptIPv6 + addrTranslate |= dnsutil.TranslateAddressAcceptIPv6 } else { - addrTranslate |= TranslateAddressAcceptAny + addrTranslate |= dnsutil.TranslateAddressAcceptAny } // The datacenter should be empty during translation if it is a peering lookup. @@ -2055,7 +2056,7 @@ func (d *DNSServer) addServiceSRVRecordsToMessage(cfg *dnsConfig, lookup service // The datacenter should be empty during translation if it is a peering lookup. // This should be fine because we should always prefer the WAN address. - serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny) + serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, dnsutil.TranslateAddressAcceptAny) servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses) tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort) if _, ok := handled[tuple]; ok { diff --git a/agent/dns/router.go b/agent/dns/router.go index c70648d3ee..538e7c9111 100644 --- a/agent/dns/router.go +++ b/agent/dns/router.go @@ -107,7 +107,9 @@ type Router struct { datacenter string logger hclog.Logger - tokenFunc func() string + tokenFunc func() string + translateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string + translateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string // dynamicConfig stores the config as an atomic value (for hot-reloading). // It is always of type *RouterDynamicConfig @@ -127,13 +129,15 @@ func NewRouter(cfg Config) (*Router, error) { logger := cfg.Logger.Named(logging.DNS) router := &Router{ - processor: cfg.Processor, - recursor: newRecursor(logger), - domain: domain, - altDomain: altDomain, - datacenter: cfg.AgentConfig.Datacenter, - logger: logger, - tokenFunc: cfg.TokenFunc, + processor: cfg.Processor, + recursor: newRecursor(logger), + domain: domain, + altDomain: altDomain, + datacenter: cfg.AgentConfig.Datacenter, + logger: logger, + tokenFunc: cfg.TokenFunc, + translateAddressFunc: cfg.TranslateAddressFunc, + translateServiceAddressFunc: cfg.TranslateServiceAddressFunc, } if err := router.ReloadConfig(cfg.AgentConfig); err != nil { @@ -526,9 +530,6 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context, // The datacenter should be empty during translation if it is a peering lookup. // This should be fine because we should always prefer the WAN address. - //serviceAddress := d.agent.TranslateServiceAddress(lookup.Datacenter, node.Service.Address, node.Service.TaggedAddresses, TranslateAddressAcceptAny) - //servicePort := d.agent.TranslateServicePort(lookup.Datacenter, node.Service.Port, node.Service.TaggedAddresses) - //tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, serviceAddress, servicePort) // TODO (v2-dns): this needs a clean up so we're not assuming this everywhere. address := "" @@ -554,13 +555,35 @@ func (r *Router) serializeQueryResults(req *dns.Msg, reqCtx Context, r.appendResultsToDNSResponse(req, reqCtx, query, resp, results, cfg, responseDomain, remoteAddress, maxRecursionLevel) } - if len(resp.Answer) == 0 && len(resp.Extra) == 0 { + if query != nil && query.QueryType != discovery.QueryTypeVirtual && + len(resp.Answer) == 0 && len(resp.Extra) == 0 { return nil, discovery.ErrNoData } return resp, nil } +// getServiceAddressMapFromLocationMap converts a map of Location to a map of ServiceAddress. +func getServiceAddressMapFromLocationMap(taggedAddresses map[string]*discovery.TaggedAddress) map[string]structs.ServiceAddress { + taggedServiceAddresses := make(map[string]structs.ServiceAddress, len(taggedAddresses)) + for k, v := range taggedAddresses { + taggedServiceAddresses[k] = structs.ServiceAddress{ + Address: v.Address, + Port: int(v.Port.Number), + } + } + return taggedServiceAddresses +} + +// getStringAddressMapFromTaggedAddressMap converts a map of Location to a map of string. +func getStringAddressMapFromTaggedAddressMap(taggedAddresses map[string]*discovery.TaggedAddress) map[string]string { + taggedServiceAddresses := make(map[string]string, len(taggedAddresses)) + for k, v := range taggedAddresses { + taggedServiceAddresses[k] = v.Address + } + return taggedServiceAddresses +} + // appendResultsToDNSResponse builds dns message from the discovery results and // appends them to the dns response. func (r *Router) appendResultsToDNSResponse(req *dns.Msg, reqCtx Context, @@ -906,15 +929,7 @@ func buildAddressResults(req *dns.Msg) ([]*discovery.Result, error) { func (r *Router) getAnswerExtraAndNs(result *discovery.Result, port discovery.Port, req *dns.Msg, reqCtx Context, query *discovery.Query, cfg *RouterDynamicConfig, domain string, remoteAddress net.Addr, maxRecursionLevel int) (answer []dns.RR, extra []dns.RR, ns []dns.RR) { - - serviceAddress := newDNSAddress("") - if result.Service != nil { - serviceAddress = newDNSAddress(result.Service.Address) - } - nodeAddress := newDNSAddress("") - if result.Node != nil { - nodeAddress = newDNSAddress(result.Node.Address) - } + serviceAddress, nodeAddress := r.getServiceAndNodeAddresses(result, req) qName := req.Question[0].Name ttlLookupName := qName if query != nil { @@ -983,6 +998,35 @@ func (r *Router) getAnswerExtraAndNs(result *discovery.Result, port discovery.Po return } +// getServiceAndNodeAddresses returns the service and node addresses from a discovery result. +func (r *Router) getServiceAndNodeAddresses(result *discovery.Result, req *dns.Msg) (*dnsAddress, *dnsAddress) { + addrTranslate := dnsutil.TranslateAddressAcceptDomain + if req.Question[0].Qtype == dns.TypeA { + addrTranslate |= dnsutil.TranslateAddressAcceptIPv4 + } else if req.Question[0].Qtype == dns.TypeAAAA { + addrTranslate |= dnsutil.TranslateAddressAcceptIPv6 + } else { + addrTranslate |= dnsutil.TranslateAddressAcceptAny + } + + // The datacenter should be empty during translation if it is a peering lookup. + // This should be fine because we should always prefer the WAN address. + serviceAddress := newDNSAddress("") + if result.Service != nil { + sa := r.translateServiceAddressFunc(result.Tenancy.Datacenter, + result.Service.Address, getServiceAddressMapFromLocationMap(result.Service.TaggedAddresses), + addrTranslate) + serviceAddress = newDNSAddress(sa) + } + nodeAddress := newDNSAddress("") + if result.Node != nil { + na := r.translateAddressFunc(result.Tenancy.Datacenter, result.Node.Address, + getStringAddressMapFromTaggedAddressMap(result.Node.TaggedAddresses), addrTranslate) + nodeAddress = newDNSAddress(na) + } + return serviceAddress, nodeAddress +} + // getAnswerExtrasForAddressAndTarget creates the dns answer and extra from nodeAddress and serviceAddress dnsAddress pairs. func (r *Router) getAnswerExtrasForAddressAndTarget(nodeAddress *dnsAddress, serviceAddress *dnsAddress, req *dns.Msg, reqCtx Context, result *discovery.Result, port discovery.Port, ttl uint32, remoteAddress net.Addr, @@ -1119,7 +1163,7 @@ func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, qType := question.Qtype canReturnARecord := qType == dns.TypeSRV || qType == dns.TypeA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT canReturnAAAARecord := qType == dns.TypeSRV || qType == dns.TypeAAAA || qType == dns.TypeANY || qType == dns.TypeNS || qType == dns.TypeTXT - if reqType != requestTypeAddress { + if reqType != requestTypeAddress && result.Type != discovery.ResultTypeVirtual { switch { // check IPV4 case addr.IsIP() && addr.IsIPV4() && !canReturnARecord, @@ -1143,8 +1187,7 @@ func getAnswerExtrasForIP(name string, addr *dnsAddress, question dns.Question, } if reqType != requestTypeAddress && qType == dns.TypeSRV { - if result.Type == discovery.ResultTypeService && addr.IsIP() && result.Service. - Address == addr.String() { + if result.Type == discovery.ResultTypeService && addr.IsIP() && result.Node.Address != addr.String() { // encode the ip to be used in the header of the A/AAAA record // as well as the target of the SRV record. recHdrName = encodeIPAsFqdn(result, addr.IP(), domain) diff --git a/agent/dns/router_query.go b/agent/dns/router_query.go index ca1056ef65..bbcbca6698 100644 --- a/agent/dns/router_query.go +++ b/agent/dns/router_query.go @@ -70,6 +70,24 @@ func getQueryNameAndTagFromParts(queryType discovery.QueryType, queryParts []str return name, tag } return queryParts[n-1], "" + case discovery.QueryTypePreparedQuery: + name := "" + + // If the first and last DNS query parts begin with _, this is an RFC 2782 style SRV lookup. + // This allows for prepared query names to include "." (for backwards compatibility). + // Otherwise, this is a standard prepared query lookup. + if n >= 2 && strings.HasPrefix(queryParts[0], "_") && strings.HasPrefix(queryParts[n-1], "_") { + // The last DNS query part is the protocol field (ignored). + // All prior parts are the prepared query name or ID. + name = strings.Join(queryParts[:n-1], ".") + + // Strip leading underscore + name = name[1:] + } else { + // Allow a "." in the query name, just join all the parts. + name = strings.Join(queryParts, ".") + } + return name, "" } return queryParts[n-1], "" } diff --git a/agent/dns/router_service_question_test.go b/agent/dns/router_service_question_test.go new file mode 100644 index 0000000000..76fce89107 --- /dev/null +++ b/agent/dns/router_service_question_test.go @@ -0,0 +1,169 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package dns + +import ( + "net" + "testing" + "time" + + "github.com/hashicorp/consul/agent/discovery" + "github.com/miekg/dns" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_HandleRequest_ServiceQuestions(t *testing.T) { + testCases := []HandleTestCase{ + // Service Lookup + { + name: "When no data is return from a query, send SOA", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "foo.service.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). + Return(nil, discovery.ErrNoData). + Run(func(args mock.Arguments) { + req := args.Get(1).(*discovery.QueryPayload) + reqType := args.Get(2).(discovery.LookupType) + + require.Equal(t, discovery.LookupTypeService, reqType) + require.Equal(t, "foo", req.Name) + }) + }, + validateAndNormalizeExpected: true, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + Response: true, + Authoritative: true, + Rcode: dns.RcodeSuccess, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "foo.service.consul.", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, + }, + Ns: []dns.RR{ + &dns.SOA{ + Hdr: dns.RR_Header{ + Name: "consul.", + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 4, + }, + Ns: "ns.consul.", + Serial: uint32(time.Now().Unix()), + Mbox: "hostmaster.consul.", + Refresh: 1, + Expire: 3, + Retry: 2, + Minttl: 4, + }, + }, + }, + }, + { + // TestDNS_ExternalServiceToConsulCNAMELookup + name: "req type: service / question type: SRV / CNAME required: no", + request: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + }, + configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { + fetcher.(*discovery.MockCatalogDataFetcher). + On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "alias", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeVirtual, + Service: &discovery.Location{Name: "alias", Address: "web.service.consul"}, + Node: &discovery.Location{Name: "web", Address: "web.service.consul"}, + }, + }, + nil).On("FetchEndpoints", mock.Anything, + &discovery.QueryPayload{ + Name: "web", + Tenancy: discovery.QueryTenancy{}, + }, discovery.LookupTypeService). + Return([]*discovery.Result{ + { + Type: discovery.ResultTypeNode, + Service: &discovery.Location{Name: "web", Address: "webnode"}, + Node: &discovery.Location{Name: "webnode", Address: "127.0.0.2"}, + }, + }, nil).On("ValidateRequest", mock.Anything, + mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything) + }, + response: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: true, + }, + Compress: true, + Question: []dns.Question{ + { + Name: "alias.service.consul.", + Qtype: dns.TypeSRV, + }, + }, + Answer: []dns.RR{ + &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "alias.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 123, + }, + Target: "web.service.consul.", + Priority: 1, + }, + }, + Extra: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "web.service.consul.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 123, + }, + A: net.ParseIP("127.0.0.2"), + }, + }, + }, + }, + } + + testCases = append(testCases, getAdditionalTestCases(t)...) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + runHandleTestCases(t, tc) + }) + } +} diff --git a/agent/dns/router_test.go b/agent/dns/router_test.go index 93fd789bc0..f5014d6c88 100644 --- a/agent/dns/router_test.go +++ b/agent/dns/router_test.go @@ -6,6 +6,7 @@ package dns import ( "errors" "fmt" + "github.com/hashicorp/consul/internal/dnsutil" "net" "reflect" "testing" @@ -37,23 +38,23 @@ type HandleTestCase struct { response *dns.Msg } -func Test_HandleRequest(t *testing.T) { - soa := &dns.SOA{ - Hdr: dns.RR_Header{ - Name: "consul.", - Rrtype: dns.TypeSOA, - Class: dns.ClassINET, - Ttl: 4, - }, - Ns: "ns.consul.", - Mbox: "hostmaster.consul.", - Serial: uint32(time.Now().Unix()), - Refresh: 1, - Retry: 2, - Expire: 3, - Minttl: 4, - } +var testSOA = &dns.SOA{ + Hdr: dns.RR_Header{ + Name: "consul.", + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 4, + }, + Ns: "ns.consul.", + Mbox: "hostmaster.consul.", + Serial: uint32(time.Now().Unix()), + Refresh: 1, + Retry: 2, + Expire: 3, + Minttl: 4, +} +func Test_HandleRequest(t *testing.T) { testCases := []HandleTestCase{ // recursor queries { @@ -800,7 +801,17 @@ func Test_HandleRequest(t *testing.T) { Qclass: dns.ClassINET, }, }, - Ns: []dns.RR{soa}, + Extra: []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "20010db800010002cafe000000001337.virtual.dc1.consul.", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 123, + }, + AAAA: net.ParseIP("2001:db8:1:2:cafe::1337"), + }, + }, }, }, // SOA Queries @@ -1456,158 +1467,7 @@ func Test_HandleRequest(t *testing.T) { }, }, }, - // Service Lookup - { - name: "When no data is return from a query, send SOA", - request: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Opcode: dns.OpcodeQuery, - }, - Question: []dns.Question{ - { - Name: "foo.service.consul.", - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }, - }, - }, - configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { - fetcher.(*discovery.MockCatalogDataFetcher). - On("FetchEndpoints", mock.Anything, mock.Anything, mock.Anything). - Return(nil, discovery.ErrNoData). - Run(func(args mock.Arguments) { - req := args.Get(1).(*discovery.QueryPayload) - reqType := args.Get(2).(discovery.LookupType) - - require.Equal(t, discovery.LookupTypeService, reqType) - require.Equal(t, "foo", req.Name) - }) - }, - validateAndNormalizeExpected: true, - response: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Opcode: dns.OpcodeQuery, - Response: true, - Authoritative: true, - Rcode: dns.RcodeSuccess, - }, - Compress: true, - Question: []dns.Question{ - { - Name: "foo.service.consul.", - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }, - }, - Ns: []dns.RR{ - &dns.SOA{ - Hdr: dns.RR_Header{ - Name: "consul.", - Rrtype: dns.TypeSOA, - Class: dns.ClassINET, - Ttl: 4, - }, - Ns: "ns.consul.", - Serial: uint32(time.Now().Unix()), - Mbox: "hostmaster.consul.", - Refresh: 1, - Expire: 3, - Retry: 2, - Minttl: 4, - }, - }, - }, - }, - { - // TestDNS_ExternalServiceToConsulCNAMELookup - name: "req type: service / question type: SRV / CNAME required: no", - request: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Opcode: dns.OpcodeQuery, - }, - Question: []dns.Question{ - { - Name: "alias.service.consul.", - Qtype: dns.TypeSRV, - }, - }, - }, - configureDataFetcher: func(fetcher discovery.CatalogDataFetcher) { - fetcher.(*discovery.MockCatalogDataFetcher). - On("FetchEndpoints", mock.Anything, - &discovery.QueryPayload{ - Name: "alias", - Tenancy: discovery.QueryTenancy{}, - }, discovery.LookupTypeService). - Return([]*discovery.Result{ - { - Type: discovery.ResultTypeVirtual, - Service: &discovery.Location{Name: "alias", Address: "web.service.consul"}, - Node: &discovery.Location{Name: "web", Address: "web.service.consul"}, - Ports: []discovery.Port{ - { - Number: 1234, - }, - }, - }, - }, - nil).On("FetchEndpoints", mock.Anything, - &discovery.QueryPayload{ - Name: "web", - Tenancy: discovery.QueryTenancy{}, - }, discovery.LookupTypeService). - Return([]*discovery.Result{ - { - Type: discovery.ResultTypeNode, - Service: &discovery.Location{Name: "web", Address: "webnode"}, - Node: &discovery.Location{Name: "webnode", Address: "127.0.0.2"}, - Ports: []discovery.Port{ - { - Number: 1234, - }, - }, - }, - }, nil).On("ValidateRequest", mock.Anything, - mock.Anything).Return(nil).On("NormalizeRequest", mock.Anything) - }, - response: &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Response: true, - Authoritative: true, - }, - Compress: true, - Question: []dns.Question{ - { - Name: "alias.service.consul.", - Qtype: dns.TypeSRV, - }, - }, - Answer: []dns.RR{ - &dns.SRV{ - Hdr: dns.RR_Header{ - Name: "alias.service.consul.", - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 123, - }, - Target: "web.service.consul.", - Priority: 1, - Port: 1234, - }, - }, - Extra: []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{ - Name: "web.service.consul.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 123, - }, - A: net.ParseIP("127.0.0.2"), - }, - }, - }, - }, + // TODO (v2-dns): add a test to make sure only 3 records are returned // V2 Workload Lookup { name: "workload A query w/ port, returns A record", @@ -2851,41 +2711,40 @@ func Test_HandleRequest(t *testing.T) { testCases = append(testCases, getAdditionalTestCases(t)...) - run := func(t *testing.T, tc HandleTestCase) { - cdf := discovery.NewMockCatalogDataFetcher(t) - if tc.validateAndNormalizeExpected { - cdf.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) - cdf.On("NormalizeRequest", mock.Anything).Return() - } - - if tc.configureDataFetcher != nil { - tc.configureDataFetcher(cdf) - } - cfg := buildDNSConfig(tc.agentConfig, cdf, tc.mockProcessorError) - - router, err := NewRouter(cfg) - require.NoError(t, err) - - // Replace the recursor with a mock and configure - router.recursor = newMockDnsRecursor(t) - if tc.configureRecursor != nil { - tc.configureRecursor(router.recursor) - } - - ctx := tc.requestContext - if ctx == nil { - ctx = &Context{} - } - actual := router.HandleRequest(tc.request, *ctx, tc.remoteAddress) - require.Equal(t, tc.response, actual) - } - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - run(t, tc) + runHandleTestCases(t, tc) }) } +} +func runHandleTestCases(t *testing.T, tc HandleTestCase) { + cdf := discovery.NewMockCatalogDataFetcher(t) + if tc.validateAndNormalizeExpected { + cdf.On("ValidateRequest", mock.Anything, mock.Anything).Return(nil) + cdf.On("NormalizeRequest", mock.Anything).Return() + } + + if tc.configureDataFetcher != nil { + tc.configureDataFetcher(cdf) + } + cfg := buildDNSConfig(tc.agentConfig, cdf, tc.mockProcessorError) + + router, err := NewRouter(cfg) + require.NoError(t, err) + + // Replace the recursor with a mock and configure + router.recursor = newMockDnsRecursor(t) + if tc.configureRecursor != nil { + tc.configureRecursor(router.recursor) + } + + ctx := tc.requestContext + if ctx == nil { + ctx = &Context{} + } + actual := router.HandleRequest(tc.request, *ctx, tc.remoteAddress) + require.Equal(t, tc.response, actual) } func TestRouterDynamicConfig_GetTTLForService(t *testing.T) { @@ -2957,6 +2816,12 @@ func buildDNSConfig(agentConfig *config.RuntimeConfig, cdf discovery.CatalogData Logger: hclog.NewNullLogger(), Processor: discovery.NewQueryProcessor(cdf), TokenFunc: func() string { return "" }, + TranslateServiceAddressFunc: func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string { + return address + }, + TranslateAddressFunc: func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string { + return addr + }, } if agentConfig != nil { diff --git a/agent/dns/server.go b/agent/dns/server.go index 9508e34159..74da3fa663 100644 --- a/agent/dns/server.go +++ b/agent/dns/server.go @@ -5,6 +5,8 @@ package dns import ( "fmt" + "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/internal/dnsutil" "net" "github.com/miekg/dns" @@ -36,11 +38,13 @@ type Server struct { // Config represent all the DNS configuration required to construct a DNS server. type Config struct { - AgentConfig *config.RuntimeConfig - EntMeta acl.EnterpriseMeta - Logger hclog.Logger - Processor DiscoveryQueryProcessor - TokenFunc func() string + AgentConfig *config.RuntimeConfig + EntMeta acl.EnterpriseMeta + Logger hclog.Logger + Processor DiscoveryQueryProcessor + TokenFunc func() string + TranslateAddressFunc func(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string + TranslateServiceAddressFunc func(dc string, address string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string } // NewServer creates a new DNS server. diff --git a/agent/dns_node_lookup_test.go b/agent/dns_node_lookup_test.go index 1e5b35ebb5..3e8f6b9c81 100644 --- a/agent/dns_node_lookup_test.go +++ b/agent/dns_node_lookup_test.go @@ -14,14 +14,12 @@ import ( "github.com/hashicorp/consul/testrpc" ) -// TODO (v2-dns): Failing on "lookup a non-existing node, we should receive a SOA" -// it is coming back empty. func TestDNS_NodeLookup(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() diff --git a/agent/dns_service_lookup_test.go b/agent/dns_service_lookup_test.go index fbca1bb1fb..4243a69014 100644 --- a/agent/dns_service_lookup_test.go +++ b/agent/dns_service_lookup_test.go @@ -1338,7 +1338,6 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddress_A(t *testing.T) { } } -// TODO (v2-dns): NET-7632 - Fix node and prepared query lookups when question name has a period in it func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -1352,7 +1351,7 @@ func TestDNS_ServiceLookup_ServiceAddress_SRV(t *testing.T) { }) defer recursor.Shutdown() - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, ` recursors = ["`+recursor.Addr+`"] @@ -1666,13 +1665,12 @@ func TestDNS_AltDomain_ServiceLookup_ServiceAddressIPV6(t *testing.T) { } } -// TODO (v2-dns): NET-7634 - Implement WAN translation func TestDNS_ServiceLookup_WanTranslation(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a1 := NewTestAgent(t, ` datacenter = "dc1" @@ -2063,13 +2061,12 @@ func TestDNS_ServiceLookup_TagPeriod(t *testing.T) { } } -// TODO (v2-dns): NET-7632 - Fix node and prepared query lookups when question name has a period in it. func TestDNS_ServiceLookup_PreparedQueryNamePeriod(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -3297,7 +3294,6 @@ func checkDNSService( } } -// TODO (v2-dns): NET-7633 - implement answer limits. func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/dns_test.go b/agent/dns_test.go index 17cd4c9f72..22385df421 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1008,13 +1008,12 @@ func TestDNS_AltDomain_NSRecords_IPV6(t *testing.T) { } } -// TODO NET-7644 - Implement service and prepared query lookup for tagged addresses func TestDNS_Lookup_TaggedIPAddresses(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -1222,7 +1221,7 @@ func TestDNS_PreparedQueryNearIPEDNS(t *testing.T) { {"foo3", "198.18.0.3", lib.GenerateCoordinate(30 * time.Millisecond)}, } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -1356,7 +1355,7 @@ func TestDNS_PreparedQueryNearIP(t *testing.T) { {"foo3", "198.18.0.3", lib.GenerateCoordinate(30 * time.Millisecond)}, } - for name, experimentsHCL := range getVersionHCL(false) { + for name, experimentsHCL := range getVersionHCL(true) { t.Run(name, func(t *testing.T) { a := NewTestAgent(t, experimentsHCL) defer a.Shutdown() @@ -2149,25 +2148,27 @@ func TestDNS_NonExistentLookupEmptyAorAAAA(t *testing.T) { "webv4.query.consul.", } for _, question := range questions { - m := new(dns.Msg) - m.SetQuestion(question, dns.TypeAAAA) + t.Run(question, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeAAAA) - c := new(dns.Client) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } - require.Len(t, in.Ns, 1) - soaRec, ok := in.Ns[0].(*dns.SOA) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[0]) - } - if soaRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[0]) - } + require.Len(t, in.Ns, 1) + soaRec, ok := in.Ns[0].(*dns.SOA) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[0]) + } + if soaRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[0]) + } - require.Equal(t, dns.RcodeSuccess, in.Rcode) + require.Equal(t, dns.RcodeSuccess, in.Rcode) + }) } // Check for ipv4 records on ipv6-only service directly and via the @@ -2177,30 +2178,32 @@ func TestDNS_NonExistentLookupEmptyAorAAAA(t *testing.T) { "webv6.query.consul.", } for _, question := range questions { - m := new(dns.Msg) - m.SetQuestion(question, dns.TypeA) + t.Run(question, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeA) - c := new(dns.Client) - in, _, err := c.Exchange(m, a.DNSAddr()) - if err != nil { - t.Fatalf("err: %v", err) - } + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } - if len(in.Ns) != 1 { - t.Fatalf("Bad: %#v", in) - } + if len(in.Ns) != 1 { + t.Fatalf("Bad: %#v", in) + } - soaRec, ok := in.Ns[0].(*dns.SOA) - if !ok { - t.Fatalf("Bad: %#v", in.Ns[0]) - } - if soaRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[0]) - } + soaRec, ok := in.Ns[0].(*dns.SOA) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[0]) + } + if soaRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[0]) + } - if in.Rcode != dns.RcodeSuccess { - t.Fatalf("Bad: %#v", in) - } + if in.Rcode != dns.RcodeSuccess { + t.Fatalf("Bad: %#v", in) + } + }) } }) } diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index 3b888988d2..1ce464d91c 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/internal/dnsutil" ) const ( @@ -243,7 +244,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re } // Translate addresses after filtering so we don't waste effort. - s.agent.TranslateAddresses(args.Datacenter, out.Nodes, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(args.Datacenter, out.Nodes, dnsutil.TranslateAddressAcceptAny) // Use empty list instead of nil if out.Nodes == nil { diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index 15ab1005e4..8a3f1f038e 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -11,6 +11,7 @@ import ( cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/internal/dnsutil" ) // preparedQueryCreateResponse is used to wrap the query ID. @@ -162,7 +163,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter, // a query can fail over to a different DC than where the execute request // was sent to. That's why we use the reply's DC and not the one from // the args. - s.agent.TranslateAddresses(reply.Datacenter, reply.Nodes, TranslateAddressAcceptAny) + s.agent.TranslateAddresses(reply.Datacenter, reply.Nodes, dnsutil.TranslateAddressAcceptAny) // Use empty list instead of nil. if reply.Nodes == nil { diff --git a/agent/translate_addr.go b/agent/translate_addr.go index 1c0f8a4003..326117cc98 100644 --- a/agent/translate_addr.go +++ b/agent/translate_addr.go @@ -5,21 +5,12 @@ package agent import ( "fmt" + "github.com/hashicorp/consul/internal/dnsutil" "net" "github.com/hashicorp/consul/agent/structs" ) -type TranslateAddressAccept int - -const ( - TranslateAddressAcceptDomain TranslateAddressAccept = 1 << iota - TranslateAddressAcceptIPv4 - TranslateAddressAcceptIPv6 - - TranslateAddressAcceptAny TranslateAddressAccept = ^0 -) - // TranslateServicePort is used to provide the final, translated port for a service, // depending on how the agent and the other node are configured. The dc // parameter is the dc the datacenter this node is from. @@ -35,7 +26,7 @@ func (a *Agent) TranslateServicePort(dc string, port int, taggedAddresses map[st // TranslateServiceAddress is used to provide the final, translated address for a node, // depending on how the agent and the other node are configured. The dc // parameter is the dc the datacenter this node is from. -func (a *Agent) TranslateServiceAddress(dc string, addr string, taggedAddresses map[string]structs.ServiceAddress, accept TranslateAddressAccept) string { +func (a *Agent) TranslateServiceAddress(dc string, addr string, taggedAddresses map[string]structs.ServiceAddress, accept dnsutil.TranslateAddressAccept) string { def := addr v4 := taggedAddresses[structs.TaggedAddressLANIPv4].Address v6 := taggedAddresses[structs.TaggedAddressLANIPv6].Address @@ -59,7 +50,7 @@ func (a *Agent) TranslateServiceAddress(dc string, addr string, taggedAddresses // TranslateAddress is used to provide the final, translated address for a node, // depending on how the agent and the other node are configured. The dc // parameter is the dc the datacenter this node is from. -func (a *Agent) TranslateAddress(dc string, addr string, taggedAddresses map[string]string, accept TranslateAddressAccept) string { +func (a *Agent) TranslateAddress(dc string, addr string, taggedAddresses map[string]string, accept dnsutil.TranslateAddressAccept) string { def := addr v4 := taggedAddresses[structs.TaggedAddressLANIPv4] v6 := taggedAddresses[structs.TaggedAddressLANIPv6] @@ -80,22 +71,22 @@ func (a *Agent) TranslateAddress(dc string, addr string, taggedAddresses map[str return translateAddressAccept(accept, def, v4, v6) } -func translateAddressAccept(accept TranslateAddressAccept, def, v4, v6 string) string { +func translateAddressAccept(accept dnsutil.TranslateAddressAccept, def, v4, v6 string) string { switch { - case accept&TranslateAddressAcceptIPv6 > 0 && v6 != "": + case accept&dnsutil.TranslateAddressAcceptIPv6 > 0 && v6 != "": return v6 - case accept&TranslateAddressAcceptIPv4 > 0 && v4 != "": + case accept&dnsutil.TranslateAddressAcceptIPv4 > 0 && v4 != "": return v4 - case accept&TranslateAddressAcceptAny > 0 && def != "": + case accept&dnsutil.TranslateAddressAcceptAny > 0 && def != "": return def default: defIP := net.ParseIP(def) switch { - case defIP != nil && defIP.To4() != nil && accept&TranslateAddressAcceptIPv4 > 0: + case defIP != nil && defIP.To4() != nil && accept&dnsutil.TranslateAddressAcceptIPv4 > 0: return def - case defIP != nil && defIP.To4() == nil && accept&TranslateAddressAcceptIPv6 > 0: + case defIP != nil && defIP.To4() == nil && accept&dnsutil.TranslateAddressAcceptIPv6 > 0: return def - case defIP == nil && accept&TranslateAddressAcceptDomain > 0: + case defIP == nil && accept&dnsutil.TranslateAddressAcceptDomain > 0: return def } } @@ -106,7 +97,7 @@ func translateAddressAccept(accept TranslateAddressAccept, def, v4, v6 string) s // TranslateAddresses translates addresses in the given structure into the // final, translated address, depending on how the agent and the other node are // configured. The dc parameter is the datacenter this structure is from. -func (a *Agent) TranslateAddresses(dc string, subj interface{}, accept TranslateAddressAccept) { +func (a *Agent) TranslateAddresses(dc string, subj interface{}, accept dnsutil.TranslateAddressAccept) { // CAUTION - SUBTLE! An agent running on a server can, in some cases, // return pointers directly into the immutable state store for // performance (it's via the in-memory RPC mechanism). It's never safe diff --git a/internal/dnsutil/dns.go b/internal/dnsutil/dns.go index 7d6877091c..07c7306091 100644 --- a/internal/dnsutil/dns.go +++ b/internal/dnsutil/dns.go @@ -13,6 +13,8 @@ import ( "github.com/miekg/dns" ) +type TranslateAddressAccept int + // MaxLabelLength is the maximum length for a name that can be used in DNS. const ( MaxLabelLength = 63 @@ -20,6 +22,12 @@ const ( arpaLabel = "arpa" arpaIPV4Label = "in-addr" arpaIPV6Label = "ip6" + + TranslateAddressAcceptDomain TranslateAddressAccept = 1 << iota + TranslateAddressAcceptIPv4 + TranslateAddressAcceptIPv6 + + TranslateAddressAcceptAny TranslateAddressAccept = ^0 ) // InvalidNameRe is a regex that matches characters which can not be included in