diff --git a/.changelog/10009.txt b/.changelog/10009.txt
new file mode 100644
index 0000000000..44f7174f51
--- /dev/null
+++ b/.changelog/10009.txt
@@ -0,0 +1,3 @@
+```release-note:bug
+dns: fixes a bug with edns truncation where the response could exceed the size limit in some cases.
+```
diff --git a/agent/dns.go b/agent/dns.go
index 5e5dcbb6a5..69c132cbd6 100644
--- a/agent/dns.go
+++ b/agent/dns.go
@@ -3,6 +3,7 @@ package agent
import (
"context"
"encoding/hex"
+ "errors"
"fmt"
"net"
"regexp"
@@ -13,6 +14,9 @@ import (
metrics "github.com/armon/go-metrics"
radix "github.com/armon/go-radix"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
+ "github.com/hashicorp/go-hclog"
+ "github.com/miekg/dns"
+
cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config"
agentdns "github.com/hashicorp/consul/agent/dns"
@@ -21,8 +25,6 @@ import (
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging"
- "github.com/hashicorp/go-hclog"
- "github.com/miekg/dns"
)
const (
@@ -74,7 +76,6 @@ type dnsConfig struct {
}
type serviceLookup struct {
- Network string
Datacenter string
Service string
Tag string
@@ -252,34 +253,36 @@ func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error {
// possibly the ECS headers as well if they were present in the
// original request
func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) {
- // Enable EDNS if enabled
- if edns := request.IsEdns0(); edns != nil {
- // cannot just use the SetEdns0 function as we need to embed
- // the ECS option as well
- ednsResp := new(dns.OPT)
- ednsResp.Hdr.Name = "."
- ednsResp.Hdr.Rrtype = dns.TypeOPT
- ednsResp.SetUDPSize(edns.UDPSize())
-
- // Setup the ECS option if present
- if subnet := ednsSubnetForRequest(request); subnet != nil {
- subOp := new(dns.EDNS0_SUBNET)
- subOp.Code = dns.EDNS0SUBNET
- subOp.Family = subnet.Family
- subOp.Address = subnet.Address
- subOp.SourceNetmask = subnet.SourceNetmask
- if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented {
- // reply is globally valid and should be cached accordingly
- subOp.SourceScope = 0
- } else {
- // reply is only valid for the subnet it was queried with
- subOp.SourceScope = subnet.SourceNetmask
- }
- ednsResp.Option = append(ednsResp.Option, subOp)
- }
-
- response.Extra = append(response.Extra, ednsResp)
+ edns := request.IsEdns0()
+ if edns == nil {
+ return
}
+
+ // cannot just use the SetEdns0 function as we need to embed
+ // the ECS option as well
+ ednsResp := new(dns.OPT)
+ ednsResp.Hdr.Name = "."
+ ednsResp.Hdr.Rrtype = dns.TypeOPT
+ ednsResp.SetUDPSize(edns.UDPSize())
+
+ // Setup the ECS option if present
+ if subnet := ednsSubnetForRequest(request); subnet != nil {
+ subOp := new(dns.EDNS0_SUBNET)
+ subOp.Code = dns.EDNS0SUBNET
+ subOp.Family = subnet.Family
+ subOp.Address = subnet.Address
+ subOp.SourceNetmask = subnet.SourceNetmask
+ if c := response.Rcode; ecsGlobal || c == dns.RcodeNameError || c == dns.RcodeServerFailure || c == dns.RcodeRefused || c == dns.RcodeNotImplemented {
+ // reply is globally valid and should be cached accordingly
+ subOp.SourceScope = 0
+ } else {
+ // reply is only valid for the subnet it was queried with
+ subOp.SourceScope = subnet.SourceNetmask
+ }
+ ednsResp.Option = append(ednsResp.Option, subOp)
+ }
+
+ response.Extra = append(response.Extra, ednsResp)
}
// recursorAddr is used to add a port to the recursor if omitted.
@@ -453,7 +456,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.Authoritative = true
m.RecursionAvailable = (len(cfg.Recursors) > 0)
- ecsGlobal := true
+ var err error
switch req.Question[0].Qtype {
case dns.TypeSOA:
@@ -473,12 +476,18 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetRcode(req, dns.RcodeNotImplemented)
default:
- ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m)
+ err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
+ rCode := rCodeFromError(err)
+ if rCode == dns.RcodeNameError || errors.Is(err, errNoData) {
+ d.addSOA(cfg, m)
+ }
+ m.SetRcode(req, rCode)
}
- setEDNS(req, m, ecsGlobal)
+ setEDNS(req, m, !errors.Is(err, errECSNotGlobal))
+
+ d.trimDNSResponse(cfg, network, req, m)
- // Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Warn("failed to respond", "error", err)
}
@@ -566,17 +575,6 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns
return
}
-// dispatch is used to parse a request and invoke the correct handler
-func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
- return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault)
-}
-
-func (d *DNSServer) invalidQuery(req, resp *dns.Msg, cfg *dnsConfig, qName string) {
- d.logger.Warn("QName invalid", "qname", qName)
- d.addSOA(cfg, resp)
- resp.SetRcode(req, dns.RcodeNameError)
-}
-
func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
switch len(labels) {
case 1:
@@ -589,10 +587,39 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
}
}
-// doDispatch is used to parse a request and invoke the correct handler.
+var errECSNotGlobal = fmt.Errorf("ECS response is not global")
+var errNameNotFound = fmt.Errorf("DNS name not found")
+
+// errNoData is used to indicate no resource records exist for the specified query type.
+// Per the recommendation from Section 2.2 of RFC 2308, the server will return a TYPE 2
+// NODATA response in which the RCODE is set to NOERROR (RcodeSuccess), the Answer
+// section is empty, and the Authority section contains the SOA record.
+var errNoData = fmt.Errorf("no DNS Answer")
+
+// ecsNotGlobalError may be used to wrap an error or nil, to indicate that the
+// EDNS client subnet source scope is not global.
+type ecsNotGlobalError struct {
+ error
+}
+
+func (e ecsNotGlobalError) Error() string {
+ if e.error == nil {
+ return ""
+ }
+ return e.error.Error()
+}
+
+func (e ecsNotGlobalError) Is(other error) bool {
+ return other == errECSNotGlobal
+}
+
+func (e ecsNotGlobalError) Unwrap() error {
+ return e.error
+}
+
+// dispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed
-func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) {
- ecsGlobal = true
+func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
@@ -632,23 +659,23 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
}
}
- if queryKind == "" {
- goto INVALID
+ invalid := func() error {
+ d.logger.Warn("QName invalid", "qname", qName)
+ return errNameNotFound
}
switch queryKind {
case "service":
n := len(queryParts)
if n < 1 {
- goto INVALID
+ return invalid()
}
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
- goto INVALID
+ return invalid()
}
lookup := serviceLookup{
- Network: network,
Datacenter: datacenter,
Connect: false,
Ingress: false,
@@ -669,34 +696,32 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
lookup.Tag = tag
lookup.Service = queryParts[0][1:]
// _name._tag.service.consul
- d.serviceLookup(cfg, lookup, req, resp)
-
- // Consul 0.3 and prior format for SRV queries
- } else {
-
- // Support "." in the label, re-join all the parts
- tag := ""
- if n >= 2 {
- tag = strings.Join(queryParts[:n-1], ".")
- }
-
- lookup.Tag = tag
- lookup.Service = queryParts[n-1]
-
- // tag[.tag].name.service.consul
- d.serviceLookup(cfg, lookup, req, resp)
+ return d.serviceLookup(cfg, lookup, req, resp)
}
+
+ // Consul 0.3 and prior format for SRV queries
+ // Support "." in the label, re-join all the parts
+ tag := ""
+ if n >= 2 {
+ tag = strings.Join(queryParts[:n-1], ".")
+ }
+
+ lookup.Tag = tag
+ lookup.Service = queryParts[n-1]
+
+ // tag[.tag].name.service.consul
+ return d.serviceLookup(cfg, lookup, req, resp)
+
case "connect":
if len(queryParts) < 1 {
- goto INVALID
+ return invalid()
}
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
- goto INVALID
+ return invalid()
}
lookup := serviceLookup{
- Network: network,
Datacenter: datacenter,
Service: queryParts[len(queryParts)-1],
Connect: true,
@@ -705,18 +730,18 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
EnterpriseMeta: entMeta,
}
// name.connect.consul
- d.serviceLookup(cfg, lookup, req, resp)
+ return d.serviceLookup(cfg, lookup, req, resp)
+
case "ingress":
if len(queryParts) < 1 {
- goto INVALID
+ return invalid()
}
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
- goto INVALID
+ return invalid()
}
lookup := serviceLookup{
- Network: network,
Datacenter: datacenter,
Service: queryParts[len(queryParts)-1],
Connect: false,
@@ -725,38 +750,40 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
EnterpriseMeta: entMeta,
}
// name.ingress.consul
- d.serviceLookup(cfg, lookup, req, resp)
+ return d.serviceLookup(cfg, lookup, req, resp)
+
case "node":
if len(queryParts) < 1 {
- goto INVALID
+ return invalid()
}
if !d.parseDatacenter(querySuffixes, &datacenter) {
- goto INVALID
+ return invalid()
}
// Allow a "." in the node name, just join all the parts
node := strings.Join(queryParts, ".")
- d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel)
+ return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel)
+
case "query":
// ensure we have a query name
if len(queryParts) < 1 {
- goto INVALID
+ return invalid()
}
if !d.parseDatacenter(querySuffixes, &datacenter) {
- goto INVALID
+ return invalid()
}
// Allow a "." in the query name, just join all the parts.
query := strings.Join(queryParts, ".")
- ecsGlobal = false
- d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
+ err := d.preparedQueryLookup(cfg, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
+ return ecsNotGlobalError{error: err}
case "addr":
//
.addr.. - addr must be the second label, datacenter is optional
if len(queryParts) != 1 {
- goto INVALID
+ return invalid()
}
switch len(queryParts[0]) / 2 {
@@ -764,7 +791,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
case 4:
ip, err := hex.DecodeString(queryParts[0])
if err != nil {
- goto INVALID
+ return invalid()
}
resp.Answer = append(resp.Answer, &dns.A{
@@ -780,7 +807,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
case 16:
ip, err := hex.DecodeString(queryParts[0])
if err != nil {
- goto INVALID
+ return invalid()
}
resp.Answer = append(resp.Answer, &dns.AAAA{
@@ -793,15 +820,10 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
AAAA: ip,
})
}
+ return nil
+ default:
+ return invalid()
}
- // early return without error
- return
-
-INVALID:
- d.logger.Warn("QName invalid", "qname", qName)
- d.addSOA(cfg, resp)
- resp.SetRcode(req, dns.RcodeNameError)
- return
}
func (d *DNSServer) trimDomain(query string) string {
@@ -818,23 +840,30 @@ func (d *DNSServer) trimDomain(query string) string {
return strings.TrimSuffix(query, shorter)
}
-// computeRCode Return the DNS Error code from Consul Error
-func (d *DNSServer) computeRCode(err error) int {
- if err == nil {
+// rCodeFromError return the appropriate DNS response code for a given error
+func rCodeFromError(err error) int {
+ switch {
+ case err == nil:
return dns.RcodeSuccess
- }
- if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) {
+ case errors.Is(err, errNoData):
+ return dns.RcodeSuccess
+ case errors.Is(err, errECSNotGlobal):
+ return rCodeFromError(errors.Unwrap(err))
+ case errors.Is(err, errNameNotFound):
return dns.RcodeNameError
+ case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err):
+ return dns.RcodeNameError
+ default:
+ return dns.RcodeServerFailure
}
- return dns.RcodeServerFailure
}
// nodeLookup is used to handle a node query
-func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
+func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
- return
+ return nil
}
// Make an RPC request
@@ -848,20 +877,12 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string,
}
out, err := d.lookupNode(cfg, args)
if err != nil {
- d.logger.Error("rpc error", "error", err)
- rCode := d.computeRCode(err)
- if rCode == dns.RcodeNameError {
- d.addSOA(cfg, resp)
- }
- resp.SetRcode(req, rCode)
- return
+ return fmt.Errorf("failed rpc request: %w", err)
}
// If we have no out.NodeServices.Nodeaddress, return not found!
if out.NodeServices == nil {
- d.addSOA(cfg, resp)
- resp.SetRcode(req, dns.RcodeNameError)
- return
+ return errNameNotFound
}
// Add the node record
@@ -883,6 +904,7 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string,
metas := d.generateMeta(n.Datacenter, q.Name, n, cfg.NodeTTL)
*metaTarget = append(*metaTarget, metas...)
}
+ return nil
}
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
@@ -1021,7 +1043,7 @@ func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasE
// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit
// of DNS responses
-func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
+func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
hasExtra := len(resp.Extra) > 0
// There is some overhead, 65535 does not work
maxSize := 65523 // 64k - 12 bytes DNS raw overhead
@@ -1029,8 +1051,6 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
// We avoid some function calls and allocations by only handling the
// extra data when necessary.
var index map[string]dns.RR
- originalSize := resp.Len()
- originalNumRecords := len(resp.Answer)
// It is not possible to return more than 4k records even with compression
// Since we are performing binary search it is not a big deal, but it
@@ -1052,6 +1072,10 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
// This enforces the given limit on 64k, the max limit for DNS messages
for len(resp.Answer) > 1 && resp.Len() > maxSize {
truncated = true
+ // first try to remove the NS section may be it will truncate enough
+ if len(resp.Ns) != 0 {
+ resp.Ns = []dns.RR{}
+ }
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
@@ -1063,13 +1087,7 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
syncExtra(index, resp)
}
}
- if truncated {
- d.logger.Debug("TCP answer to question too large, truncated",
- "question", req.Question,
- "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords),
- "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize),
- )
- }
+
return truncated
}
@@ -1118,6 +1136,10 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
// Even when size is too big for one single record, try to send it anyway
// (useful for 512 bytes messages)
for len(resp.Answer) > 1 && resp.Len() > maxSize-7 {
+ // first try to remove the NS section may be it will truncate enough
+ if len(resp.Ns) != 0 {
+ resp.Ns = []dns.RR{}
+ }
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
@@ -1136,15 +1158,26 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
}
// trimDNSResponse will trim the response for UDP and TCP
-func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) {
+func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) bool {
+ var trimmed bool
+ originalSize := resp.Len()
+ originalNumRecords := len(resp.Answer)
if network != "tcp" {
trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit)
} else {
- trimmed = d.trimTCPResponse(req, resp)
+ trimmed = trimTCPResponse(req, resp)
}
// Flag that there are more records to return in the UDP response
- if trimmed && cfg.EnableTruncate {
- resp.Truncated = true
+ if trimmed {
+ if cfg.EnableTruncate {
+ resp.Truncated = true
+ }
+ d.logger.Debug("DNS response too large, truncated",
+ "protocol", network,
+ "question", req.Question,
+ "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords),
+ "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize),
+ )
}
return trimmed
}
@@ -1213,23 +1246,15 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st
}
// serviceLookup is used to handle a service query
-func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) {
+func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error {
out, err := d.lookupServiceNodes(cfg, lookup)
if err != nil {
- d.logger.Error("rpc error", "error", err)
- rCode := d.computeRCode(err)
- if rCode == dns.RcodeNameError {
- d.addSOA(cfg, resp)
- }
- resp.SetRcode(req, rCode)
- return
+ return fmt.Errorf("rpc request failed: %w", err)
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
- d.addSOA(cfg, resp)
- resp.SetRcode(req, dns.RcodeNameError)
- return
+ return errNameNotFound
}
// Perform a random shuffle
@@ -1246,13 +1271,10 @@ func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, res
d.serviceNodeRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel)
}
- d.trimDNSResponse(cfg, lookup.Network, req, resp)
-
- // If the answer is empty and the response isn't truncated, return not found
- if len(resp.Answer) == 0 && !resp.Truncated {
- d.addSOA(cfg, resp)
- return
+ if len(resp.Answer) == 0 {
+ return errNoData
}
+ return nil
}
func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
@@ -1273,7 +1295,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}
// preparedQueryLookup is used to handle a prepared query.
-func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
+func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
@@ -1311,17 +1333,8 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
}
out, err := d.lookupPreparedQuery(cfg, args)
-
- // If they give a bogus query name, treat that as a name error,
- // not a full on server error. We have to use a string compare
- // here since the RPC layer loses the type information.
if err != nil {
- rCode := d.computeRCode(err)
- if rCode == dns.RcodeNameError {
- d.addSOA(cfg, resp)
- }
- resp.SetRcode(req, rCode)
- return
+ return err
}
// TODO (slackpad) - What's a safe limit we can set here? It seems like
@@ -1352,9 +1365,7 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
- d.addSOA(cfg, resp)
- resp.SetRcode(req, dns.RcodeNameError)
- return
+ return errNameNotFound
}
// Add various responses depending on the request.
@@ -1365,13 +1376,10 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}
- d.trimDNSResponse(cfg, network, req, resp)
-
- // If the answer is empty and the response isn't truncated, return not found
- if len(resp.Answer) == 0 && !resp.Truncated {
- d.addSOA(cfg, resp)
- return
+ if len(resp.Answer) == 0 {
+ return errNoData
}
+ return nil
}
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
@@ -1907,7 +1915,8 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel
resp := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY)
- d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1)
+ // TODO: handle error response
+ d.dispatch(nil, req, resp, maxRecursionLevel-1)
return resp.Answer
}
diff --git a/agent/dns_test.go b/agent/dns_test.go
index d10e2d66a5..1ea42be088 100644
--- a/agent/dns_test.go
+++ b/agent/dns_test.go
@@ -1,6 +1,7 @@
package agent
import (
+ "errors"
"fmt"
"math/rand"
"net"
@@ -10,6 +11,11 @@ import (
"testing"
"time"
+ "github.com/hashicorp/serf/coordinate"
+ "github.com/miekg/dns"
+ "github.com/pascaldekloe/goe/verify"
+ "github.com/stretchr/testify/require"
+
"github.com/hashicorp/consul/agent/config"
agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/structs"
@@ -17,10 +23,6 @@ import (
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
- "github.com/hashicorp/serf/coordinate"
- "github.com/miekg/dns"
- "github.com/pascaldekloe/goe/verify"
- "github.com/stretchr/testify/require"
)
const (
@@ -508,6 +510,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) {
m := new(dns.Msg)
m.SetQuestion("google.node.consul.", dns.TypeANY)
+ m.SetEdns0(8192, true)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
@@ -871,7 +874,6 @@ func TestDNS_EDNS0_ECS(t *testing.T) {
require.True(t, ok)
require.Equal(t, uint16(1), subnet.Family)
require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask)
- // scope set to 0 for a globally valid reply
require.Equal(t, tc.ExpectedScope, subnet.SourceScope)
require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address)
})
@@ -4180,6 +4182,7 @@ func TestBinarySearch(t *testing.T) {
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)
msg.Answer = msgSrc.Answer
msg.Extra = msgSrc.Extra
+ msg.Ns = msgSrc.Ns
index := make(map[string]dns.RR, len(msg.Extra))
indexRRs(msg.Extra, index)
blen := dnsBinaryTruncate(msg, maxSize, index, true)
@@ -5969,9 +5972,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
t.Fatalf("err: %v", err)
}
- if len(in.Ns) != 1 {
- t.Fatalf("Bad: %#v", in)
- }
+ require.Len(t, in.Ns, 1)
soaRec, ok := in.Ns[0].(*dns.SOA)
if !ok {
t.Fatalf("Bad: %#v", in.Ns[0])
@@ -5980,10 +5981,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
t.Fatalf("Bad: %#v", in.Ns[0])
}
- if in.Rcode != dns.RcodeSuccess {
- t.Fatalf("Bad: %#v", in)
- }
-
+ require.Equal(t, dns.RcodeSuccess, in.Rcode)
}
// Check for ipv4 records on ipv6-only service directly and via the
@@ -6303,6 +6301,51 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
}
}
+func TestDNS_EDNS_Truncate_AgentSource(t *testing.T) {
+ if testing.Short() {
+ t.Skip("too slow for testing.Short")
+ }
+
+ t.Parallel()
+ a := NewTestAgent(t, `
+ dns_config {
+ enable_truncate = true
+ }
+ `)
+ defer a.Shutdown()
+ a.DNSDisableCompression(true)
+ testrpc.WaitForLeader(t, a.RPC, "dc1")
+
+ m := MockPreparedQuery{
+ executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
+ // Check that the agent inserted its self-name and datacenter to
+ // the RPC request body.
+ if args.Agent.Datacenter != a.Config.Datacenter ||
+ args.Agent.Node != a.Config.NodeName {
+ t.Fatalf("bad: %#v", args.Agent)
+ }
+ for i := 0; i < 100; i++ {
+ reply.Nodes = append(reply.Nodes, structs.CheckServiceNode{Node: &structs.Node{Node: "apple", Address: fmt.Sprintf("node.address:%d", i)}, Service: &structs.NodeService{Service: "appleService", Address: fmt.Sprintf("service.address:%d", i)}})
+ }
+ return nil
+ },
+ }
+
+ if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ req := new(dns.Msg)
+ req.SetQuestion("foo.query.consul.", dns.TypeSRV)
+ req.SetEdns0(2048, true)
+ req.Compress = false
+
+ c := new(dns.Client)
+ resp, _, err := c.Exchange(req, a.DNSAddr())
+ require.NoError(t, err)
+ require.True(t, resp.Len() < 2048)
+}
+
func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
t.Parallel()
req := &dns.Msg{}
@@ -6401,6 +6444,111 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
}
}
+func TestDNS_trimUDPResponse_TrimLimitWithNS(t *testing.T) {
+ t.Parallel()
+ cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`)
+
+ req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
+ for i := 0; i < cfg.DNSUDPAnswerLimit+1; i++ {
+ target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
+ srv := &dns.SRV{
+ Hdr: dns.RR_Header{
+ Name: "redis-cache-redis.service.consul.",
+ Rrtype: dns.TypeSRV,
+ Class: dns.ClassINET,
+ },
+ Target: target,
+ }
+ a := &dns.A{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)),
+ }
+ ns := &dns.SOA{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeSOA,
+ Class: dns.ClassINET,
+ },
+ Ns: fmt.Sprintf("soa-%d", i),
+ }
+
+ resp.Answer = append(resp.Answer, srv)
+ resp.Extra = append(resp.Extra, a)
+ resp.Ns = append(resp.Ns, ns)
+ if i < cfg.DNSUDPAnswerLimit {
+ expected.Answer = append(expected.Answer, srv)
+ expected.Extra = append(expected.Extra, a)
+ }
+ }
+
+ if trimmed := trimUDPResponse(req, resp, cfg.DNSUDPAnswerLimit); !trimmed {
+ t.Fatalf("Bad %#v", *resp)
+ }
+ require.LessOrEqual(t, resp.Len(), defaultMaxUDPSize)
+ require.Len(t, resp.Ns, 0)
+}
+
+func TestDNS_trimTCPResponse_TrimLimitWithNS(t *testing.T) {
+ t.Parallel()
+ cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`)
+
+ req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
+ for i := 0; i < 5000; i++ {
+ target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
+ srv := &dns.SRV{
+ Hdr: dns.RR_Header{
+ Name: "redis-cache-redis.service.consul.",
+ Rrtype: dns.TypeSRV,
+ Class: dns.ClassINET,
+ },
+ Target: target,
+ }
+ a := &dns.A{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ },
+ A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)),
+ }
+ ns := &dns.SOA{
+ Hdr: dns.RR_Header{
+ Name: target,
+ Rrtype: dns.TypeSOA,
+ Class: dns.ClassINET,
+ },
+ Ns: fmt.Sprintf("soa-%d", i),
+ }
+
+ resp.Answer = append(resp.Answer, srv)
+ resp.Extra = append(resp.Extra, a)
+ resp.Ns = append(resp.Ns, ns)
+ if i < cfg.DNSUDPAnswerLimit {
+ expected.Answer = append(expected.Answer, srv)
+ expected.Extra = append(expected.Extra, a)
+ }
+ }
+ req.Question = append(req.Question, dns.Question{Qtype: dns.TypeSRV})
+
+ if trimmed := trimTCPResponse(req, resp); !trimmed {
+ t.Fatalf("Bad %#v", *resp)
+ }
+ require.LessOrEqual(t, resp.Len(), 65523)
+ require.Len(t, resp.Ns, 0)
+}
+
+func loadRuntimeConfig(t *testing.T, hcl string) *config.RuntimeConfig {
+ t.Helper()
+ cfg, warns, err := config.Load(config.BuilderOpts{HCL: []string{hcl}}, nil)
+ require.NoError(t, err)
+ require.Len(t, warns, 0)
+ return cfg
+}
+
func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
t.Parallel()
cfg := config.DefaultRuntimeConfig(`data_dir = "a" bind_addr = "127.0.0.1"`)
@@ -7151,3 +7299,19 @@ func TestDNS_ReloadConfig_DuringQuery(t *testing.T) {
}
}
}
+
+func TestECSNotGlobalError(t *testing.T) {
+ t.Run("wrap nil", func(t *testing.T) {
+ e := ecsNotGlobalError{}
+ require.True(t, errors.Is(e, errECSNotGlobal))
+ require.False(t, errors.Is(e, fmt.Errorf("some other error")))
+ require.Equal(t, nil, errors.Unwrap(e))
+ })
+
+ t.Run("wrap some error", func(t *testing.T) {
+ e := ecsNotGlobalError{error: errNameNotFound}
+ require.True(t, errors.Is(e, errECSNotGlobal))
+ require.False(t, errors.Is(e, fmt.Errorf("some other error")))
+ require.Equal(t, errNameNotFound, errors.Unwrap(e))
+ })
+}