diff --git a/command/agent/dns.go b/command/agent/dns.go index 6c1299dec0..ac23b27286 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -6,14 +6,16 @@ import ( "github.com/miekg/dns" "io" "log" + "math/rand" "net" "strings" "time" ) const ( - testQuery = "_test.consul." - consulDomain = "consul." + testQuery = "_test.consul." + consulDomain = "consul." + maxServiceResponses = 3 // TODO: Increase, currently a bug upstream in dns package ) // DNSServer is used to wrap an Agent and expose various @@ -318,6 +320,14 @@ func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dn // Filter out any service nodes due to health checks out.Nodes = d.filterServiceNodes(out.Nodes) + // Perform a random shuffle + shuffleServiceNodes(out.Nodes) + + // Restrict the number of responses + if len(out.Nodes) > maxServiceResponses { + out.Nodes = out.Nodes[:maxServiceResponses] + } + // Add various responses depending on the request qType := req.Question[0].Qtype if qType == dns.TypeANY || qType == dns.TypeA { @@ -346,6 +356,14 @@ func (d *DNSServer) filterServiceNodes(nodes structs.CheckServiceNodes) structs. return nodes[:n] } +// shuffleServiceNodes does an in-place random shuffle using the Fisher-Yates algorithm +func shuffleServiceNodes(nodes structs.CheckServiceNodes) { + for i := len(nodes) - 1; i > 0; i-- { + j := rand.Int31() % int32(i+1) + nodes[i], nodes[j] = nodes[j], nodes[i] + } +} + // serviceARecords is used to add the A records for a service lookup func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) { handled := make(map[string]struct{}) diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index 30c94b6409..a1776cad64 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -1,9 +1,11 @@ package agent import ( + "fmt" "github.com/hashicorp/consul/consul/structs" "github.com/miekg/dns" "os" + "strings" "testing" "time" ) @@ -256,7 +258,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { if !ok { t.Fatalf("Bad: %#v", in.Answer[1]) } - if srvRec.Port != 12345 { + if srvRec.Port != 12345 && srvRec.Port != 12346 { t.Fatalf("Bad: %#v", srvRec) } if srvRec.Target != "foo.node.dc1.consul." { @@ -267,9 +269,12 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) { if !ok { t.Fatalf("Bad: %#v", in.Answer[1]) } - if srvRec.Port != 12346 { + if srvRec.Port != 12346 && srvRec.Port != 12345 { t.Fatalf("Bad: %#v", srvRec) } + if srvRec.Port == in.Answer[1].(*dns.SRV).Port { + t.Fatalf("should be a different port") + } if srvRec.Target != "foo.node.dc1.consul." { t.Fatalf("Bad: %#v", srvRec) } @@ -352,3 +357,66 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) { t.Fatalf("Bad: %#v", in) } } + +func TestDNS_ServiceLookup_Randomize(t *testing.T) { + dir, srv := makeDNSServer(t) + defer os.RemoveAll(dir) + defer srv.agent.Shutdown() + + // Wait for leader + time.Sleep(100 * time.Millisecond) + + // Register nodes + for i := 0; i < 3*maxServiceResponses; i++ { + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: fmt.Sprintf("foo%d", i), + Address: fmt.Sprintf("127.0.0.%d", i+1), + Service: &structs.NodeService{ + Service: "web", + Port: 8000, + }, + } + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Ensure the response is randomized each time + uniques := map[string]struct{}{} + for i := 0; i < 5; i++ { + m := new(dns.Msg) + m.SetQuestion("web.service.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, srv.agent.config.DNSAddr) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Response length should be truncated + // We should get an SRV + A record for each response (hence 2x) + if len(in.Answer) != 2*maxServiceResponses { + t.Fatalf("Bad: %#v", len(in.Answer)) + } + + // Collect all the names + var names []string + for _, rec := range in.Answer { + switch v := rec.(type) { + case *dns.SRV: + names = append(names, v.Target) + case *dns.A: + names = append(names, v.A.String()) + } + } + nameS := strings.Join(names, "|") + + // Check if unique + if _, ok := uniques[nameS]; ok { + t.Fatalf("non-unique response: %v", nameS) + } + uniques[nameS] = struct{}{} + } +}