Merge pull request #10642 from hashicorp/dnephin/backport-1.8-dns-truncate

[1.8.x] dns: properly trim response when EDNS is used
This commit is contained in:
Daniel Nephin 2021-07-19 16:48:45 -04:00 committed by GitHub
commit fad658591c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 353 additions and 177 deletions

3
.changelog/10009.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
dns: fixes a bug with edns truncation where the response could exceed the size limit in some cases.
```

View File

@ -3,6 +3,7 @@ package agent
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"regexp" "regexp"
@ -13,6 +14,9 @@ import (
metrics "github.com/armon/go-metrics" metrics "github.com/armon/go-metrics"
radix "github.com/armon/go-radix" radix "github.com/armon/go-radix"
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
cachetype "github.com/hashicorp/consul/agent/cache-types" cachetype "github.com/hashicorp/consul/agent/cache-types"
"github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/config"
agentdns "github.com/hashicorp/consul/agent/dns" agentdns "github.com/hashicorp/consul/agent/dns"
@ -21,8 +25,6 @@ import (
"github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/logging" "github.com/hashicorp/consul/logging"
"github.com/hashicorp/go-hclog"
"github.com/miekg/dns"
) )
const ( const (
@ -74,7 +76,6 @@ type dnsConfig struct {
} }
type serviceLookup struct { type serviceLookup struct {
Network string
Datacenter string Datacenter string
Service string Service string
Tag string Tag string
@ -252,8 +253,11 @@ func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error {
// possibly the ECS headers as well if they were present in the // possibly the ECS headers as well if they were present in the
// original request // original request
func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) { func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) {
// Enable EDNS if enabled edns := request.IsEdns0()
if edns := request.IsEdns0(); edns != nil { if edns == nil {
return
}
// cannot just use the SetEdns0 function as we need to embed // cannot just use the SetEdns0 function as we need to embed
// the ECS option as well // the ECS option as well
ednsResp := new(dns.OPT) ednsResp := new(dns.OPT)
@ -280,7 +284,6 @@ func setEDNS(request *dns.Msg, response *dns.Msg, ecsGlobal bool) {
response.Extra = append(response.Extra, ednsResp) response.Extra = append(response.Extra, ednsResp)
} }
}
// recursorAddr is used to add a port to the recursor if omitted. // recursorAddr is used to add a port to the recursor if omitted.
func recursorAddr(recursor string) (string, error) { func recursorAddr(recursor string) (string, error) {
@ -453,7 +456,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.Authoritative = true m.Authoritative = true
m.RecursionAvailable = (len(cfg.Recursors) > 0) m.RecursionAvailable = (len(cfg.Recursors) > 0)
ecsGlobal := true var err error
switch req.Question[0].Qtype { switch req.Question[0].Qtype {
case dns.TypeSOA: case dns.TypeSOA:
@ -473,12 +476,18 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetRcode(req, dns.RcodeNotImplemented) m.SetRcode(req, dns.RcodeNotImplemented)
default: default:
ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m) err = d.dispatch(resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
rCode := rCodeFromError(err)
if rCode == dns.RcodeNameError || errors.Is(err, errNoData) {
d.addSOA(cfg, m)
}
m.SetRcode(req, rCode)
} }
setEDNS(req, m, ecsGlobal) setEDNS(req, m, !errors.Is(err, errECSNotGlobal))
d.trimDNSResponse(cfg, network, req, m)
// Write out the complete response
if err := resp.WriteMsg(m); err != nil { if err := resp.WriteMsg(m); err != nil {
d.logger.Warn("failed to respond", "error", err) d.logger.Warn("failed to respond", "error", err)
} }
@ -566,17 +575,6 @@ func (d *DNSServer) nameservers(cfg *dnsConfig, maxRecursionLevel int) (ns []dns
return return
} }
// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault)
}
func (d *DNSServer) invalidQuery(req, resp *dns.Msg, cfg *dnsConfig, qName string) {
d.logger.Warn("QName invalid", "qname", qName)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
}
func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
switch len(labels) { switch len(labels) {
case 1: case 1:
@ -589,10 +587,39 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
} }
} }
// doDispatch is used to parse a request and invoke the correct handler. var errECSNotGlobal = fmt.Errorf("ECS response is not global")
var errNameNotFound = fmt.Errorf("DNS name not found")
// errNoData is used to indicate no resource records exist for the specified query type.
// Per the recommendation from Section 2.2 of RFC 2308, the server will return a TYPE 2
// NODATA response in which the RCODE is set to NOERROR (RcodeSuccess), the Answer
// section is empty, and the Authority section contains the SOA record.
var errNoData = fmt.Errorf("no DNS Answer")
// ecsNotGlobalError may be used to wrap an error or nil, to indicate that the
// EDNS client subnet source scope is not global.
type ecsNotGlobalError struct {
error
}
func (e ecsNotGlobalError) Error() string {
if e.error == nil {
return ""
}
return e.error.Error()
}
func (e ecsNotGlobalError) Is(other error) bool {
return other == errECSNotGlobal
}
func (e ecsNotGlobalError) Unwrap() error {
return e.error
}
// dispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed // parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) { func (d *DNSServer) dispatch(remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
ecsGlobal = true
// By default the query is in the default datacenter // By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter datacenter := d.agent.config.Datacenter
@ -632,23 +659,23 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
} }
} }
if queryKind == "" { invalid := func() error {
goto INVALID d.logger.Warn("QName invalid", "qname", qName)
return errNameNotFound
} }
switch queryKind { switch queryKind {
case "service": case "service":
n := len(queryParts) n := len(queryParts)
if n < 1 { if n < 1 {
goto INVALID return invalid()
} }
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
goto INVALID return invalid()
} }
lookup := serviceLookup{ lookup := serviceLookup{
Network: network,
Datacenter: datacenter, Datacenter: datacenter,
Connect: false, Connect: false,
Ingress: false, Ingress: false,
@ -669,11 +696,10 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
lookup.Tag = tag lookup.Tag = tag
lookup.Service = queryParts[0][1:] lookup.Service = queryParts[0][1:]
// _name._tag.service.consul // _name._tag.service.consul
d.serviceLookup(cfg, lookup, req, resp) return d.serviceLookup(cfg, lookup, req, resp)
}
// Consul 0.3 and prior format for SRV queries // Consul 0.3 and prior format for SRV queries
} else {
// Support "." in the label, re-join all the parts // Support "." in the label, re-join all the parts
tag := "" tag := ""
if n >= 2 { if n >= 2 {
@ -684,19 +710,18 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
lookup.Service = queryParts[n-1] lookup.Service = queryParts[n-1]
// tag[.tag].name.service.consul // tag[.tag].name.service.consul
d.serviceLookup(cfg, lookup, req, resp) return d.serviceLookup(cfg, lookup, req, resp)
}
case "connect": case "connect":
if len(queryParts) < 1 { if len(queryParts) < 1 {
goto INVALID return invalid()
} }
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
goto INVALID return invalid()
} }
lookup := serviceLookup{ lookup := serviceLookup{
Network: network,
Datacenter: datacenter, Datacenter: datacenter,
Service: queryParts[len(queryParts)-1], Service: queryParts[len(queryParts)-1],
Connect: true, Connect: true,
@ -705,18 +730,18 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
EnterpriseMeta: entMeta, EnterpriseMeta: entMeta,
} }
// name.connect.consul // name.connect.consul
d.serviceLookup(cfg, lookup, req, resp) return d.serviceLookup(cfg, lookup, req, resp)
case "ingress": case "ingress":
if len(queryParts) < 1 { if len(queryParts) < 1 {
goto INVALID return invalid()
} }
if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) { if !d.parseDatacenterAndEnterpriseMeta(querySuffixes, cfg, &datacenter, &entMeta) {
goto INVALID return invalid()
} }
lookup := serviceLookup{ lookup := serviceLookup{
Network: network,
Datacenter: datacenter, Datacenter: datacenter,
Service: queryParts[len(queryParts)-1], Service: queryParts[len(queryParts)-1],
Connect: false, Connect: false,
@ -725,38 +750,40 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
EnterpriseMeta: entMeta, EnterpriseMeta: entMeta,
} }
// name.ingress.consul // name.ingress.consul
d.serviceLookup(cfg, lookup, req, resp) return d.serviceLookup(cfg, lookup, req, resp)
case "node": case "node":
if len(queryParts) < 1 { if len(queryParts) < 1 {
goto INVALID return invalid()
} }
if !d.parseDatacenter(querySuffixes, &datacenter) { if !d.parseDatacenter(querySuffixes, &datacenter) {
goto INVALID return invalid()
} }
// Allow a "." in the node name, just join all the parts // Allow a "." in the node name, just join all the parts
node := strings.Join(queryParts, ".") node := strings.Join(queryParts, ".")
d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel) return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel)
case "query": case "query":
// ensure we have a query name // ensure we have a query name
if len(queryParts) < 1 { if len(queryParts) < 1 {
goto INVALID return invalid()
} }
if !d.parseDatacenter(querySuffixes, &datacenter) { if !d.parseDatacenter(querySuffixes, &datacenter) {
goto INVALID return invalid()
} }
// Allow a "." in the query name, just join all the parts. // Allow a "." in the query name, just join all the parts.
query := strings.Join(queryParts, ".") query := strings.Join(queryParts, ".")
ecsGlobal = false err := d.preparedQueryLookup(cfg, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) return ecsNotGlobalError{error: err}
case "addr": case "addr":
// <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional // <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional
if len(queryParts) != 1 { if len(queryParts) != 1 {
goto INVALID return invalid()
} }
switch len(queryParts[0]) / 2 { switch len(queryParts[0]) / 2 {
@ -764,7 +791,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
case 4: case 4:
ip, err := hex.DecodeString(queryParts[0]) ip, err := hex.DecodeString(queryParts[0])
if err != nil { if err != nil {
goto INVALID return invalid()
} }
resp.Answer = append(resp.Answer, &dns.A{ resp.Answer = append(resp.Answer, &dns.A{
@ -780,7 +807,7 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
case 16: case 16:
ip, err := hex.DecodeString(queryParts[0]) ip, err := hex.DecodeString(queryParts[0])
if err != nil { if err != nil {
goto INVALID return invalid()
} }
resp.Answer = append(resp.Answer, &dns.AAAA{ resp.Answer = append(resp.Answer, &dns.AAAA{
@ -793,15 +820,10 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d
AAAA: ip, AAAA: ip,
}) })
} }
return nil
default:
return invalid()
} }
// early return without error
return
INVALID:
d.logger.Warn("QName invalid", "qname", qName)
d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
} }
func (d *DNSServer) trimDomain(query string) string { func (d *DNSServer) trimDomain(query string) string {
@ -818,23 +840,30 @@ func (d *DNSServer) trimDomain(query string) string {
return strings.TrimSuffix(query, shorter) return strings.TrimSuffix(query, shorter)
} }
// computeRCode Return the DNS Error code from Consul Error // rCodeFromError return the appropriate DNS response code for a given error
func (d *DNSServer) computeRCode(err error) int { func rCodeFromError(err error) int {
if err == nil { switch {
case err == nil:
return dns.RcodeSuccess return dns.RcodeSuccess
} case errors.Is(err, errNoData):
if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) { return dns.RcodeSuccess
case errors.Is(err, errECSNotGlobal):
return rCodeFromError(errors.Unwrap(err))
case errors.Is(err, errNameNotFound):
return dns.RcodeNameError return dns.RcodeNameError
} case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err):
return dns.RcodeNameError
default:
return dns.RcodeServerFailure return dns.RcodeServerFailure
} }
}
// nodeLookup is used to handle a node query // nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error {
// Only handle ANY, A, AAAA, and TXT type requests // Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
return return nil
} }
// Make an RPC request // Make an RPC request
@ -848,20 +877,12 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string,
} }
out, err := d.lookupNode(cfg, args) out, err := d.lookupNode(cfg, args)
if err != nil { if err != nil {
d.logger.Error("rpc error", "error", err) return fmt.Errorf("failed rpc request: %w", err)
rCode := d.computeRCode(err)
if rCode == dns.RcodeNameError {
d.addSOA(cfg, resp)
}
resp.SetRcode(req, rCode)
return
} }
// If we have no out.NodeServices.Nodeaddress, return not found! // If we have no out.NodeServices.Nodeaddress, return not found!
if out.NodeServices == nil { if out.NodeServices == nil {
d.addSOA(cfg, resp) return errNameNotFound
resp.SetRcode(req, dns.RcodeNameError)
return
} }
// Add the node record // Add the node record
@ -883,6 +904,7 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string,
metas := d.generateMeta(n.Datacenter, q.Name, n, cfg.NodeTTL) metas := d.generateMeta(n.Datacenter, q.Name, n, cfg.NodeTTL)
*metaTarget = append(*metaTarget, metas...) *metaTarget = append(*metaTarget, metas...)
} }
return nil
} }
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
@ -1021,7 +1043,7 @@ func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasE
// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit // trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit
// of DNS responses // of DNS responses
func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
hasExtra := len(resp.Extra) > 0 hasExtra := len(resp.Extra) > 0
// There is some overhead, 65535 does not work // There is some overhead, 65535 does not work
maxSize := 65523 // 64k - 12 bytes DNS raw overhead maxSize := 65523 // 64k - 12 bytes DNS raw overhead
@ -1029,8 +1051,6 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
// We avoid some function calls and allocations by only handling the // We avoid some function calls and allocations by only handling the
// extra data when necessary. // extra data when necessary.
var index map[string]dns.RR var index map[string]dns.RR
originalSize := resp.Len()
originalNumRecords := len(resp.Answer)
// It is not possible to return more than 4k records even with compression // It is not possible to return more than 4k records even with compression
// Since we are performing binary search it is not a big deal, but it // Since we are performing binary search it is not a big deal, but it
@ -1052,6 +1072,10 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
// This enforces the given limit on 64k, the max limit for DNS messages // This enforces the given limit on 64k, the max limit for DNS messages
for len(resp.Answer) > 1 && resp.Len() > maxSize { for len(resp.Answer) > 1 && resp.Len() > maxSize {
truncated = true truncated = true
// first try to remove the NS section may be it will truncate enough
if len(resp.Ns) != 0 {
resp.Ns = []dns.RR{}
}
// More than 100 bytes, find with a binary search // More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 { if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
@ -1063,13 +1087,7 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
syncExtra(index, resp) syncExtra(index, resp)
} }
} }
if truncated {
d.logger.Debug("TCP answer to question too large, truncated",
"question", req.Question,
"records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords),
"size", fmt.Sprintf("%d/%d", resp.Len(), originalSize),
)
}
return truncated return truncated
} }
@ -1118,6 +1136,10 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
// Even when size is too big for one single record, try to send it anyway // Even when size is too big for one single record, try to send it anyway
// (useful for 512 bytes messages) // (useful for 512 bytes messages)
for len(resp.Answer) > 1 && resp.Len() > maxSize-7 { for len(resp.Answer) > 1 && resp.Len() > maxSize-7 {
// first try to remove the NS section may be it will truncate enough
if len(resp.Ns) != 0 {
resp.Ns = []dns.RR{}
}
// More than 100 bytes, find with a binary search // More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 { if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
@ -1136,16 +1158,27 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
} }
// trimDNSResponse will trim the response for UDP and TCP // trimDNSResponse will trim the response for UDP and TCP
func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) { func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) bool {
var trimmed bool
originalSize := resp.Len()
originalNumRecords := len(resp.Answer)
if network != "tcp" { if network != "tcp" {
trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit)
} else { } else {
trimmed = d.trimTCPResponse(req, resp) trimmed = trimTCPResponse(req, resp)
} }
// Flag that there are more records to return in the UDP response // Flag that there are more records to return in the UDP response
if trimmed && cfg.EnableTruncate { if trimmed {
if cfg.EnableTruncate {
resp.Truncated = true resp.Truncated = true
} }
d.logger.Debug("DNS response too large, truncated",
"protocol", network,
"question", req.Question,
"records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords),
"size", fmt.Sprintf("%d/%d", resp.Len(), originalSize),
)
}
return trimmed return trimmed
} }
@ -1213,23 +1246,15 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st
} }
// serviceLookup is used to handle a service query // serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) { func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error {
out, err := d.lookupServiceNodes(cfg, lookup) out, err := d.lookupServiceNodes(cfg, lookup)
if err != nil { if err != nil {
d.logger.Error("rpc error", "error", err) return fmt.Errorf("rpc request failed: %w", err)
rCode := d.computeRCode(err)
if rCode == dns.RcodeNameError {
d.addSOA(cfg, resp)
}
resp.SetRcode(req, rCode)
return
} }
// If we have no nodes, return not found! // If we have no nodes, return not found!
if len(out.Nodes) == 0 { if len(out.Nodes) == 0 {
d.addSOA(cfg, resp) return errNameNotFound
resp.SetRcode(req, dns.RcodeNameError)
return
} }
// Perform a random shuffle // Perform a random shuffle
@ -1246,13 +1271,10 @@ func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, res
d.serviceNodeRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel) d.serviceNodeRecords(cfg, lookup.Datacenter, out.Nodes, req, resp, ttl, lookup.MaxRecursionLevel)
} }
d.trimDNSResponse(cfg, lookup.Network, req, resp) if len(resp.Answer) == 0 {
return errNoData
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(cfg, resp)
return
} }
return nil
} }
func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
@ -1273,7 +1295,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
} }
// preparedQueryLookup is used to handle a prepared query. // preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
// Execute the prepared query. // Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{ args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter, Datacenter: datacenter,
@ -1311,17 +1333,8 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
} }
out, err := d.lookupPreparedQuery(cfg, args) out, err := d.lookupPreparedQuery(cfg, args)
// If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information.
if err != nil { if err != nil {
rCode := d.computeRCode(err) return err
if rCode == dns.RcodeNameError {
d.addSOA(cfg, resp)
}
resp.SetRcode(req, rCode)
return
} }
// TODO (slackpad) - What's a safe limit we can set here? It seems like // TODO (slackpad) - What's a safe limit we can set here? It seems like
@ -1352,9 +1365,7 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
// If we have no nodes, return not found! // If we have no nodes, return not found!
if len(out.Nodes) == 0 { if len(out.Nodes) == 0 {
d.addSOA(cfg, resp) return errNameNotFound
resp.SetRcode(req, dns.RcodeNameError)
return
} }
// Add various responses depending on the request. // Add various responses depending on the request.
@ -1365,13 +1376,10 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} }
d.trimDNSResponse(cfg, network, req, resp) if len(resp.Answer) == 0 {
return errNoData
// If the answer is empty and the response isn't truncated, return not found
if len(resp.Answer) == 0 && !resp.Truncated {
d.addSOA(cfg, resp)
return
} }
return nil
} }
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
@ -1907,7 +1915,8 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel
resp := &dns.Msg{} resp := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY) req.SetQuestion(name, dns.TypeANY)
d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1) // TODO: handle error response
d.dispatch(nil, req, resp, maxRecursionLevel-1)
return resp.Answer return resp.Answer
} }

View File

@ -1,6 +1,7 @@
package agent package agent
import ( import (
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
@ -10,6 +11,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/serf/coordinate"
"github.com/miekg/dns"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/config"
agentdns "github.com/hashicorp/consul/agent/dns" agentdns "github.com/hashicorp/consul/agent/dns"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
@ -17,10 +23,6 @@ import (
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/serf/coordinate"
"github.com/miekg/dns"
"github.com/pascaldekloe/goe/verify"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -508,6 +510,7 @@ func TestDNS_NodeLookup_CNAME(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("google.node.consul.", dns.TypeANY) m.SetQuestion("google.node.consul.", dns.TypeANY)
m.SetEdns0(8192, true)
c := new(dns.Client) c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr()) in, _, err := c.Exchange(m, a.DNSAddr())
@ -871,7 +874,6 @@ func TestDNS_EDNS0_ECS(t *testing.T) {
require.True(t, ok) require.True(t, ok)
require.Equal(t, uint16(1), subnet.Family) require.Equal(t, uint16(1), subnet.Family)
require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask) require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask)
// scope set to 0 for a globally valid reply
require.Equal(t, tc.ExpectedScope, subnet.SourceScope) require.Equal(t, tc.ExpectedScope, subnet.SourceScope)
require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address) require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address)
}) })
@ -4180,6 +4182,7 @@ func TestBinarySearch(t *testing.T) {
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV) msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)
msg.Answer = msgSrc.Answer msg.Answer = msgSrc.Answer
msg.Extra = msgSrc.Extra msg.Extra = msgSrc.Extra
msg.Ns = msgSrc.Ns
index := make(map[string]dns.RR, len(msg.Extra)) index := make(map[string]dns.RR, len(msg.Extra))
indexRRs(msg.Extra, index) indexRRs(msg.Extra, index)
blen := dnsBinaryTruncate(msg, maxSize, index, true) blen := dnsBinaryTruncate(msg, maxSize, index, true)
@ -5969,9 +5972,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if len(in.Ns) != 1 { require.Len(t, in.Ns, 1)
t.Fatalf("Bad: %#v", in)
}
soaRec, ok := in.Ns[0].(*dns.SOA) soaRec, ok := in.Ns[0].(*dns.SOA)
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Ns[0]) t.Fatalf("Bad: %#v", in.Ns[0])
@ -5980,10 +5981,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
t.Fatalf("Bad: %#v", in.Ns[0]) t.Fatalf("Bad: %#v", in.Ns[0])
} }
if in.Rcode != dns.RcodeSuccess { require.Equal(t, dns.RcodeSuccess, in.Rcode)
t.Fatalf("Bad: %#v", in)
}
} }
// Check for ipv4 records on ipv6-only service directly and via the // Check for ipv4 records on ipv6-only service directly and via the
@ -6303,6 +6301,51 @@ func TestDNS_PreparedQuery_AgentSource(t *testing.T) {
} }
} }
func TestDNS_EDNS_Truncate_AgentSource(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := NewTestAgent(t, `
dns_config {
enable_truncate = true
}
`)
defer a.Shutdown()
a.DNSDisableCompression(true)
testrpc.WaitForLeader(t, a.RPC, "dc1")
m := MockPreparedQuery{
executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error {
// Check that the agent inserted its self-name and datacenter to
// the RPC request body.
if args.Agent.Datacenter != a.Config.Datacenter ||
args.Agent.Node != a.Config.NodeName {
t.Fatalf("bad: %#v", args.Agent)
}
for i := 0; i < 100; i++ {
reply.Nodes = append(reply.Nodes, structs.CheckServiceNode{Node: &structs.Node{Node: "apple", Address: fmt.Sprintf("node.address:%d", i)}, Service: &structs.NodeService{Service: "appleService", Address: fmt.Sprintf("service.address:%d", i)}})
}
return nil
},
}
if err := a.registerEndpoint("PreparedQuery", &m); err != nil {
t.Fatalf("err: %v", err)
}
req := new(dns.Msg)
req.SetQuestion("foo.query.consul.", dns.TypeSRV)
req.SetEdns0(2048, true)
req.Compress = false
c := new(dns.Client)
resp, _, err := c.Exchange(req, a.DNSAddr())
require.NoError(t, err)
require.True(t, resp.Len() < 2048)
}
func TestDNS_trimUDPResponse_NoTrim(t *testing.T) { func TestDNS_trimUDPResponse_NoTrim(t *testing.T) {
t.Parallel() t.Parallel()
req := &dns.Msg{} req := &dns.Msg{}
@ -6401,6 +6444,111 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) {
} }
} }
func TestDNS_trimUDPResponse_TrimLimitWithNS(t *testing.T) {
t.Parallel()
cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`)
req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
for i := 0; i < cfg.DNSUDPAnswerLimit+1; i++ {
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Target: target,
}
a := &dns.A{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)),
}
ns := &dns.SOA{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
},
Ns: fmt.Sprintf("soa-%d", i),
}
resp.Answer = append(resp.Answer, srv)
resp.Extra = append(resp.Extra, a)
resp.Ns = append(resp.Ns, ns)
if i < cfg.DNSUDPAnswerLimit {
expected.Answer = append(expected.Answer, srv)
expected.Extra = append(expected.Extra, a)
}
}
if trimmed := trimUDPResponse(req, resp, cfg.DNSUDPAnswerLimit); !trimmed {
t.Fatalf("Bad %#v", *resp)
}
require.LessOrEqual(t, resp.Len(), defaultMaxUDPSize)
require.Len(t, resp.Ns, 0)
}
func TestDNS_trimTCPResponse_TrimLimitWithNS(t *testing.T) {
t.Parallel()
cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`)
req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{}
for i := 0; i < 5000; i++ {
target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i)
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: "redis-cache-redis.service.consul.",
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Target: target,
}
a := &dns.A{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)),
}
ns := &dns.SOA{
Hdr: dns.RR_Header{
Name: target,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
},
Ns: fmt.Sprintf("soa-%d", i),
}
resp.Answer = append(resp.Answer, srv)
resp.Extra = append(resp.Extra, a)
resp.Ns = append(resp.Ns, ns)
if i < cfg.DNSUDPAnswerLimit {
expected.Answer = append(expected.Answer, srv)
expected.Extra = append(expected.Extra, a)
}
}
req.Question = append(req.Question, dns.Question{Qtype: dns.TypeSRV})
if trimmed := trimTCPResponse(req, resp); !trimmed {
t.Fatalf("Bad %#v", *resp)
}
require.LessOrEqual(t, resp.Len(), 65523)
require.Len(t, resp.Ns, 0)
}
func loadRuntimeConfig(t *testing.T, hcl string) *config.RuntimeConfig {
t.Helper()
cfg, warns, err := config.Load(config.BuilderOpts{HCL: []string{hcl}}, nil)
require.NoError(t, err)
require.Len(t, warns, 0)
return cfg
}
func TestDNS_trimUDPResponse_TrimSize(t *testing.T) { func TestDNS_trimUDPResponse_TrimSize(t *testing.T) {
t.Parallel() t.Parallel()
cfg := config.DefaultRuntimeConfig(`data_dir = "a" bind_addr = "127.0.0.1"`) cfg := config.DefaultRuntimeConfig(`data_dir = "a" bind_addr = "127.0.0.1"`)
@ -7151,3 +7299,19 @@ func TestDNS_ReloadConfig_DuringQuery(t *testing.T) {
} }
} }
} }
func TestECSNotGlobalError(t *testing.T) {
t.Run("wrap nil", func(t *testing.T) {
e := ecsNotGlobalError{}
require.True(t, errors.Is(e, errECSNotGlobal))
require.False(t, errors.Is(e, fmt.Errorf("some other error")))
require.Equal(t, nil, errors.Unwrap(e))
})
t.Run("wrap some error", func(t *testing.T) {
e := ecsNotGlobalError{error: errNameNotFound}
require.True(t, errors.Is(e, errECSNotGlobal))
require.False(t, errors.Is(e, fmt.Errorf("some other error")))
require.Equal(t, errNameNotFound, errors.Unwrap(e))
})
}