dns: error response from dispatch

So that dispatch can communicate status back to the caller.
This commit is contained in:
Daniel Nephin 2021-04-13 16:07:10 -04:00
parent 68d6f1315f
commit 9267b09c32
1 changed files with 18 additions and 13 deletions

View File

@ -3,6 +3,7 @@ package agent
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"regexp" "regexp"
@ -476,7 +477,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.Authoritative = true m.Authoritative = true
m.RecursionAvailable = (len(cfg.Recursors) > 0) m.RecursionAvailable = (len(cfg.Recursors) > 0)
ecsGlobal := true var err error
switch req.Question[0].Qtype { switch req.Question[0].Qtype {
case dns.TypeSOA: case dns.TypeSOA:
@ -496,10 +497,10 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetRcode(req, dns.RcodeNotImplemented) m.SetRcode(req, dns.RcodeNotImplemented)
default: default:
ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault) err = d.dispatch(network, resp.RemoteAddr(), req, m, maxRecursionLevelDefault)
} }
setEDNS(req, m, ecsGlobal) setEDNS(req, m, !errors.Is(err, errECSNotGlobal))
// Write out the complete response // Write out the complete response
if err := resp.WriteMsg(m); err != nil { if err := resp.WriteMsg(m); err != nil {
@ -601,9 +602,12 @@ func (d *DNSServer) parseDatacenter(labels []string, datacenter *string) bool {
} }
} }
var errECSNotGlobal = fmt.Errorf("ECS response is not global")
var errNameNotFound = fmt.Errorf("DNS name not found")
// dispatch is used to parse a request and invoke the correct handler. // dispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed // parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) bool { func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) error {
// By default the query is in the default datacenter // By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter datacenter := d.agent.config.Datacenter
@ -643,11 +647,11 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
} }
} }
invalid := func() bool { invalid := func() error {
d.logger.Warn("QName invalid", "qname", qName) d.logger.Warn("QName invalid", "qname", qName)
d.addSOA(cfg, resp) d.addSOA(cfg, resp)
resp.SetRcode(req, dns.RcodeNameError) resp.SetRcode(req, dns.RcodeNameError)
return true return errNameNotFound
} }
switch queryKind { switch queryKind {
@ -684,7 +688,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
lookup.Service = queryParts[0][1:] lookup.Service = queryParts[0][1:]
// _name._tag.service.consul // _name._tag.service.consul
d.serviceLookup(cfg, lookup, req, resp) d.serviceLookup(cfg, lookup, req, resp)
return true return nil
} }
// Consul 0.3 and prior format for SRV queries // Consul 0.3 and prior format for SRV queries
@ -699,7 +703,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
// tag[.tag].name.service.consul // tag[.tag].name.service.consul
d.serviceLookup(cfg, lookup, req, resp) d.serviceLookup(cfg, lookup, req, resp)
return true return nil
case "connect": case "connect":
if len(queryParts) < 1 { if len(queryParts) < 1 {
@ -721,7 +725,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
} }
// name.connect.consul // name.connect.consul
d.serviceLookup(cfg, lookup, req, resp) d.serviceLookup(cfg, lookup, req, resp)
return true return nil
case "ingress": case "ingress":
if len(queryParts) < 1 { if len(queryParts) < 1 {
@ -743,7 +747,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
} }
// name.ingress.consul // name.ingress.consul
d.serviceLookup(cfg, lookup, req, resp) d.serviceLookup(cfg, lookup, req, resp)
return true return nil
case "node": case "node":
if len(queryParts) < 1 { if len(queryParts) < 1 {
@ -757,7 +761,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
// Allow a "." in the node name, just join all the parts // Allow a "." in the node name, just join all the parts
node := strings.Join(queryParts, ".") node := strings.Join(queryParts, ".")
d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel) d.nodeLookup(cfg, datacenter, node, req, resp, maxRecursionLevel)
return true return nil
case "query": case "query":
// ensure we have a query name // ensure we have a query name
@ -772,7 +776,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
// Allow a "." in the query name, just join all the parts. // Allow a "." in the query name, just join all the parts.
query := strings.Join(queryParts, ".") query := strings.Join(queryParts, ".")
d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
return false return errECSNotGlobal
case "addr": case "addr":
// <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional // <address>.addr.<suffixes>.<domain> - addr must be the second label, datacenter is optional
@ -825,7 +829,7 @@ func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns
resp.Answer = append(resp.Answer, aaaaRecord) resp.Answer = append(resp.Answer, aaaaRecord)
} }
} }
return true return nil
default: default:
return invalid() return invalid()
} }
@ -1905,6 +1909,7 @@ func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel
resp := &dns.Msg{} resp := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY) req.SetQuestion(name, dns.TypeANY)
// TODO: handle error response
d.dispatch("udp", nil, req, resp, maxRecursionLevel-1) d.dispatch("udp", nil, req, resp, maxRecursionLevel-1)
return resp.Answer return resp.Answer