diff --git a/agent/dns.go b/agent/dns.go index 8298316f57..ac1f28242e 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -498,11 +498,17 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { default: err = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault) + rCode := rCodeFromError(err) + if rCode == dns.RcodeNameError || errors.Is(err, errNoAnswer) { + d.addSOA(cfg, m) + } + m.SetRcode(req, rCode) } setEDNS(req, m, !errors.Is(err, errECSNotGlobal)) - // Write out the complete response + //d.trimDNSResponse(cfg, network, req, m) + if err := resp.WriteMsg(m); err != nil { d.logger.Warn("failed to respond", "error", err) } @@ -604,6 +610,32 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { var errECSNotGlobal = fmt.Errorf("ECS response is not global") var errNameNotFound = fmt.Errorf("DNS name not found") +var errQueryRefused = fmt.Errorf("query refused") + +// errNoAnswer is used to indicate that the response should set SOA, and the +// success response code. +var errNoAnswer = fmt.Errorf("no DNS Answer") + +// ecsNotGlobalError may be used to wrap an error or nil, to indicate that the +// EDNS client subnet source scope is not global. +type ecsNotGlobalError struct { + error +} + +func (e ecsNotGlobalError) Error() string { + if e.error == nil { + return "" + } + return e.error.Error() +} + +func (e ecsNotGlobalError) Is(other error) bool { + return other == errECSNotGlobal +} + +func (e ecsNotGlobalError) Unwrap() error { + return e.error +} // dispatch is used to parse a request and invoke the correct handler. // parameter maxRecursionLevel will handle whether recursive call can be performed @@ -649,8 +681,6 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns invalid := func() error { d.logger.Warn("QName invalid", "qname", qName) - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) return errNameNotFound } @@ -687,8 +717,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns lookup.Tag = tag lookup.Service = queryParts[0][1:] // _name._tag.service.consul - d.serviceLookup(cfg, lookup, req, resp) - return nil + return d.serviceLookup(cfg, lookup, req, resp) } // Consul 0.3 and prior format for SRV queries @@ -702,8 +731,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns lookup.Service = queryParts[n-1] // tag[.tag].name.service.consul - d.serviceLookup(cfg, lookup, req, resp) - return nil + return d.serviceLookup(cfg, lookup, req, resp) case "connect": if len(queryParts) < 1 { @@ -724,8 +752,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns EnterpriseMeta: entMeta, } // name.connect.consul - d.serviceLookup(cfg, lookup, req, resp) - return nil + return d.serviceLookup(cfg, lookup, req, resp) case "ingress": if len(queryParts) < 1 { @@ -746,8 +773,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns EnterpriseMeta: entMeta, } // name.ingress.consul - d.serviceLookup(cfg, lookup, req, resp) - return nil + return d.serviceLookup(cfg, lookup, req, resp) case "node": if len(queryParts) < 1 { @@ -760,8 +786,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns // Allow a "." in the node name, just join all the parts node := strings.Join(queryParts, ".") - d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) - return nil + return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) case "query": // ensure we have a query name @@ -775,8 +800,8 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns // Allow a "." in the query name, just join all the parts. query := strings.Join(queryParts, ".") - d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) - return errECSNotGlobal + err := d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) + return ecsNotGlobalError{error: err} case "addr": //
.addr.. - addr must be the second label, datacenter is optional @@ -849,23 +874,33 @@ func (d *DNSServer) trimDomain(query string) string { return strings.TrimSuffix(query, shorter) } -// computeRCode Return the DNS Error code from Consul Error -func (d *DNSServer) computeRCode(err error) int { - if err == nil { +// rCodeFromError return the DNS Error code an error +func rCodeFromError(err error) int { + switch { + case err == nil: return dns.RcodeSuccess - } - if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) { + case errors.Is(err, errNoAnswer): + // TODO: why do we return success if the answer is empty? + return dns.RcodeSuccess + case errors.Is(err, errECSNotGlobal): + return rCodeFromError(errors.Unwrap(err)) + case errors.Is(err, errQueryRefused): + return dns.RcodeRefused + case errors.Is(err, errNameNotFound): return dns.RcodeNameError + case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err): + return dns.RcodeNameError + default: + return dns.RcodeServerFailure } - return dns.RcodeServerFailure } // nodeLookup is used to handle a node query -func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error { // Only handle ANY, A, AAAA, and TXT type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { - return + return errQueryRefused } // Make an RPC request @@ -879,20 +914,12 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, res } out, err := d.lookupNode(cfg, args) if err != nil { - d.logger.Error("rpc error", "error", err) - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return fmt.Errorf("failed rpc request: %w", err) } // If we have no out.NodeServices.Nodeaddress, return not found! if out.NodeServices == nil { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Add the node record @@ -914,6 +941,7 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, res metas := d.generateMeta(q.Name, n, cfg.NodeTTL) *metaTarget = append(*metaTarget, metas...) } + return nil } func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { @@ -1217,23 +1245,15 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st } // serviceLookup is used to handle a service query -func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) { +func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error { out, err := d.lookupServiceNodes(cfg, lookup) if err != nil { - d.logger.Error("rpc error", "error", err) - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return fmt.Errorf("rpc request failed: %w", err) } // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Perform a random shuffle @@ -1254,9 +1274,9 @@ func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, res // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(cfg, resp) - return + return errNoAnswer } + return nil } func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { @@ -1277,7 +1297,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { } // preparedQueryLookup is used to handle a prepared query. -func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ Datacenter: datacenter, @@ -1315,17 +1335,8 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que } out, err := d.lookupPreparedQuery(cfg, args) - - // If they give a bogus query name, treat that as a name error, - // not a full on server error. We have to use a string compare - // here since the RPC layer loses the type information. if err != nil { - rCode := d.computeRCode(err) - if rCode == dns.RcodeNameError { - d.addSOA(cfg, resp) - } - resp.SetRcode(req, rCode) - return + return err } // TODO (slackpad) - What's a safe limit we can set here? It seems like @@ -1356,9 +1367,7 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(cfg, resp) - resp.SetRcode(req, dns.RcodeNameError) - return + return errNameNotFound } // Add various responses depending on the request. @@ -1373,9 +1382,9 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(cfg, resp) - return + return errNoAnswer } + return nil } func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { diff --git a/agent/dns_test.go b/agent/dns_test.go index 5cfe582069..36f774b68e 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1,6 +1,7 @@ package agent import ( + "errors" "fmt" "math/rand" "net" @@ -935,7 +936,6 @@ func TestDNS_EDNS0_ECS(t *testing.T) { require.True(t, ok) require.Equal(t, uint16(1), subnet.Family) require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask) - // scope set to 0 for a globally valid reply require.Equal(t, tc.ExpectedScope, subnet.SourceScope) require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address) }) @@ -6391,9 +6391,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { - t.Fatalf("Bad: %#v", in) - } + require.Len(t, in.Ns, 1) soaRec, ok := in.Ns[0].(*dns.SOA) if !ok { t.Fatalf("Bad: %#v", in.Ns[0]) @@ -6402,10 +6400,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("Bad: %#v", in.Ns[0]) } - if in.Rcode != dns.RcodeSuccess { - t.Fatalf("Bad: %#v", in) - } - + require.Equal(t, dns.RcodeSuccess, in.Rcode) } // Check for ipv4 records on ipv6-only service directly and via the @@ -7625,3 +7620,19 @@ func TestDNS_ReloadConfig_DuringQuery(t *testing.T) { } } } + +func TestECSNotGlobalError(t *testing.T) { + t.Run("wrap nil", func(t *testing.T) { + e := ecsNotGlobalError{} + require.True(t, errors.Is(e, errECSNotGlobal)) + require.False(t, errors.Is(e, fmt.Errorf("some other error"))) + require.Equal(t, nil, errors.Unwrap(e)) + }) + + t.Run("wrap some error", func(t *testing.T) { + e := ecsNotGlobalError{error: errNameNotFound} + require.True(t, errors.Is(e, errECSNotGlobal)) + require.False(t, errors.Is(e, fmt.Errorf("some other error"))) + require.Equal(t, errNameNotFound, errors.Unwrap(e)) + }) +}