diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index fbc2fee001..d6d5e3f0f2 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -2,7 +2,9 @@ package agent import ( "fmt" + "net" "os" + "reflect" "strings" "testing" "time" @@ -1569,3 +1571,133 @@ func TestDNS_ServiceLookup_SRV_RFC_TCP_Default(t *testing.T) { t.Fatalf("Bad: %#v", in.Extra[0]) } } + +func TestDNS_CNAME_recurse(t *testing.T) { + // Create our recursor - Consul will recurse to this + dnsConf := nextConfig() + dnsAddr := fmt.Sprintf("%s:%d", dnsConf.Addresses.DNS, dnsConf.Ports.DNS) + mux := dns.NewServeMux() + mux.HandleFunc(".", func(resp dns.ResponseWriter, msg *dns.Msg) { + + cnResp := func(src, target string) *dns.CNAME { + return &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: src, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + }, + Target: target, + } + } + + // Create the answer + ans := &dns.Msg{} + ans.SetReply(msg) + ans.Answer = append(ans.Answer, + cnResp("a.example.com.", "b.example.com."), + cnResp("b.example.com.", "c.example.com."), + cnResp("c.example.com.", "d.example.com."), + &dns.A{ + Hdr: dns.RR_Header{ + Name: "d.example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP("1.2.3.4"), + }) + + // Write the answer back to the client + if err := resp.WriteMsg(ans); err != nil { + t.Fatalf("err: %s", err) + } + }) + server := &dns.Server{ + Addr: dnsAddr, + Net: "udp", + Handler: mux, + } + go server.ListenAndServe() + defer server.Shutdown() + + // Create the Consul server + dconf := &DNSConfig{} + config := nextConfig() + addr, _ := config.ClientListener(config.Addresses.DNS, config.Ports.DNS) + dir, agent := makeAgent(t, config) + defer os.RemoveAll(dir) + defer agent.Shutdown() + + srv, err := NewDNSServer(agent, dconf, agent.logOutput, + config.Domain, addr.String(), []string{dnsAddr}) + if err != nil { + t.Fatalf("err: %v", err) + } + + testutil.WaitForLeader(t, srv.agent.RPC, "dc1") + + // Register a service with a recursing CNAME as the address + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "a.example.com", + Service: &structs.NodeService{ + Service: "db", + Tags: []string{"master"}, + Address: "a.example.com", + Port: 12345, + }, + } + + var out struct{} + if err := srv.agent.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + // Create the DNS query against the Consul server + m := new(dns.Msg) + m.SetQuestion("db.service.consul.", dns.TypeA) + + c := new(dns.Client) + c.Net = "tcp" + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Should have all 3 CNAMES and the A record + if len(in.Answer) != 4 { + t.Fatalf("Bad: %#v", in) + } + + // Check all the records + expected := []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeCNAME, + }, + Target: "abc", + }, + &dns.CNAME{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeCNAME, + }, + Target: "abc", + }, + &dns.CNAME{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeCNAME, + }, + Target: "abc", + }, + &dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeCNAME, + }, + A: net.ParseIP("1.2.3.4"), + }, + } + + if !reflect.DeepEqual(expected, in.Answer) { + t.Fatalf("Bad: %v %v", expected, in.Answer) + } +}