mirror of https://github.com/status-im/consul.git
agent: Shuffle DNS responses, limit records
This commit is contained in:
parent
1d10b9d6ba
commit
d35de5bc11
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
const (
|
const (
|
||||||
testQuery = "_test.consul."
|
testQuery = "_test.consul."
|
||||||
consulDomain = "consul."
|
consulDomain = "consul."
|
||||||
|
maxServiceResponses = 3 // TODO: Increase, currently a bug upstream in dns package
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSServer is used to wrap an Agent and expose various
|
// DNSServer is used to wrap an Agent and expose various
|
||||||
|
@ -318,6 +320,14 @@ func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dn
|
||||||
// Filter out any service nodes due to health checks
|
// Filter out any service nodes due to health checks
|
||||||
out.Nodes = d.filterServiceNodes(out.Nodes)
|
out.Nodes = d.filterServiceNodes(out.Nodes)
|
||||||
|
|
||||||
|
// Perform a random shuffle
|
||||||
|
shuffleServiceNodes(out.Nodes)
|
||||||
|
|
||||||
|
// Restrict the number of responses
|
||||||
|
if len(out.Nodes) > maxServiceResponses {
|
||||||
|
out.Nodes = out.Nodes[:maxServiceResponses]
|
||||||
|
}
|
||||||
|
|
||||||
// Add various responses depending on the request
|
// Add various responses depending on the request
|
||||||
qType := req.Question[0].Qtype
|
qType := req.Question[0].Qtype
|
||||||
if qType == dns.TypeANY || qType == dns.TypeA {
|
if qType == dns.TypeANY || qType == dns.TypeA {
|
||||||
|
@ -346,6 +356,14 @@ func (d *DNSServer) filterServiceNodes(nodes structs.CheckServiceNodes) structs.
|
||||||
return nodes[:n]
|
return nodes[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shuffleServiceNodes does an in-place random shuffle using the Fisher-Yates algorithm
|
||||||
|
func shuffleServiceNodes(nodes structs.CheckServiceNodes) {
|
||||||
|
for i := len(nodes) - 1; i > 0; i-- {
|
||||||
|
j := rand.Int31() % int32(i+1)
|
||||||
|
nodes[i], nodes[j] = nodes[j], nodes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// serviceARecords is used to add the A records for a service lookup
|
// serviceARecords is used to add the A records for a service lookup
|
||||||
func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
|
func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
|
||||||
handled := make(map[string]struct{})
|
handled := make(map[string]struct{})
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/hashicorp/consul/consul/structs"
|
"github.com/hashicorp/consul/consul/structs"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -256,7 +258,7 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Bad: %#v", in.Answer[1])
|
t.Fatalf("Bad: %#v", in.Answer[1])
|
||||||
}
|
}
|
||||||
if srvRec.Port != 12345 {
|
if srvRec.Port != 12345 && srvRec.Port != 12346 {
|
||||||
t.Fatalf("Bad: %#v", srvRec)
|
t.Fatalf("Bad: %#v", srvRec)
|
||||||
}
|
}
|
||||||
if srvRec.Target != "foo.node.dc1.consul." {
|
if srvRec.Target != "foo.node.dc1.consul." {
|
||||||
|
@ -267,9 +269,12 @@ func TestDNS_ServiceLookup_Dedup(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Bad: %#v", in.Answer[1])
|
t.Fatalf("Bad: %#v", in.Answer[1])
|
||||||
}
|
}
|
||||||
if srvRec.Port != 12346 {
|
if srvRec.Port != 12346 && srvRec.Port != 12345 {
|
||||||
t.Fatalf("Bad: %#v", srvRec)
|
t.Fatalf("Bad: %#v", srvRec)
|
||||||
}
|
}
|
||||||
|
if srvRec.Port == in.Answer[1].(*dns.SRV).Port {
|
||||||
|
t.Fatalf("should be a different port")
|
||||||
|
}
|
||||||
if srvRec.Target != "foo.node.dc1.consul." {
|
if srvRec.Target != "foo.node.dc1.consul." {
|
||||||
t.Fatalf("Bad: %#v", srvRec)
|
t.Fatalf("Bad: %#v", srvRec)
|
||||||
}
|
}
|
||||||
|
@ -352,3 +357,66 @@ func TestDNS_ServiceLookup_FilterCritical(t *testing.T) {
|
||||||
t.Fatalf("Bad: %#v", in)
|
t.Fatalf("Bad: %#v", in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNS_ServiceLookup_Randomize(t *testing.T) {
|
||||||
|
dir, srv := makeDNSServer(t)
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
defer srv.agent.Shutdown()
|
||||||
|
|
||||||
|
// Wait for leader
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Register nodes
|
||||||
|
for i := 0; i < 3*maxServiceResponses; i++ {
|
||||||
|
args := &structs.RegisterRequest{
|
||||||
|
Datacenter: "dc1",
|
||||||
|
Node: fmt.Sprintf("foo%d", i),
|
||||||
|
Address: fmt.Sprintf("127.0.0.%d", i+1),
|
||||||
|
Service: &structs.NodeService{
|
||||||
|
Service: "web",
|
||||||
|
Port: 8000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var out struct{}
|
||||||
|
if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the response is randomized each time
|
||||||
|
uniques := map[string]struct{}{}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("web.service.consul.", dns.TypeANY)
|
||||||
|
|
||||||
|
c := new(dns.Client)
|
||||||
|
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response length should be truncated
|
||||||
|
// We should get an SRV + A record for each response (hence 2x)
|
||||||
|
if len(in.Answer) != 2*maxServiceResponses {
|
||||||
|
t.Fatalf("Bad: %#v", len(in.Answer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all the names
|
||||||
|
var names []string
|
||||||
|
for _, rec := range in.Answer {
|
||||||
|
switch v := rec.(type) {
|
||||||
|
case *dns.SRV:
|
||||||
|
names = append(names, v.Target)
|
||||||
|
case *dns.A:
|
||||||
|
names = append(names, v.A.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nameS := strings.Join(names, "|")
|
||||||
|
|
||||||
|
// Check if unique
|
||||||
|
if _, ok := uniques[nameS]; ok {
|
||||||
|
t.Fatalf("non-unique response: %v", nameS)
|
||||||
|
}
|
||||||
|
uniques[nameS] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue