diff --git a/agent/agent.go b/agent/agent.go index 98f02b6924..99700f405c 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -13,7 +13,6 @@ import ( "net/http" "os" "path/filepath" - "regexp" "strconv" "strings" "sync" @@ -51,9 +50,6 @@ const ( "service, but no reason was provided. This is a default message." ) -// dnsNameRe checks if a name or tag is dns-compatible. -var dnsNameRe = regexp.MustCompile(`^[a-zA-Z0-9\-]+$`) - // delegate defines the interface shared by both // consul.Client and consul.Server. type delegate interface { @@ -65,6 +61,7 @@ type delegate interface { JoinLAN(addrs []string) (n int, err error) RemoveFailedNode(node string) error RPC(method string, args interface{}, reply interface{}) error + ServerAddrs() map[string]string SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error Shutdown() error Stats() map[string]map[string]string @@ -1369,7 +1366,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che } // Warn if the service name is incompatible with DNS - if !dnsNameRe.MatchString(service.Service) { + if InvalidDnsRe.MatchString(service.Service) { a.logger.Printf("[WARN] Service name %q will not be discoverable "+ "via DNS due to invalid characters. Valid characters include "+ "all alpha-numerics and dashes.", service.Service) @@ -1377,7 +1374,7 @@ func (a *Agent) AddService(service *structs.NodeService, chkTypes []*structs.Che // Warn if any tags are incompatible with DNS for _, tag := range service.Tags { - if !dnsNameRe.MatchString(tag) { + if InvalidDnsRe.MatchString(tag) { a.logger.Printf("[DEBUG] Service tag %q will not be discoverable "+ "via DNS due to invalid characters. Valid characters include "+ "all alpha-numerics and dashes.", tag) diff --git a/agent/consul/client.go b/agent/consul/client.go index 6ee73e12b3..87476e7cfd 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -411,6 +411,10 @@ func (c *Client) Stats() map[string]map[string]string { return stats } +func (c *Client) ServerAddrs() map[string]string { + return c.servers.GetServerAddrs() +} + // GetLANCoordinate returns the network coordinate of the current node, as // maintained by Serf. func (c *Client) GetLANCoordinate() (*coordinate.Coordinate, error) { diff --git a/agent/consul/server.go b/agent/consul/server.go index 3160ad73ce..694426c4d2 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -1047,6 +1047,15 @@ func (s *Server) GetWANCoordinate() (*coordinate.Coordinate, error) { return s.serfWAN.GetCoordinate() } +func (s *Server) ServerAddrs() map[string]string { + ret, err := s.router.FindServerAddrs(s.config.Datacenter) + if err != nil { + s.logger.Printf("[WARN] Unexpected state, no server addresses for datacenter %v, got error: %v", s.config.Datacenter, err) + return nil + } + return ret +} + // Atomically sets a readiness state flag when leadership is obtained, to indicate that server is past its barrier write func (s *Server) setConsistentReadReady() { atomic.StoreInt32(&s.readyForConsistentReads, 1) diff --git a/agent/consul/servers/manager.go b/agent/consul/servers/manager.go index ef149d0087..092efadd7b 100644 --- a/agent/consul/servers/manager.go +++ b/agent/consul/servers/manager.go @@ -223,6 +223,15 @@ func (m *Manager) getServerList() serverList { return m.listValue.Load().(serverList) } +// GetServerAddrs returns a map from node name to address for all servers +func (m *Manager) GetServerAddrs() map[string]string { + ret := make(map[string]string) + for _, server := range m.getServerList().servers { + ret[server.Name] = server.Addr.String() + } + return ret +} + // saveServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) saveServerList(l serverList) { diff --git a/agent/consul/servers/router.go b/agent/consul/servers/router.go index 315e3a55af..f75f354285 100644 --- a/agent/consul/servers/router.go +++ b/agent/consul/servers/router.go @@ -489,3 +489,26 @@ func (r *Router) GetDatacenterMaps() ([]structs.DatacenterMap, error) { } return maps, nil } + +func (r *Router) FindServerAddrs(datacenter string) (map[string]string, error) { + r.RLock() + defer r.RUnlock() + + // Get the list of managers for this datacenter. This will usually just + // have one entry, but it's possible to have a user-defined area + WAN. + managers, ok := r.managers[datacenter] + if !ok { + return nil, fmt.Errorf("datacenter %v not found", datacenter) + } + + ret := make(map[string]string) + for _, manager := range managers { + if manager.IsOffline() { + continue + } + for name, addr := range manager.GetServerAddrs() { + ret[name] = addr + } + } + return ret, nil +} diff --git a/agent/dns.go b/agent/dns.go index bf01bd1e85..64e9226a38 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -9,6 +9,8 @@ import ( "sync/atomic" "time" + "regexp" + "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/consul/structs" @@ -30,6 +32,8 @@ const ( defaultMaxUDPSize = 512 ) +var InvalidDnsRe = regexp.MustCompile(`[^A-Za-z0-9\\-]+`) + // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { @@ -133,7 +137,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { - d.addSOA(d.domain, m) + d.addSOA(m) } datacenter := d.agent.config.Datacenter @@ -206,13 +210,26 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { m.Authoritative = true m.RecursionAvailable = (len(d.recursors) > 0) - // Only add the SOA if requested - if req.Question[0].Qtype == dns.TypeSOA { - d.addSOA(d.domain, m) - } + switch req.Question[0].Qtype { + case dns.TypeSOA: + ns, glue := d.nameservers(req.IsEdns0() != nil) + m.Answer = append(m.Answer, d.soa()) + m.Ns = append(m.Ns, ns...) + m.Extra = append(m.Extra, glue...) + m.SetRcode(req, dns.RcodeSuccess) - // Dispatch the correct handler - d.dispatch(network, req, m) + case dns.TypeNS: + ns, glue := d.nameservers(req.IsEdns0() != nil) + m.Answer = ns + m.Extra = glue + m.SetRcode(req, dns.RcodeSuccess) + + case dns.TypeAXFR: + m.SetRcode(req, dns.RcodeNotImplemented) + + default: + d.dispatch(network, req, m) + } // Handle EDNS if edns := req.IsEdns0(); edns != nil { @@ -225,24 +242,92 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { } } -// addSOA is used to add an SOA record to a message for the given domain -func (d *DNSServer) addSOA(domain string, msg *dns.Msg) { - soa := &dns.SOA{ +func (d *DNSServer) soa() *dns.SOA { + return &dns.SOA{ Hdr: dns.RR_Header{ - Name: domain, + Name: d.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0, }, - Ns: "ns." + domain, - Mbox: "postmaster." + domain, - Serial: uint32(time.Now().Unix()), + Ns: "ns." + d.domain, + Serial: uint32(time.Now().Unix()), + + // todo(fs): make these configurable + Mbox: "hostmaster." + d.domain, Refresh: 3600, Retry: 600, Expire: 86400, Minttl: 0, } - msg.Ns = append(msg.Ns, soa) +} + +// addSOA is used to add an SOA record to a message for the given domain +func (d *DNSServer) addSOA(msg *dns.Msg) { + msg.Ns = append(msg.Ns, d.soa()) +} + +// nameservers returns the names and ip addresses of up to three random servers +// in the current cluster which serve as authoritative name servers for zone. +func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) { + // get server names and store them in a map to randomize the output + servers := map[string]net.IP{} + for name, addr := range d.agent.delegate.ServerAddrs() { + host, _, err := net.SplitHostPort(addr) + if err != nil { + d.logger.Println("[WARN] Unable to parse address %v, got error: %v", addr, err) + continue + } + + ip := net.ParseIP(host) + if ip == nil { + continue + } + + // Use "NODENAME.node.DC.DOMAIN" as a unique name for the server + // since we use that name in other places as well. + // 'name' is "NODENAME.DC" so we need to split it + // to construct the server name. + lastdot := strings.LastIndexByte(name, '.') + nodeName, dc := name[:lastdot], name[lastdot:] + if InvalidDnsRe.MatchString(nodeName) { + d.logger.Printf("[WARN] dns: Node name %q is not a valid dns host name, will not be added to NS record", nodeName) + continue + } + fqdn := nodeName + ".node" + dc + "." + d.domain + fqdn = dns.Fqdn(strings.ToLower(fqdn)) + + servers[fqdn] = ip + } + + if len(servers) == 0 { + return + } + + for name, ip := range servers { + // NS record + nsrr := &dns.NS{ + Hdr: dns.RR_Header{ + Name: d.domain, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: uint32(d.config.NodeTTL / time.Second), + }, + Ns: name, + } + ns = append(ns, nsrr) + + // A or AAAA glue record + glue := d.formatNodeRecord(ip.String(), name, dns.TypeANY, d.config.NodeTTL, edns) + extra = append(extra, glue...) + + // don't provide more than 3 servers + if len(ns) >= 3 { + return + } + } + + return } // dispatch is used to parse a request and invoke the correct handler @@ -371,7 +456,7 @@ PARSE: return INVALID: d.logger.Printf("[WARN] dns: QName invalid: %s", qName) - d.addSOA(d.domain, resp) + d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) } @@ -413,7 +498,7 @@ RPC: // If we have no address, return not found! if out.NodeServices == nil { - d.addSOA(d.domain, resp) + d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -422,15 +507,14 @@ RPC: n := out.NodeServices.Node edns := req.IsEdns0() != nil addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) - records := d.formatNodeRecord(out.NodeServices.Node, addr, - req.Question[0].Name, qType, d.config.NodeTTL, edns) + records := d.formatNodeRecord(addr, req.Question[0].Name, qType, d.config.NodeTTL, edns) if records != nil { resp.Answer = append(resp.Answer, records...) } } // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record -func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) { +func (d *DNSServer) formatNodeRecord(addr, qName string, qType uint16, ttl time.Duration, edns bool) (records []dns.RR) { // Parse the IP ip := net.ParseIP(addr) var ipv4 net.IP @@ -640,7 +724,7 @@ RPC: // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(d.domain, resp) + d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -668,7 +752,7 @@ RPC: // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(d.domain, resp) + d.addSOA(resp) return } } @@ -709,7 +793,7 @@ RPC: // not a full on server error. We have to use a string compare // here since the RPC layer loses the type information. if err.Error() == consul.ErrQueryNotFound.Error() { - d.addSOA(d.domain, resp) + d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -751,7 +835,7 @@ RPC: // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(d.domain, resp) + d.addSOA(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -776,7 +860,7 @@ RPC: // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(d.domain, resp) + d.addSOA(resp) return } } @@ -810,7 +894,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode handled[addr] = struct{}{} // Add the node record - records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns) + records := d.formatNodeRecord(addr, qName, qType, ttl, edns) if records != nil { resp.Answer = append(resp.Answer, records...) } @@ -854,7 +938,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns) + records := d.formatNodeRecord(addr, srvRec.Target, dns.TypeANY, ttl, edns) if len(records) > 0 { // Use the node address if it doesn't differ from the service address if addr == node.Node.Address { diff --git a/agent/dns_test.go b/agent/dns_test.go index 90e905d6aa..152f5f237b 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -179,6 +179,7 @@ func TestDNS_NodeLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + } func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { @@ -630,12 +631,15 @@ func TestDNS_ServiceLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + } } func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { t.Parallel() - a := NewTestAgent(t.Name(), nil) + cfg := TestConfig() + cfg.NodeName = "my.test-node" + a := NewTestAgent(t.Name(), cfg) defer a.Shutdown() // Register a node with a service. @@ -679,7 +683,6 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { }, } verify.Values(t, "answer", in.Answer, wantAnswer) - wantExtra := []dns.RR{ &dns.CNAME{ Hdr: dns.RR_Header{Name: "foo.node.dc1.consul.", Rrtype: 0x5, Class: 0x1, Rdlength: 0x2}, @@ -769,6 +772,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { t.Parallel() cfg := TestConfig() cfg.Domain = "CONSUL." + cfg.NodeName = "test node" a := NewTestAgent(t.Name(), cfg) defer a.Shutdown() @@ -873,12 +877,119 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { if aRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Extra[1]) } + } } +func TestDNS_NSRecords(t *testing.T) { + t.Parallel() + cfg := TestConfig() + cfg.Domain = "CONSUL." + cfg.NodeName = "server1" + a := NewTestAgent(t.Name(), cfg) + defer a.Shutdown() + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("something.node.consul.", dns.TypeNS) + + c := new(dns.Client) + addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: "consul.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x13}, + Ns: "server1.node.dc1.consul.", + }, + } + verify.Values(t, "answer", in.Answer, wantAnswer) + wantExtra := []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "server1.node.dc1.consul.", Rrtype: dns.TypeA, Class: dns.ClassINET, Rdlength: 0x4, Ttl: 0}, + A: net.ParseIP("127.0.0.1").To4(), + }, + } + + verify.Values(t, "extra", in.Extra, wantExtra) + +} + +func TestDNS_NSRecords_IPV6(t *testing.T) { + t.Parallel() + cfg := TestConfig() + cfg.Domain = "CONSUL." + cfg.NodeName = "server1" + cfg.AdvertiseAddr = "::1" + cfg.AdvertiseAddrWan = "::1" + a := NewTestAgent(t.Name(), cfg) + defer a.Shutdown() + + // Register node + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: "foo", + Address: "127.0.0.1", + TaggedAddresses: map[string]string{ + "wan": "127.0.0.2", + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + + m := new(dns.Msg) + m.SetQuestion("server1.node.dc1.consul.", dns.TypeNS) + + c := new(dns.Client) + addr, _ := a.Config.ClientListener("", a.Config.Ports.DNS) + in, _, err := c.Exchange(m, addr.String()) + if err != nil { + t.Fatalf("err: %v", err) + } + + wantAnswer := []dns.RR{ + &dns.NS{ + Hdr: dns.RR_Header{Name: "consul.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 0, Rdlength: 0x2}, + Ns: "server1.node.dc1.consul.", + }, + } + verify.Values(t, "answer", in.Answer, wantAnswer) + wantExtra := []dns.RR{ + &dns.AAAA{ + Hdr: dns.RR_Header{Name: "server1.node.dc1.consul.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Rdlength: 0x10, Ttl: 0}, + AAAA: net.ParseIP("::1"), + }, + } + + verify.Values(t, "extra", in.Extra, wantExtra) + +} + func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { t.Parallel() - a := NewTestAgent(t.Name(), nil) + cfg := TestConfig() + cfg.NodeName = "test-node" + a := NewTestAgent(t.Name(), cfg) defer a.Shutdown() // Register the initial node with a service @@ -2694,6 +2805,7 @@ func testDNS_ServiceLookup_responseLimits(t *testing.T, answerLimit int, qType u expectedService, expectedQuery, expectedQueryID int) (bool, error) { cfg := TestConfig() cfg.DNSConfig.UDPAnswerLimit = answerLimit + cfg.NodeName = "test-node" a := NewTestAgent(t.Name(), cfg) defer a.Shutdown() @@ -3862,7 +3974,6 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { if len(in.Ns) != 1 { t.Fatalf("Bad: %#v", in) } - soaRec, ok := in.Ns[0].(*dns.SOA) if !ok { t.Fatalf("Bad: %#v", in.Ns[0]) @@ -3874,6 +3985,7 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { if in.Rcode != dns.RcodeSuccess { t.Fatalf("Bad: %#v", in) } + } // Check for ipv4 records on ipv6-only service directly and via the @@ -3955,6 +4067,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + } } @@ -3993,6 +4106,7 @@ func TestDNS_InvalidQueries(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + } } @@ -4688,3 +4802,26 @@ func TestDNS_Compression_Recurse(t *testing.T) { t.Fatalf("doesn't look compressed: %d vs. %d", compressed, unc) } } + +func TestDNSInvalidRegex(t *testing.T) { + tests := []struct { + desc string + in string + invalid bool + }{ + {"Valid Hostname", "testnode", false}, + {"Valid Hostname", "test-node", false}, + {"Invalid Hostname with special chars", "test#$$!node", true}, + {"Invalid Hostname with special chars in the end", "testnode%^", true}, + {"Whitespace", " ", true}, + {"Only special chars", "./$", true}, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if got, want := InvalidDnsRe.MatchString(test.in), test.invalid; got != want { + t.Fatalf("Expected %v to return %v", test.in, want) + } + }) + + } +}