diff --git a/agent/consul/catalog_endpoint_test.go b/agent/consul/catalog_endpoint_test.go index 70e0fa95d1..e86b8122cc 100644 --- a/agent/consul/catalog_endpoint_test.go +++ b/agent/consul/catalog_endpoint_test.go @@ -933,6 +933,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) { defer s1.Shutdown() codec1 := rpcClient(t, s1) defer codec1.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") dir2, s2 := testServerDCBootstrap(t, "dc1", false) defer os.RemoveAll(dir2) @@ -980,7 +981,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) { } } if !found { - t.Fatalf("failed to find foo") + t.Fatalf("failed to find foo in %#v", out.Nodes) } if out.QueryMeta.LastContact == 0 { @@ -2160,6 +2161,7 @@ func TestCatalog_NodeServices(t *testing.T) { defer s1.Shutdown() codec := rpcClient(t, s1) defer codec.Close() + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") args := structs.NodeSpecificRequest{ Datacenter: "dc1", @@ -2213,7 +2215,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Register the service args := structs.TestRegisterRequestProxy(t) @@ -2244,7 +2246,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Register the service args := structs.TestRegisterRequest(t) @@ -2392,6 +2394,7 @@ func TestCatalog_ListServices_FilterACL(t *testing.T) { defer os.RemoveAll(dir) defer srv.Shutdown() defer codec.Close() + testrpc.WaitForTestAgent(t, srv.RPC, "dc1") opt := structs.DCSpecificRequest{ Datacenter: "dc1", @@ -2473,7 +2476,7 @@ func TestCatalog_NodeServices_ACLDeny(t *testing.T) { codec := rpcClient(t, s1) defer codec.Close() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") // Prior to version 8, the node policy should be ignored. args := structs.NodeSpecificRequest{ @@ -2542,6 +2545,7 @@ func TestCatalog_NodeServices_FilterACL(t *testing.T) { defer os.RemoveAll(dir) defer srv.Shutdown() defer codec.Close() + testrpc.WaitForTestAgent(t, srv.RPC, "dc1") opt := structs.NodeSpecificRequest{ Datacenter: "dc1", diff --git a/agent/consul/leader_test.go b/agent/consul/leader_test.go index 6ddef55c13..75a6358278 100644 --- a/agent/consul/leader_test.go +++ b/agent/consul/leader_test.go @@ -980,7 +980,7 @@ func TestLeader_ACL_Initialization(t *testing.T) { dir1, s1 := testServerWithConfig(t, conf) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") if tt.master != "" { _, master, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.master) @@ -1153,7 +1153,7 @@ func TestLeader_ACLUpgrade(t *testing.T) { }) defer os.RemoveAll(dir1) defer s1.Shutdown() - testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForTestAgent(t, s1.RPC, "dc1") codec := rpcClient(t, s1) defer codec.Close() diff --git a/agent/dns.go b/agent/dns.go index 5775e15a52..21f0cadc31 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -27,8 +27,9 @@ const ( // records. Limit further to prevent unintentional configuration // abuse that would have a negative effect on application response // times. - maxUDPAnswerLimit = 8 - maxRecurseRecords = 5 + maxUDPAnswerLimit = 8 + maxRecurseRecords = 5 + maxRecursionLevelDefault = 3 // Increment a counter when requests staler than this are served staleCounterThreshold = 5 * time.Second @@ -365,14 +366,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { switch req.Question[0].Qtype { case dns.TypeSOA: - ns, glue := d.nameservers(req.IsEdns0() != nil) + ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) m.Answer = append(m.Answer, d.soa()) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, glue := d.nameservers(req.IsEdns0() != nil) + ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) m.Answer = ns m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) @@ -418,8 +419,8 @@ func (d *DNSServer) addSOA(msg *dns.Msg) { // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. -func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { - out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false) +func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { + out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel) if err != nil { d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err) return nil, nil @@ -456,7 +457,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { } ns = append(ns, nsrr) - glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns) + glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel) extra = append(extra, glue...) if meta != nil && d.config.NodeMetaTXT { extra = append(extra, meta...) @@ -473,6 +474,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { // dispatch is used to parse a request and invoke the correct handler func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) { + return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault) +} + +// doDispatch is used to parse a request and invoke the correct handler. +// parameter maxRecursionLevel will handle whether recursive call can be performed +func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) { ecsGlobal = true // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter @@ -519,7 +526,7 @@ PARSE: } // _name._tag.service.consul - d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp) + d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel) // Consul 0.3 and prior format for SRV queries } else { @@ -531,7 +538,7 @@ PARSE: } // tag[.tag].name.service.consul - d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp) + d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel) } case "connect": @@ -540,7 +547,7 @@ PARSE: } // name.connect.consul - d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp) + d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel) case "node": if n == 1 { @@ -549,7 +556,7 @@ PARSE: // Allow a "." in the node name, just join all the parts node := strings.Join(labels[:n-1], ".") - d.nodeLookup(network, datacenter, node, req, resp) + d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel) case "query": if n == 1 { @@ -559,7 +566,7 @@ PARSE: // Allow a "." in the query name, just join all the parts. query := strings.Join(labels[:n-1], ".") ecsGlobal = false - d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp) + d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) case "addr": if n != 2 { @@ -632,7 +639,7 @@ INVALID: } // nodeLookup is used to handle a node query -func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) { +func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { // Only handle ANY, A, AAAA, and TXT type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { @@ -678,7 +685,7 @@ RPC: n := out.NodeServices.Node edns := req.IsEdns0() != nil addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) - records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns) + records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel) if records != nil { resp.Answer = append(resp.Answer, records...) } @@ -715,7 +722,7 @@ func encodeKVasRFC1464(key, value string) (txt string) { // The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node // and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the // generated RRs should go and if they should be used at all. -func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records, meta []dns.RR) { +func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int) (records, meta []dns.RR) { // Parse the IP ip := net.ParseIP(addr) var ipv4 net.IP @@ -761,7 +768,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy records = append(records, cnRec) // Recurse - more := d.resolveCNAME(cnRec.Target) + more := d.resolveCNAME(cnRec.Target, maxRecursionLevel) extra := 0 MORE_REC: for _, rr := range more { @@ -1004,7 +1011,7 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed } // lookupServiceNodes returns nodes with a given service. -func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) { +func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) { args := structs.ServiceSpecificRequest{ Connect: connect, Datacenter: datacenter, @@ -1042,8 +1049,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect } // serviceLookup is used to handle a service query -func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) { - out, err := d.lookupServiceNodes(datacenter, service, tag, connect) +func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) { + out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel) if err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) @@ -1066,9 +1073,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn // Add various responses depending on the request qType := req.Question[0].Qtype if qType == dns.TypeSRV { - d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl) + d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } else { - d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl) + d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } d.trimDNSResponse(network, req, resp) @@ -1098,7 +1105,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { } // preparedQueryLookup is used to handle a prepared query. -func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg) { +func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ Datacenter: datacenter, @@ -1195,9 +1202,9 @@ RPC: // Add various responses depending on the request. qType := req.Question[0].Qtype if qType == dns.TypeSRV { - d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl) + d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } else { - d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl) + d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } d.trimDNSResponse(network, req, resp) @@ -1210,7 +1217,7 @@ RPC: } // serviceNodeRecords is used to add the node records for a service lookup -func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) { +func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { qName := req.Question[0].Name qType := req.Question[0].Qtype handled := make(map[string]struct{}) @@ -1241,7 +1248,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode // Add the node record had_answer := false - records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns) + records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel) if records != nil { switch records[0].(type) { case *dns.CNAME: @@ -1323,7 +1330,7 @@ func findWeight(node structs.CheckServiceNode) int { } // serviceARecords is used to add the SRV records for a service lookup -func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) { +func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { handled := make(map[string]struct{}) edns := req.IsEdns0() != nil @@ -1360,7 +1367,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns) + records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel) if len(records) > 0 { // Use the node address if it doesn't differ from the service address if addr == node.Node.Address { @@ -1457,16 +1464,21 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { } // resolveCNAME is used to recursively resolve CNAME records -func (d *DNSServer) resolveCNAME(name string) []dns.RR { +func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR { // If the CNAME record points to a Consul address, resolve it internally // Convert query to lowercase because DNS is case insensitive; d.domain is // already converted + if strings.HasSuffix(strings.ToLower(name), "."+d.domain) { + if maxRecursionLevel < 1 { + d.logger.Printf("[ERR] dns: Infinite recursion detected for %s, won't perform any CNAME resolution.", name) + return nil + } req := &dns.Msg{} resp := &dns.Msg{} req.SetQuestion(name, dns.TypeANY) - d.dispatch("udp", nil, req, resp) + d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1) return resp.Answer } diff --git a/agent/dns_test.go b/agent/dns_test.go index 95d941e1bf..bb363aa7d1 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -1683,6 +1683,61 @@ func TestDNS_ExternalServiceLookup(t *testing.T) { } } +func TestDNS_InifiniteRecursion(t *testing.T) { + // This test should not create an infinite recursion + t.Parallel() + a := NewTestAgent(t.Name(), ` + domain = "CONSUL." + node_name = "test node" + `) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + // Register the initial node with a service + { + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "web", + Address: "web.service.consul.", + Service: &structs.NodeService{ + Service: "web", + Port: 12345, + Address: "web.service.consul.", + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Look up the service directly + questions := []string{ + "web.service.consul.", + } + for _, question := range questions { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeA) + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(in.Answer) < 1 { + t.Fatalf("Bad: %#v", in) + } + aRec, ok := in.Answer[0].(*dns.CNAME) + if !ok { + t.Fatalf("Bad: %#v", in.Answer[0]) + } + if aRec.Target != "web.service.consul." { + t.Fatalf("Bad: %#v, target:=%s", aRec, aRec.Target) + } + } +} + func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), ` diff --git a/agent/session_endpoint_test.go b/agent/session_endpoint_test.go index c5c600fcdb..f16bdcfe1f 100644 --- a/agent/session_endpoint_test.go +++ b/agent/session_endpoint_test.go @@ -397,7 +397,7 @@ func TestSessionTTLRenew(t *testing.T) { } // Sleep to consume some time before renew - time.Sleep(ttl * (structs.SessionTTLMultiplier / 2)) + time.Sleep(ttl * (structs.SessionTTLMultiplier / 3)) req, _ = http.NewRequest("PUT", "/v1/session/renew/"+id, nil) resp = httptest.NewRecorder()