Avoid to have infinite recursion in DNS lookups when resolving CNAMEs (#4918)

* Avoid to have infinite recursion in DNS lookups when resolving CNAMEs

This will avoid killing Consul when a Service.Address is using CNAME
to a Consul CNAME that creates an infinite recursion.

This will fix https://github.com/hashicorp/consul/issues/4907

* Use maxRecursionLevel = 3 to allow several recursions
This commit is contained in:
Pierre Souchay 2019-01-07 22:53:54 +01:00 committed by Matt Keeler
parent 88d239818c
commit ae7f88f995
5 changed files with 108 additions and 37 deletions

View File

@ -933,6 +933,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
defer s1.Shutdown()
codec1 := rpcClient(t, s1)
defer codec1.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
dir2, s2 := testServerDCBootstrap(t, "dc1", false)
defer os.RemoveAll(dir2)
@ -980,7 +981,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
}
}
if !found {
t.Fatalf("failed to find foo")
t.Fatalf("failed to find foo in %#v", out.Nodes)
}
if out.QueryMeta.LastContact == 0 {
@ -2160,6 +2161,7 @@ func TestCatalog_NodeServices(t *testing.T) {
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
args := structs.NodeSpecificRequest{
Datacenter: "dc1",
@ -2213,7 +2215,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequestProxy(t)
@ -2244,7 +2246,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
// Register the service
args := structs.TestRegisterRequest(t)
@ -2392,6 +2394,7 @@ func TestCatalog_ListServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")
opt := structs.DCSpecificRequest{
Datacenter: "dc1",
@ -2473,7 +2476,7 @@ func TestCatalog_NodeServices_ACLDeny(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
// Prior to version 8, the node policy should be ignored.
args := structs.NodeSpecificRequest{
@ -2542,6 +2545,7 @@ func TestCatalog_NodeServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")
opt := structs.NodeSpecificRequest{
Datacenter: "dc1",

View File

@ -980,7 +980,7 @@ func TestLeader_ACL_Initialization(t *testing.T) {
dir1, s1 := testServerWithConfig(t, conf)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
if tt.master != "" {
_, master, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.master)
@ -1153,7 +1153,7 @@ func TestLeader_ACLUpgrade(t *testing.T) {
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
codec := rpcClient(t, s1)
defer codec.Close()

View File

@ -29,6 +29,7 @@ const (
// times.
maxUDPAnswerLimit = 8
maxRecurseRecords = 5
maxRecursionLevelDefault = 3
// Increment a counter when requests staler than this are served
staleCounterThreshold = 5 * time.Second
@ -365,14 +366,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
switch req.Question[0].Qtype {
case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa())
m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess)
case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = ns
m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess)
@ -418,8 +419,8 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {
// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
@ -456,7 +457,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
}
ns = append(ns, nsrr)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel)
extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT {
extra = append(extra, meta...)
@ -473,6 +474,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault)
}
// doDispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) {
ecsGlobal = true
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
@ -519,7 +526,7 @@ PARSE:
}
// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)
// Consul 0.3 and prior format for SRV queries
} else {
@ -531,7 +538,7 @@ PARSE:
}
// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
}
case "connect":
@ -540,7 +547,7 @@ PARSE:
}
// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)
case "node":
if n == 1 {
@ -549,7 +556,7 @@ PARSE:
// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel)
case "query":
if n == 1 {
@ -559,7 +566,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp)
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)
case "addr":
if n != 2 {
@ -632,7 +639,7 @@ INVALID:
}
// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
@ -678,7 +685,7 @@ RPC:
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
@ -715,7 +722,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records, meta []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int) (records, meta []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
@ -761,7 +768,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec)
// Recurse
more := d.resolveCNAME(cnRec.Target)
more := d.resolveCNAME(cnRec.Target, maxRecursionLevel)
extra := 0
MORE_REC:
for _, rr := range more {
@ -1004,7 +1011,7 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
}
// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
@ -1042,8 +1049,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
@ -1066,9 +1073,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// Add various responses depending on the request
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}
d.trimDNSResponse(network, req, resp)
@ -1098,7 +1105,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}
// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg) {
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
@ -1195,9 +1202,9 @@ RPC:
// Add various responses depending on the request.
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}
d.trimDNSResponse(network, req, resp)
@ -1210,7 +1217,7 @@ RPC:
}
// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
@ -1241,7 +1248,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
// Add the node record
had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel)
if records != nil {
switch records[0].(type) {
case *dns.CNAME:
@ -1323,7 +1330,7 @@ func findWeight(node structs.CheckServiceNode) int {
}
// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil
@ -1360,7 +1367,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}
// Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
@ -1457,16 +1464,21 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}
// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR {
func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain is
// already converted
if strings.HasSuffix(strings.ToLower(name), "."+d.domain) {
if maxRecursionLevel < 1 {
d.logger.Printf("[ERR] dns: Infinite recursion detected for %s, won't perform any CNAME resolution.", name)
return nil
}
req := &dns.Msg{}
resp := &dns.Msg{}
req.SetQuestion(name, dns.TypeANY)
d.dispatch("udp", nil, req, resp)
d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1)
return resp.Answer
}

View File

@ -1683,6 +1683,61 @@ func TestDNS_ExternalServiceLookup(t *testing.T) {
}
}
func TestDNS_InifiniteRecursion(t *testing.T) {
// This test should not create an infinite recursion
t.Parallel()
a := NewTestAgent(t.Name(), `
domain = "CONSUL."
node_name = "test node"
`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// Register the initial node with a service
{
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: "web",
Address: "web.service.consul.",
Service: &structs.NodeService{
Service: "web",
Port: 12345,
Address: "web.service.consul.",
},
}
var out struct{}
if err := a.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
// Look up the service directly
questions := []string{
"web.service.consul.",
}
for _, question := range questions {
m := new(dns.Msg)
m.SetQuestion(question, dns.TypeA)
c := new(dns.Client)
in, _, err := c.Exchange(m, a.DNSAddr())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) < 1 {
t.Fatalf("Bad: %#v", in)
}
aRec, ok := in.Answer[0].(*dns.CNAME)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if aRec.Target != "web.service.consul." {
t.Fatalf("Bad: %#v, target:=%s", aRec, aRec.Target)
}
}
}
func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `

View File

@ -397,7 +397,7 @@ func TestSessionTTLRenew(t *testing.T) {
}
// Sleep to consume some time before renew
time.Sleep(ttl * (structs.SessionTTLMultiplier / 2))
time.Sleep(ttl * (structs.SessionTTLMultiplier / 3))
req, _ = http.NewRequest("PUT", "/v1/session/renew/"+id, nil)
resp = httptest.NewRecorder()