mirror of https://github.com/status-im/consul.git
dns: handle errors from dispatch
This commit is contained in:
parent
9267b09c32
commit
436a02af31
133
agent/dns.go
133
agent/dns.go
|
@ -498,11 +498,17 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
|
err = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
|
||||||
|
rCode := rCodeFromError(err)
|
||||||
|
if rCode == dns.RcodeNameError || errors.Is(err, errNoAnswer) {
|
||||||
|
d.addSOA(cfg, m)
|
||||||
|
}
|
||||||
|
m.SetRcode(req, rCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
setEDNS(req, m, !errors.Is(err, errECSNotGlobal))
|
setEDNS(req, m, !errors.Is(err, errECSNotGlobal))
|
||||||
|
|
||||||
// Write out the complete response
|
//d.trimDNSResponse(cfg, network, req, m)
|
||||||
|
|
||||||
if err := resp.WriteMsg(m); err != nil {
|
if err := resp.WriteMsg(m); err != nil {
|
||||||
d.logger.Warn("failed to respond", "error", err)
|
d.logger.Warn("failed to respond", "error", err)
|
||||||
}
|
}
|
||||||
|
@ -604,6 +610,32 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
|
||||||
|
|
||||||
var errECSNotGlobal = fmt.Errorf("ECS response is not global")
|
var errECSNotGlobal = fmt.Errorf("ECS response is not global")
|
||||||
var errNameNotFound = fmt.Errorf("DNS name not found")
|
var errNameNotFound = fmt.Errorf("DNS name not found")
|
||||||
|
var errQueryRefused = fmt.Errorf("query refused")
|
||||||
|
|
||||||
|
// errNoAnswer is used to indicate that the response should set SOA, and the
|
||||||
|
// success response code.
|
||||||
|
var errNoAnswer = fmt.Errorf("no DNS Answer")
|
||||||
|
|
||||||
|
// ecsNotGlobalError may be used to wrap an error or nil, to indicate that the
|
||||||
|
// EDNS client subnet source scope is not global.
|
||||||
|
type ecsNotGlobalError struct {
|
||||||
|
error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ecsNotGlobalError) Error() string {
|
||||||
|
if e.error == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return e.error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ecsNotGlobalError) Is(other error) bool {
|
||||||
|
return other == errECSNotGlobal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ecsNotGlobalError) Unwrap() error {
|
||||||
|
return e.error
|
||||||
|
}
|
||||||
|
|
||||||
// dispatch is used to parse a request and invoke the correct handler.
|
// dispatch is used to parse a request and invoke the correct handler.
|
||||||
// parameter maxRecursionLevel will handle whether recursive call can be performed
|
// parameter maxRecursionLevel will handle whether recursive call can be performed
|
||||||
|
@ -649,8 +681,6 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
|
|
||||||
invalid := func() error {
|
invalid := func() error {
|
||||||
d.logger.Warn("QName invalid", "qname", qName)
|
d.logger.Warn("QName invalid", "qname", qName)
|
||||||
d.addSOA(cfg, resp)
|
|
||||||
resp.SetRcode(req, dns.RcodeNameError)
|
|
||||||
return errNameNotFound
|
return errNameNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -687,8 +717,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
lookup.Tag = tag
|
lookup.Tag = tag
|
||||||
lookup.Service = queryParts[0][1:]
|
lookup.Service = queryParts[0][1:]
|
||||||
// _name._tag.service.consul
|
// _name._tag.service.consul
|
||||||
d.serviceLookup(cfg, lookup, req, resp)
|
return d.serviceLookup(cfg, lookup, req, resp)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consul 0.3 and prior format for SRV queries
|
// Consul 0.3 and prior format for SRV queries
|
||||||
|
@ -702,8 +731,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
lookup.Service = queryParts[n-1]
|
lookup.Service = queryParts[n-1]
|
||||||
|
|
||||||
// tag[.tag].name.service.consul
|
// tag[.tag].name.service.consul
|
||||||
d.serviceLookup(cfg, lookup, req, resp)
|
return d.serviceLookup(cfg, lookup, req, resp)
|
||||||
return nil
|
|
||||||
|
|
||||||
case "connect":
|
case "connect":
|
||||||
if len(queryParts) < 1 {
|
if len(queryParts) < 1 {
|
||||||
|
@ -724,8 +752,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
EnterpriseMeta: entMeta,
|
EnterpriseMeta: entMeta,
|
||||||
}
|
}
|
||||||
// name.connect.consul
|
// name.connect.consul
|
||||||
d.serviceLookup(cfg, lookup, req, resp)
|
return d.serviceLookup(cfg, lookup, req, resp)
|
||||||
return nil
|
|
||||||
|
|
||||||
case "ingress":
|
case "ingress":
|
||||||
if len(queryParts) < 1 {
|
if len(queryParts) < 1 {
|
||||||
|
@ -746,8 +773,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
EnterpriseMeta: entMeta,
|
EnterpriseMeta: entMeta,
|
||||||
}
|
}
|
||||||
// name.ingress.consul
|
// name.ingress.consul
|
||||||
d.serviceLookup(cfg, lookup, req, resp)
|
return d.serviceLookup(cfg, lookup, req, resp)
|
||||||
return nil
|
|
||||||
|
|
||||||
case "node":
|
case "node":
|
||||||
if len(queryParts) < 1 {
|
if len(queryParts) < 1 {
|
||||||
|
@ -760,8 +786,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
|
|
||||||
// Allow a "." in the node name, just join all the parts
|
// Allow a "." in the node name, just join all the parts
|
||||||
node := strings.Join(queryParts, ".")
|
node := strings.Join(queryParts, ".")
|
||||||
d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel)
|
return d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel)
|
||||||
return nil
|
|
||||||
|
|
||||||
case "query":
|
case "query":
|
||||||
// ensure we have a query name
|
// ensure we have a query name
|
||||||
|
@ -775,8 +800,8 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
|
||||||
|
|
||||||
// Allow a "." in the query name, just join all the parts.
|
// Allow a "." in the query name, just join all the parts.
|
||||||
query := strings.Join(queryParts, ".")
|
query := strings.Join(queryParts, ".")
|
||||||
d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
|
err := d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
|
||||||
return errECSNotGlobal
|
return ecsNotGlobalError{error: err}
|
||||||
|
|
||||||
case "addr":
|
case "addr":
|
||||||
// <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional
|
// <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional
|
||||||
|
@ -849,23 +874,33 @@ func (d *DNSServer) trimDomain(query string) string {
|
||||||
return strings.TrimSuffix(query, shorter)
|
return strings.TrimSuffix(query, shorter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeRCode Return the DNS Error code from Consul Error
|
// rCodeFromError return the DNS Error code an error
|
||||||
func (d *DNSServer) computeRCode(err error) int {
|
func rCodeFromError(err error) int {
|
||||||
if err == nil {
|
switch {
|
||||||
|
case err == nil:
|
||||||
return dns.RcodeSuccess
|
return dns.RcodeSuccess
|
||||||
}
|
case errors.Is(err, errNoAnswer):
|
||||||
if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) {
|
// TODO: why do we return success if the answer is empty?
|
||||||
|
return dns.RcodeSuccess
|
||||||
|
case errors.Is(err, errECSNotGlobal):
|
||||||
|
return rCodeFromError(errors.Unwrap(err))
|
||||||
|
case errors.Is(err, errQueryRefused):
|
||||||
|
return dns.RcodeRefused
|
||||||
|
case errors.Is(err, errNameNotFound):
|
||||||
return dns.RcodeNameError
|
return dns.RcodeNameError
|
||||||
|
case structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err):
|
||||||
|
return dns.RcodeNameError
|
||||||
|
default:
|
||||||
|
return dns.RcodeServerFailure
|
||||||
}
|
}
|
||||||
return dns.RcodeServerFailure
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// nodeLookup is used to handle a node query
|
// nodeLookup is used to handle a node query
|
||||||
func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
|
func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) error {
|
||||||
// Only handle ANY, A, AAAA, and TXT type requests
|
// Only handle ANY, A, AAAA, and TXT type requests
|
||||||
qType := req.Question[0].Qtype
|
qType := req.Question[0].Qtype
|
||||||
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
|
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
|
||||||
return
|
return errQueryRefused
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make an RPC request
|
// Make an RPC request
|
||||||
|
@ -879,20 +914,12 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, res
|
||||||
}
|
}
|
||||||
out, err := d.lookupNode(cfg, args)
|
out, err := d.lookupNode(cfg, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.logger.Error("rpc error", "error", err)
|
return fmt.Errorf("failed rpc request: %w", err)
|
||||||
rCode := d.computeRCode(err)
|
|
||||||
if rCode == dns.RcodeNameError {
|
|
||||||
d.addSOA(cfg, resp)
|
|
||||||
}
|
|
||||||
resp.SetRcode(req, rCode)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have no out.NodeServices.Nodeaddress, return not found!
|
// If we have no out.NodeServices.Nodeaddress, return not found!
|
||||||
if out.NodeServices == nil {
|
if out.NodeServices == nil {
|
||||||
d.addSOA(cfg, resp)
|
return errNameNotFound
|
||||||
resp.SetRcode(req, dns.RcodeNameError)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the node record
|
// Add the node record
|
||||||
|
@ -914,6 +941,7 @@ func (d *DNSServer) nodeLookup(cfg *dnsConfig, datacenter, node string, req, res
|
||||||
metas := d.generateMeta(q.Name, n, cfg.NodeTTL)
|
metas := d.generateMeta(q.Name, n, cfg.NodeTTL)
|
||||||
*metaTarget = append(*metaTarget, metas...)
|
*metaTarget = append(*metaTarget, metas...)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
|
func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) {
|
||||||
|
@ -1217,23 +1245,15 @@ func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, lookup serviceLookup) (st
|
||||||
}
|
}
|
||||||
|
|
||||||
// serviceLookup is used to handle a service query
|
// serviceLookup is used to handle a service query
|
||||||
func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) {
|
func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, resp *dns.Msg) error {
|
||||||
out, err := d.lookupServiceNodes(cfg, lookup)
|
out, err := d.lookupServiceNodes(cfg, lookup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.logger.Error("rpc error", "error", err)
|
return fmt.Errorf("rpc request failed: %w", err)
|
||||||
rCode := d.computeRCode(err)
|
|
||||||
if rCode == dns.RcodeNameError {
|
|
||||||
d.addSOA(cfg, resp)
|
|
||||||
}
|
|
||||||
resp.SetRcode(req, rCode)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have no nodes, return not found!
|
// If we have no nodes, return not found!
|
||||||
if len(out.Nodes) == 0 {
|
if len(out.Nodes) == 0 {
|
||||||
d.addSOA(cfg, resp)
|
return errNameNotFound
|
||||||
resp.SetRcode(req, dns.RcodeNameError)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform a random shuffle
|
// Perform a random shuffle
|
||||||
|
@ -1254,9 +1274,9 @@ func (d *DNSServer) serviceLookup(cfg *dnsConfig, lookup serviceLookup, req, res
|
||||||
|
|
||||||
// If the answer is empty and the response isn't truncated, return not found
|
// If the answer is empty and the response isn't truncated, return not found
|
||||||
if len(resp.Answer) == 0 && !resp.Truncated {
|
if len(resp.Answer) == 0 && !resp.Truncated {
|
||||||
d.addSOA(cfg, resp)
|
return errNoAnswer
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
|
func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
|
||||||
|
@ -1277,7 +1297,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
|
||||||
}
|
}
|
||||||
|
|
||||||
// preparedQueryLookup is used to handle a prepared query.
|
// preparedQueryLookup is used to handle a prepared query.
|
||||||
func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
|
func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
|
||||||
// Execute the prepared query.
|
// Execute the prepared query.
|
||||||
args := structs.PreparedQueryExecuteRequest{
|
args := structs.PreparedQueryExecuteRequest{
|
||||||
Datacenter: datacenter,
|
Datacenter: datacenter,
|
||||||
|
@ -1315,17 +1335,8 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := d.lookupPreparedQuery(cfg, args)
|
out, err := d.lookupPreparedQuery(cfg, args)
|
||||||
|
|
||||||
// If they give a bogus query name, treat that as a name error,
|
|
||||||
// not a full on server error. We have to use a string compare
|
|
||||||
// here since the RPC layer loses the type information.
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rCode := d.computeRCode(err)
|
return err
|
||||||
if rCode == dns.RcodeNameError {
|
|
||||||
d.addSOA(cfg, resp)
|
|
||||||
}
|
|
||||||
resp.SetRcode(req, rCode)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (slackpad) - What's a safe limit we can set here? It seems like
|
// TODO (slackpad) - What's a safe limit we can set here? It seems like
|
||||||
|
@ -1356,9 +1367,7 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
|
||||||
|
|
||||||
// If we have no nodes, return not found!
|
// If we have no nodes, return not found!
|
||||||
if len(out.Nodes) == 0 {
|
if len(out.Nodes) == 0 {
|
||||||
d.addSOA(cfg, resp)
|
return errNameNotFound
|
||||||
resp.SetRcode(req, dns.RcodeNameError)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add various responses depending on the request.
|
// Add various responses depending on the request.
|
||||||
|
@ -1373,9 +1382,9 @@ func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, que
|
||||||
|
|
||||||
// If the answer is empty and the response isn't truncated, return not found
|
// If the answer is empty and the response isn't truncated, return not found
|
||||||
if len(resp.Answer) == 0 && !resp.Truncated {
|
if len(resp.Answer) == 0 && !resp.Truncated {
|
||||||
d.addSOA(cfg, resp)
|
return errNoAnswer
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
|
func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
@ -935,7 +936,6 @@ func TestDNS_EDNS0_ECS(t *testing.T) {
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, uint16(1), subnet.Family)
|
require.Equal(t, uint16(1), subnet.Family)
|
||||||
require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask)
|
require.Equal(t, tc.SourceNetmask, subnet.SourceNetmask)
|
||||||
// scope set to 0 for a globally valid reply
|
|
||||||
require.Equal(t, tc.ExpectedScope, subnet.SourceScope)
|
require.Equal(t, tc.ExpectedScope, subnet.SourceScope)
|
||||||
require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address)
|
require.Equal(t, net.ParseIP(tc.SubnetAddr), subnet.Address)
|
||||||
})
|
})
|
||||||
|
@ -6391,9 +6391,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(in.Ns) != 1 {
|
require.Len(t, in.Ns, 1)
|
||||||
t.Fatalf("Bad: %#v", in)
|
|
||||||
}
|
|
||||||
soaRec, ok := in.Ns[0].(*dns.SOA)
|
soaRec, ok := in.Ns[0].(*dns.SOA)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Bad: %#v", in.Ns[0])
|
t.Fatalf("Bad: %#v", in.Ns[0])
|
||||||
|
@ -6402,10 +6400,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) {
|
||||||
t.Fatalf("Bad: %#v", in.Ns[0])
|
t.Fatalf("Bad: %#v", in.Ns[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if in.Rcode != dns.RcodeSuccess {
|
require.Equal(t, dns.RcodeSuccess, in.Rcode)
|
||||||
t.Fatalf("Bad: %#v", in)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for ipv4 records on ipv6-only service directly and via the
|
// Check for ipv4 records on ipv6-only service directly and via the
|
||||||
|
@ -7625,3 +7620,19 @@ func TestDNS_ReloadConfig_DuringQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestECSNotGlobalError(t *testing.T) {
|
||||||
|
t.Run("wrap nil", func(t *testing.T) {
|
||||||
|
e := ecsNotGlobalError{}
|
||||||
|
require.True(t, errors.Is(e, errECSNotGlobal))
|
||||||
|
require.False(t, errors.Is(e, fmt.Errorf("some other error")))
|
||||||
|
require.Equal(t, nil, errors.Unwrap(e))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wrap some error", func(t *testing.T) {
|
||||||
|
e := ecsNotGlobalError{error: errNameNotFound}
|
||||||
|
require.True(t, errors.Is(e, errECSNotGlobal))
|
||||||
|
require.False(t, errors.Is(e, fmt.Errorf("some other error")))
|
||||||
|
require.Equal(t, errNameNotFound, errors.Unwrap(e))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue