diff --git a/.changelog/10009.txt b/.changelog/10009.txt new file mode 100644 index 0000000000..44f7174f51 --- /dev/null +++ b/.changelog/10009.txt @@ -0,0 +1,3 @@ +```release-note:bug +dns: fixes a bug with edns truncation where the response could exceed the size limit in some cases. +``` diff --git a/agent/dns.go b/agent/dns.go index 5e5dcbb6a5..69c132cbd6 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -3,6 +3,7 @@ package agent import ( "context" "encoding/hex" + "errors" "fmt" "net" "regexp" @@ -13,6 +14,9 @@ import ( metrics "github.com/armon/go-metrics" radix "github.com/armon/go-radix" "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/hashicorp/go-hclog" + "github.com/miekg/dns" + cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" agentdns "github.com/hashicorp/consul/agent/dns" @@ -21,8 +25,6 @@ import ( "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/logging" - "github.com/hashicorp/go-hclog" - "github.com/miekg/dns" ) const ( @@ -74,7 +76,6 @@ type dnsConfig struct { } type serviceLookup struct { - Network string Datacenter string Service string Tag string @@ -252,34 +253,36 @@ func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error { // possibly the ECS headers as well if they were present in the // original request func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) { - // Enable EDNS if enabled - if edns := request.IsEdns0(); edns != nil { - // cannot just use the SetEdns0 function as we need to embed - // the ECS option as well - ednsResp := new(dns.OPT) - ednsResp.Hdr.Name = "." - ednsResp.Hdr.Rrtype = dns.TypeOPT - ednsResp.SetUDPSize(edns.UDPSize()) - - // Setup the ECS option if present - if subnet := ednsSubnetForRequest(request); subnet != nil { - subOp := new(dns.EDNS0_SUBNET) - subOp.Code = dns.EDNS0SUBNET - subOp.Family = subnet.Family - subOp.Address = subnet.Address - subOp.SourceNetmask = subnet.SourceNetmask - if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented { - // reply is globally valid and should be cached accordingly - subOp.SourceScope = 0 - } else { - // reply is only valid for the subnet it was queried with - subOp.SourceScope = subnet.SourceNetmask - } - ednsResp.Option = append(ednsResp.Option, subOp) - } - - response.Extra = append(response.Extra, ednsResp) + edns := request.IsEdns0() + if edns == nil { + return } + + // cannot just use the SetEdns0 function as we need to embed + // the ECS option as well + ednsResp := new(dns.OPT) + ednsResp.Hdr.Name = "." + ednsResp.Hdr.Rrtype = dns.TypeOPT + ednsResp.SetUDPSize(edns.UDPSize()) + + // Setup the ECS option if present + if subnet := ednsSubnetForRequest(request); subnet != nil { + subOp := new(dns.EDNS0_SUBNET) + subOp.Code = dns.EDNS0SUBNET + subOp.Family = subnet.Family + subOp.Address = subnet.Address + subOp.SourceNetmask = subnet.SourceNetmask + if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented { + // reply is globally valid and should be cached accordingly + subOp.SourceScope = 0 + } else { + // reply is only valid for the subnet it was queried with + subOp.SourceScope = subnet.SourceNetmask + } + ednsResp.Option = append(ednsResp.Option, subOp) + } + + response.Extra = append(response.Extra, ednsResp) } // recursorAddr is used to add a port to the recursor if omitted. @@ -453,7 +456,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.Authoritative = true m.RecursionAvailable = (len(cfg.Recursors) > 0) - ecsGlobal := true + var err error switch req.Question[0].Qtype { case dns.TypeSOA: @@ -473,12 +476,18 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetRcode(req, dns.RcodeNotImplemented) default: - ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m) + err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault) + rCode := rCodeFromError(err) + if rCode == dns.RcodeNameError || errors.Is(err, errNoData) { + d.addSOA(cfg, m) + } + m.SetRcode(req, rCode) } - setEDNS(req, m, ecsGlobal) + setEDNS(req, m, !errors.Is(err, errECSNotGlobal)) + + d.trimDNSResponse(cfg, network, req, m) - // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Warn("failed to respond", "error", err) } @@ -566,17 +575,6 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns return } -// dispatch is used to parse a request and invoke the correct handler -func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) { - return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault) -} - -func (d *DNSServer) invalidQuery(req, resp *dns.Msg, cfg *dnsConfig, qName string) { - d.logger.Warn("QName invalid", "qname", qName) - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) -} - func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { switch len(labels) { case 1: @@ -589,10 +587,39 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { } } -// doDispatch is used to parse a request and invoke the correct handler. +var errECSNotGlobal = fmt.Errorf("ECS response is not global") +var errNameNotFound = fmt.Errorf("DNS name not found") + +// errNoData is used to indicate no resource records exist for the specified query type. +// Per the recommendation from Section 2.2 of RFC 2308, the server will return a TYPE 2 +// NODATA response in which the RCODE is set to NOERROR (RcodeSuccess), the Answer +// section is empty, and the Authority section contains the SOA record. +var errNoData = fmt.Errorf("no DNS Answer") + +// ecsNotGlobalError may be used to wrap an error or nil, to indicate that the +// EDNS client subnet source scope is not global. +type ecsNotGlobalError struct { + error +} + +func (e ecsNotGlobalError) Error() string { + if e.error == nil { + return "" + } + return e.error.Error() +} + +func (e ecsNotGlobalError) Is(other error) bool { + return other == errECSNotGlobal +} + +func (e ecsNotGlobalError) Unwrap() error { + return e.error +} + +// dispatch is used to parse a request and invoke the correct handler. // parameter maxRecursionLevel will handle whether recursive call can be performed -func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) { - ecsGlobal = true +func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter @@ -632,23 +659,23 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d } } - if queryKind == "" { - goto INVALID + invalid := func() error { + d.logger.Warn("QName invalid", "qname", qName) + return errNameNotFound } switch queryKind { case "service": n := len(queryParts) if n < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ - Network: network, Datacenter: datacenter, Connect: false, Ingress: false, @@ -669,34 +696,32 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d lookup.Tag = tag lookup.Service = queryParts[0][1:] // _name._tag.service.consul - d.serviceLookup(cfg, lookup, req, resp) - - // Consul 0.3 and prior format for SRV queries - } else { - - // Support "." in the label, re-join all the parts - tag := "" - if n >= 2 { - tag = strings.Join(queryParts[:n-1], ".") - } - - lookup.Tag = tag - lookup.Service = queryParts[n-1] - - // tag[.tag].name.service.consul - d.serviceLookup(cfg, lookup, req, resp) + return d.serviceLookup(cfg, lookup, req, resp) } + + // Consul 0.3 and prior format for SRV queries + // Support "." in the label, re-join all the parts + tag := "" + if n >= 2 { + tag = strings.Join(queryParts[:n-1], ".") + } + + lookup.Tag = tag + lookup.Service = queryParts[n-1] + + // tag[.tag].name.service.consul + return d.serviceLookup(cfg, lookup, req, resp) + case "connect": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ - Network: network, Datacenter: datacenter, Service: queryParts[len(queryParts)-1], Connect: true, @@ -705,18 +730,18 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d EnterpriseMeta: entMeta, } // name.connect.consul - d.serviceLookup(cfg, lookup, req, resp) + return d.serviceLookup(cfg, lookup, req, resp) + case "ingress": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { - goto INVALID + return invalid() } lookup := serviceLookup{ - Network: network, Datacenter: datacenter, Service: queryParts[len(queryParts)-1], Connect: false, @@ -725,38 +750,40 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d EnterpriseMeta: entMeta, } // name.ingress.consul - d.serviceLookup(cfg, lookup, req, resp) + return d.serviceLookup(cfg, lookup, req, resp) + case "node": if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { - goto INVALID + return invalid() } // Allow a "." in the node name, just join all the parts node := strings.Join(queryParts, ".") - d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel) + return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) + case "query": // ensure we have a query name if len(queryParts) < 1 { - goto INVALID + return invalid() } if !d.parseDatacenter(querySuffixes, &datacenter) { - goto INVALID + return invalid() } // Allow a "." in the query name, just join all the parts. query := strings.Join(queryParts, ".") - ecsGlobal = false - d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) + err := d.preparedQueryLookup(cfg, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) + return ecsNotGlobalError{error: err} case "addr": //
.addr.. - addr must be the second label, datacenter is optional if len(queryParts) != 1 { - goto INVALID + return invalid() } switch len(queryParts[0]) / 2 { @@ -764,7 +791,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d case 4: ip, err := hex.DecodeString(queryParts[0]) if err != nil { - goto INVALID + return invalid() } resp.Answer = append(resp.Answer, &dns.A{ @@ -780,7 +807,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d case 16: ip, err := hex.DecodeString(queryParts[0]) if err != nil { - goto INVALID + return invalid() } resp.Answer = append(resp.Answer, &dns.AAAA{ @@ -793,15 +820,10 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d AAAA: ip, }) } + return nil + default: + return invalid() } - // early return without error - return - -INVALID: - d.logger.Warn("QName invalid", "qname", qName) - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return } func (d *DNSServer) trimDomain(query string) string { @@ -818,23 +840,30 @@ func (d *DNSServer) trimDomain(query string) string { return strings.TrimSuffix(query, shorter) } -// computeRCode Return the DNS Error code from Consul Error -func (d *DNSServer) computeRCode(err error) int { - if err == nil { +// rCodeFromError return the appropriate DNS response code for a given error +func rCodeFromError(err error) int { + switch { + case err == nil: return dns.RcodeSuccess - } - if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) { + case errors.Is(err, errNoData): + return dns.RcodeSuccess + case errors.Is(err, errECSNotGlobal): + return rCodeFromError(errors.Unwrap(err)) + case errors.Is(err, errNameNotFound): return dns.RcodeNameError + case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err): + return dns.RcodeNameError + default: + return dns.RcodeServerFailure } - return dns.RcodeServerFailure } // nodeLookup is used to handle a node query -func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error { // Only handle ANY, A, AAAA, and TXT type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { - return + return nil } // Make an RPC request @@ -848,20 +877,12 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, } out, err := d.lookupNode(cfg, args) if err != nil { - d.logger.Error("rpc error", "error", err) - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return fmt.Errorf("failed rpc request: %w", err) } // If we have no out.NodeServices.Nodeaddress, return not found! if out.NodeServices == nil { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Add the node record @@ -883,6 +904,7 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, metas := d.generateMeta(n.Datacenter, q.Name, n, cfg.NodeTTL) *metaTarget = append(*metaTarget, metas...) } + return nil } func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { @@ -1021,7 +1043,7 @@ func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasE // trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit // of DNS responses -func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { +func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { hasExtra := len(resp.Extra) > 0 // There is some overhead, 65535 does not work maxSize := 65523 // 64k - 12 bytes DNS raw overhead @@ -1029,8 +1051,6 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // We avoid some function calls and allocations by only handling the // extra data when necessary. var index map[string]dns.RR - originalSize := resp.Len() - originalNumRecords := len(resp.Answer) // It is not possible to return more than 4k records even with compression // Since we are performing binary search it is not a big deal, but it @@ -1052,6 +1072,10 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // This enforces the given limit on 64k, the max limit for DNS messages for len(resp.Answer) > 1 && resp.Len() > maxSize { truncated = true + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) @@ -1063,13 +1087,7 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { syncExtra(index, resp) } } - if truncated { - d.logger.Debug("TCP answer to question too large, truncated", - "question", req.Question, - "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), - "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), - ) - } + return truncated } @@ -1118,6 +1136,10 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { // Even when size is too big for one single record, try to send it anyway // (useful for 512 bytes messages) for len(resp.Answer) > 1 && resp.Len() > maxSize-7 { + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) @@ -1136,15 +1158,26 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { } // trimDNSResponse will trim the response for UDP and TCP -func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) { +func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) bool { + var trimmed bool + originalSize := resp.Len() + originalNumRecords := len(resp.Answer) if network != "tcp" { trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) } else { - trimmed = d.trimTCPResponse(req, resp) + trimmed = trimTCPResponse(req, resp) } // Flag that there are more records to return in the UDP response - if trimmed && cfg.EnableTruncate { - resp.Truncated = true + if trimmed { + if cfg.EnableTruncate { + resp.Truncated = true + } + d.logger.Debug("DNS response too large, truncated", + "protocol", network, + "question", req.Question, + "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), + "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), + ) } return trimmed } @@ -1213,23 +1246,15 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st } // serviceLookup is used to handle a service query -func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) { +func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error { out, err := d.lookupServiceNodes(cfg, lookup) if err != nil { - d.logger.Error("rpc error", "error", err) - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return fmt.Errorf("rpc request failed: %w", err) } // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Perform a random shuffle @@ -1246,13 +1271,10 @@ func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, res d.serviceNodeRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel) } - d.trimDNSResponse(cfg, lookup.Network, req, resp) - - // If the answer is empty and the response isn't truncated, return not found - if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(cfg, resp) - return + if len(resp.Answer) == 0 { + return errNoData } + return nil } func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { @@ -1273,7 +1295,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { } // preparedQueryLookup is used to handle a prepared query. -func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ Datacenter: datacenter, @@ -1311,17 +1333,8 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que } out, err := d.lookupPreparedQuery(cfg, args) - - // If they give a bogus query name, treat that as a name error, - // not a full on server error. We have to use a string compare - // here since the RPC layer loses the type information. if err != nil { - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return err } // TODO (slackpad) - What's a safe limit we can set here? It seems like @@ -1352,9 +1365,7 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Add various responses depending on the request. @@ -1365,13 +1376,10 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } - d.trimDNSResponse(cfg, network, req, resp) - - // If the answer is empty and the response isn't truncated, return not found - if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(cfg, resp) - return + if len(resp.Answer) == 0 { + return errNoData } + return nil } func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { @@ -1907,7 +1915,8 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel resp := &dns.Msg{} req.SetQuestion(name, dns.TypeANY) - d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1) + // TODO: handle error response + d.dispatch(nil, req, resp, maxRecursionLevel-1) return resp.Answer } diff --git a/agent/dns_test.go b/agent/dns_test.go index d10e2d66a5..1ea42be088 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1,6 +1,7 @@ package agent import ( + "errors" "fmt" "math/rand" "net" @@ -10,6 +11,11 @@ import ( "testing" "time" + "github.com/hashicorp/serf/coordinate" + "github.com/miekg/dns" + "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/agent/config" agentdns "github.com/hashicorp/consul/agent/dns" "github.com/hashicorp/consul/agent/structs" @@ -17,10 +23,6 @@ import ( "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" - "github.com/hashicorp/serf/coordinate" - "github.com/miekg/dns" - "github.com/pascaldekloe/goe/verify" - "github.com/stretchr/testify/require" ) const ( @@ -508,6 +510,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) { m := new(dns.Msg) m.SetQuestion("google.node.consul.", dns.TypeANY) + m.SetEdns0(8192, true) c := new(dns.Client) in, _, err := c.Exchange(m, a.DNSAddr()) @@ -871,7 +874,6 @@ func TestDNS_EDNS0_ECS(t *testing.T) { require.True(t, ok) require.Equal(t, uint16(1), subnet.Family) require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask) - // scope set to 0 for a globally valid reply require.Equal(t, tc.ExpectedScope, subnet.SourceScope) require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address) }) @@ -4180,6 +4182,7 @@ func TestBinarySearch(t *testing.T) { msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV) msg.Answer = msgSrc.Answer msg.Extra = msgSrc.Extra + msg.Ns = msgSrc.Ns index := make(map[string]dns.RR, len(msg.Extra)) indexRRs(msg.Extra, index) blen := dnsBinaryTruncate(msg, maxSize, index, true) @@ -5969,9 +5972,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { - t.Fatalf("Bad: %#v", in) - } + require.Len(t, in.Ns, 1) soaRec, ok := in.Ns[0].(*dns.SOA) if !ok { t.Fatalf("Bad: %#v", in.Ns[0]) @@ -5980,10 +5981,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - if in.Rcode != dns.RcodeSuccess { - t.Fatalf("Bad: %#v", in) - } - + require.Equal(t, dns.RcodeSuccess, in.Rcode) } // Check for ipv4 records on ipv6-only service directly and via the @@ -6303,6 +6301,51 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) { } } +func TestDNS_EDNS_Truncate_AgentSource(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + a := NewTestAgent(t, ` + dns_config { + enable_truncate = true + } + `) + defer a.Shutdown() + a.DNSDisableCompression(true) + testrpc.WaitForLeader(t, a.RPC, "dc1") + + m := MockPreparedQuery{ + executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { + // Check that the agent inserted its self-name and datacenter to + // the RPC request body. + if args.Agent.Datacenter != a.Config.Datacenter || + args.Agent.Node != a.Config.NodeName { + t.Fatalf("bad: %#v", args.Agent) + } + for i := 0; i < 100; i++ { + reply.Nodes = append(reply.Nodes, structs.CheckServiceNode{Node: &structs.Node{Node: "apple", Address: fmt.Sprintf("node.address:%d", i)}, Service: &structs.NodeService{Service: "appleService", Address: fmt.Sprintf("service.address:%d", i)}}) + } + return nil + }, + } + + if err := a.registerEndpoint("PreparedQuery", &m); err != nil { + t.Fatalf("err: %v", err) + } + + req := new(dns.Msg) + req.SetQuestion("foo.query.consul.", dns.TypeSRV) + req.SetEdns0(2048, true) + req.Compress = false + + c := new(dns.Client) + resp, _, err := c.Exchange(req, a.DNSAddr()) + require.NoError(t, err) + require.True(t, resp.Len() < 2048) +} + func TestDNS_trimUDPResponse_NoTrim(t *testing.T) { t.Parallel() req := &dns.Msg{} @@ -6401,6 +6444,111 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) { } } +func TestDNS_trimUDPResponse_TrimLimitWithNS(t *testing.T) { + t.Parallel() + cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`) + + req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{} + for i := 0; i < cfg.DNSUDPAnswerLimit+1; i++ { + target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "redis-cache-redis.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + }, + Target: target, + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)), + } + ns := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + }, + Ns: fmt.Sprintf("soa-%d", i), + } + + resp.Answer = append(resp.Answer, srv) + resp.Extra = append(resp.Extra, a) + resp.Ns = append(resp.Ns, ns) + if i < cfg.DNSUDPAnswerLimit { + expected.Answer = append(expected.Answer, srv) + expected.Extra = append(expected.Extra, a) + } + } + + if trimmed := trimUDPResponse(req, resp, cfg.DNSUDPAnswerLimit); !trimmed { + t.Fatalf("Bad %#v", *resp) + } + require.LessOrEqual(t, resp.Len(), defaultMaxUDPSize) + require.Len(t, resp.Ns, 0) +} + +func TestDNS_trimTCPResponse_TrimLimitWithNS(t *testing.T) { + t.Parallel() + cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`) + + req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{} + for i := 0; i < 5000; i++ { + target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "redis-cache-redis.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + }, + Target: target, + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)), + } + ns := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + }, + Ns: fmt.Sprintf("soa-%d", i), + } + + resp.Answer = append(resp.Answer, srv) + resp.Extra = append(resp.Extra, a) + resp.Ns = append(resp.Ns, ns) + if i < cfg.DNSUDPAnswerLimit { + expected.Answer = append(expected.Answer, srv) + expected.Extra = append(expected.Extra, a) + } + } + req.Question = append(req.Question, dns.Question{Qtype: dns.TypeSRV}) + + if trimmed := trimTCPResponse(req, resp); !trimmed { + t.Fatalf("Bad %#v", *resp) + } + require.LessOrEqual(t, resp.Len(), 65523) + require.Len(t, resp.Ns, 0) +} + +func loadRuntimeConfig(t *testing.T, hcl string) *config.RuntimeConfig { + t.Helper() + cfg, warns, err := config.Load(config.BuilderOpts{HCL: []string{hcl}}, nil) + require.NoError(t, err) + require.Len(t, warns, 0) + return cfg +} + func TestDNS_trimUDPResponse_TrimSize(t *testing.T) { t.Parallel() cfg := config.DefaultRuntimeConfig(`data_dir = "a" bind_addr = "127.0.0.1"`) @@ -7151,3 +7299,19 @@ func TestDNS_ReloadConfig_DuringQuery(t *testing.T) { } } } + +func TestECSNotGlobalError(t *testing.T) { + t.Run("wrap nil", func(t *testing.T) { + e := ecsNotGlobalError{} + require.True(t, errors.Is(e, errECSNotGlobal)) + require.False(t, errors.Is(e, fmt.Errorf("some other error"))) + require.Equal(t, nil, errors.Unwrap(e)) + }) + + t.Run("wrap some error", func(t *testing.T) { + e := ecsNotGlobalError{error: errNameNotFound} + require.True(t, errors.Is(e, errECSNotGlobal)) + require.False(t, errors.Is(e, fmt.Errorf("some other error"))) + require.Equal(t, errNameNotFound, errors.Unwrap(e)) + }) +}