diff --git a/agent/dns.go b/agent/dns.go index 56ccc08260..8298316f57 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -3,6 +3,7 @@ package agent import ( "context" "encoding/hex" + "errors" "fmt" "net" "regexp" @@ -476,7 +477,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.Authoritative = true m.RecursionAvailable = (len(cfg.Recursors) > 0) - ecsGlobal := true + var err error switch req.Question[0].Qtype { case dns.TypeSOA: @@ -496,10 +497,10 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.SetRcode(req, dns.RcodeNotImplemented) default: - ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault) + err = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault) } - setEDNS(req, m, ecsGlobal) + setEDNS(req, m, !errors.Is(err, errECSNotGlobal)) // Write out the complete response if err := resp.WriteMsg(m); err != nil { @@ -601,9 +602,12 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool { } } +var errECSNotGlobal = fmt.Errorf("ECS response is not global") +var errNameNotFound = fmt.Errorf("DNS name not found") + // dispatch is used to parse a request and invoke the correct handler. // parameter maxRecursionLevel will handle whether recursive call can be performed -func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) bool { +func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error { // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter @@ -643,11 +647,11 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns } } - invalid := func() bool { + invalid := func() error { d.logger.Warn("QName invalid", "qname", qName) d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) - return true + return errNameNotFound } switch queryKind { @@ -684,7 +688,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns lookup.Service = queryParts[0][1:] // _name._tag.service.consul d.serviceLookup(cfg, lookup, req, resp) - return true + return nil } // Consul 0.3 and prior format for SRV queries @@ -699,7 +703,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns // tag[.tag].name.service.consul d.serviceLookup(cfg, lookup, req, resp) - return true + return nil case "connect": if len(queryParts) < 1 { @@ -721,7 +725,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns } // name.connect.consul d.serviceLookup(cfg, lookup, req, resp) - return true + return nil case "ingress": if len(queryParts) < 1 { @@ -743,7 +747,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns } // name.ingress.consul d.serviceLookup(cfg, lookup, req, resp) - return true + return nil case "node": if len(queryParts) < 1 { @@ -757,7 +761,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns // Allow a "." in the node name, just join all the parts node := strings.Join(queryParts, ".") d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) - return true + return nil case "query": // ensure we have a query name @@ -772,7 +776,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns // Allow a "." in the query name, just join all the parts. query := strings.Join(queryParts, ".") d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) - return false + return errECSNotGlobal case "addr": //
.addr.. - addr must be the second label, datacenter is optional @@ -825,7 +829,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns resp.Answer = append(resp.Answer, aaaaRecord) } } - return true + return nil default: return invalid() } @@ -1905,6 +1909,7 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel resp := &dns.Msg{} req.SetQuestion(name, dns.TypeANY) + // TODO: handle error response d.dispatch("udp", nil, req, resp, maxRecursionLevel-1) return resp.Answer