agent: Shuffle DNS responses, limit records

This commit is contained in:
Armon Dadgar 2014-02-14 12:26:51 -08:00
parent 1d10b9d6ba
commit d35de5bc11
2 changed files with 90 additions and 4 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"strings" "strings"
"time" "time"
@ -14,6 +15,7 @@ import (
const ( const (
testQuery = "_test.consul." testQuery = "_test.consul."
consulDomain = "consul." consulDomain = "consul."
maxServiceResponses = 3 // TODO: Increase, currently a bug upstream in dns package
) )
// DNSServer is used to wrap an Agent and expose various // DNSServer is used to wrap an Agent and expose various
@ -318,6 +320,14 @@ func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dn
// Filter out any service nodes due to health checks // Filter out any service nodes due to health checks
out.Nodes = d.filterServiceNodes(out.Nodes) out.Nodes = d.filterServiceNodes(out.Nodes)
// Perform a random shuffle
shuffleServiceNodes(out.Nodes)
// Restrict the number of responses
if len(out.Nodes) > maxServiceResponses {
out.Nodes = out.Nodes[:maxServiceResponses]
}
// Add various responses depending on the request // Add various responses depending on the request
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType == dns.TypeANY || qType == dns.TypeA { if qType == dns.TypeANY || qType == dns.TypeA {
@ -346,6 +356,14 @@ func (d *DNSServer) filterServiceNodes(nodes structs.CheckServiceNodes) structs.
return nodes[:n] return nodes[:n]
} }
// shuffleServiceNodes does an in-place random shuffle using the Fisher-Yates algorithm
func shuffleServiceNodes(nodes structs.CheckServiceNodes) {
for i := len(nodes) - 1; i > 0; i-- {
j := rand.Int31() % int32(i+1)
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
// serviceARecords is used to add the A records for a service lookup // serviceARecords is used to add the A records for a service lookup
func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) { func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
handled := make(map[string]struct{}) handled := make(map[string]struct{})

View File

@ -1,9 +1,11 @@
package agent package agent
import ( import (
"fmt"
"github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns" "github.com/miekg/dns"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
) )
@ -256,7 +258,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[1]) t.Fatalf("Bad: %#v", in.Answer[1])
} }
if srvRec.Port != 12345 { if srvRec.Port != 12345 && srvRec.Port != 12346 {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
if srvRec.Target != "foo.node.dc1.consul." { if srvRec.Target != "foo.node.dc1.consul." {
@ -267,9 +269,12 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
if !ok { if !ok {
t.Fatalf("Bad: %#v", in.Answer[1]) t.Fatalf("Bad: %#v", in.Answer[1])
} }
if srvRec.Port != 12346 { if srvRec.Port != 12346 && srvRec.Port != 12345 {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
if srvRec.Port == in.Answer[1].(*dns.SRV).Port {
t.Fatalf("should be a different port")
}
if srvRec.Target != "foo.node.dc1.consul." { if srvRec.Target != "foo.node.dc1.consul." {
t.Fatalf("Bad: %#v", srvRec) t.Fatalf("Bad: %#v", srvRec)
} }
@ -352,3 +357,66 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) {
t.Fatalf("Bad: %#v", in) t.Fatalf("Bad: %#v", in)
} }
} }
func TestDNS_ServiceLookup_Randomize(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()
// Wait for leader
time.Sleep(100 * time.Millisecond)
// Register nodes
for i := 0; i < 3*maxServiceResponses; i++ {
args := &structs.RegisterRequest{
Datacenter: "dc1",
Node: fmt.Sprintf("foo%d", i),
Address: fmt.Sprintf("127.0.0.%d", i+1),
Service: &structs.NodeService{
Service: "web",
Port: 8000,
},
}
var out struct{}
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
t.Fatalf("err: %v", err)
}
}
// Ensure the response is randomized each time
uniques := map[string]struct{}{}
for i := 0; i < 5; i++ {
m := new(dns.Msg)
m.SetQuestion("web.service.consul.", dns.TypeANY)
c := new(dns.Client)
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
// Response length should be truncated
// We should get an SRV + A record for each response (hence 2x)
if len(in.Answer) != 2*maxServiceResponses {
t.Fatalf("Bad: %#v", len(in.Answer))
}
// Collect all the names
var names []string
for _, rec := range in.Answer {
switch v := rec.(type) {
case *dns.SRV:
names = append(names, v.Target)
case *dns.A:
names = append(names, v.A.String())
}
}
nameS := strings.Join(names, "|")
// Check if unique
if _, ok := uniques[nameS]; ok {
t.Fatalf("non-unique response: %v", nameS)
}
uniques[nameS] = struct{}{}
}
}