From f9db3870975690de66e782f2cfb0906a94b78f2f Mon Sep 17 00:00:00 2001 From: Preetha Appan Date: Wed, 2 Aug 2017 16:44:40 -0500 Subject: [PATCH] Add NS records and A records for each server. Constructs ns host names using the advertise address of the server. --- agent/agent.go | 1 + agent/consul/client.go | 4 ++ agent/consul/server.go | 9 +++ agent/consul/servers/manager.go | 9 +++ agent/consul/servers/router.go | 22 ++++++ agent/dns.go | 46 ++++++++++++ agent/dns_test.go | 119 +++++++++++++++++++++++++++----- 7 files changed, 194 insertions(+), 16 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 98f02b6924..99ee456936 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -65,6 +65,7 @@ type delegate interface { JoinLAN(addrs []string) (n int, err error) RemoveFailedNode(node string) error RPC(method string, args interface{}, reply interface{}) error + ServerAddrs() []string SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error Shutdown() error Stats() map[string]map[string]string diff --git a/agent/consul/client.go b/agent/consul/client.go index 6ee73e12b3..1f625463d7 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -411,6 +411,10 @@ func (c *Client) Stats() map[string]map[string]string { return stats } +func (c *Client) ServerAddrs() []string { + return c.servers.GetServerAddrs() +} + // GetLANCoordinate returns the network coordinate of the current node, as // maintained by Serf. func (c *Client) GetLANCoordinate() (*coordinate.Coordinate, error) { diff --git a/agent/consul/server.go b/agent/consul/server.go index 3160ad73ce..bba2cf9807 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -1047,6 +1047,15 @@ func (s *Server) GetWANCoordinate() (*coordinate.Coordinate, error) { return s.serfWAN.GetCoordinate() } +func (s *Server) ServerAddrs() []string { + ret, err := s.router.FindServerAddrs(s.config.Datacenter) + if err != nil { + s.logger.Printf("[WARN] Unexpected state, no server addresses for datacenter %v, got error: %v", s.config.Datacenter, err) + return nil + } + return ret +} + // Atomically sets a readiness state flag when leadership is obtained, to indicate that server is past its barrier write func (s *Server) setConsistentReadReady() { atomic.StoreInt32(&s.readyForConsistentReads, 1) diff --git a/agent/consul/servers/manager.go b/agent/consul/servers/manager.go index ef149d0087..7272414f8a 100644 --- a/agent/consul/servers/manager.go +++ b/agent/consul/servers/manager.go @@ -223,6 +223,15 @@ func (m *Manager) getServerList() serverList { return m.listValue.Load().(serverList) } +// GetServerAddrs returns a slice with all server addresses +func (m *Manager) GetServerAddrs() []string { + var ret []string + for _, server := range m.getServerList().servers { + ret = append(ret, server.Addr.String()) + } + return ret +} + // saveServerList is a convenience method which hides the locking semantics // of atomic.Value from the caller. func (m *Manager) saveServerList(l serverList) { diff --git a/agent/consul/servers/router.go b/agent/consul/servers/router.go index 315e3a55af..35d4580c1c 100644 --- a/agent/consul/servers/router.go +++ b/agent/consul/servers/router.go @@ -489,3 +489,25 @@ func (r *Router) GetDatacenterMaps() ([]structs.DatacenterMap, error) { } return maps, nil } + +func (r *Router) FindServerAddrs(datacenter string) ([]string, error) { + r.RLock() + defer r.RUnlock() + + // Get the list of managers for this datacenter. This will usually just + // have one entry, but it's possible to have a user-defined area + WAN. + managers, ok := r.managers[datacenter] + if !ok { + return nil, fmt.Errorf("datacenter %v not found", datacenter) + } + + var ret []string + // Try each manager until we get a server. + for _, manager := range managers { + if manager.IsOffline() { + continue + } + ret = append(ret, manager.GetServerAddrs()...) + } + return ret, nil +} diff --git a/agent/dns.go b/agent/dns.go index bf01bd1e85..1434650ff1 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -372,6 +372,7 @@ PARSE: INVALID: d.logger.Printf("[WARN] dns: QName invalid: %s", qName) d.addSOA(d.domain, resp) + d.addNSAndARecordsForDomain(resp) resp.SetRcode(req, dns.RcodeNameError) } @@ -414,6 +415,7 @@ RPC: // If we have no address, return not found! if out.NodeServices == nil { d.addSOA(d.domain, resp) + d.addNSAndARecordsForDomain(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -427,6 +429,9 @@ RPC: if records != nil { resp.Answer = append(resp.Answer, records...) } + + // Add NS record and A record + d.addNSAndARecordsForDomain(resp) } // formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record @@ -641,6 +646,7 @@ RPC: // If we have no nodes, return not found! if len(out.Nodes) == 0 { d.addSOA(d.domain, resp) + d.addNSAndARecordsForDomain(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -656,6 +662,9 @@ RPC: d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl) } + // Add NS and A records + d.addNSAndARecordsForDomain(resp) + // If the network is not TCP, restrict the number of responses if network != "tcp" { wasTrimmed := trimUDPResponse(d.config, req, resp) @@ -673,6 +682,40 @@ RPC: } } +// addNSAndARecordsForDomain uses the agent's advertise address to +func (d *DNSServer) addNSAndARecordsForDomain(msg *dns.Msg) { + serverAddrs := d.agent.delegate.ServerAddrs() + for _, addr := range serverAddrs { + ipAddrStr := strings.Split(addr, ":")[0] + nsName := "ns." + ipAddrStr + "." + d.domain + ip := net.ParseIP(ipAddrStr) + if ip != nil { + ns := &dns.NS{ + Hdr: dns.RR_Header{ + Name: d.domain, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 0, + }, + Ns: nsName, + } + msg.Ns = append(msg.Ns, ns) + + //add an A record for the NS record + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: nsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(d.config.NodeTTL / time.Second), + }, + A: ip, + } + msg.Extra = append(msg.Extra, a) + } + } +} + // preparedQueryLookup is used to handle a prepared query. func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, resp *dns.Msg) { // Execute the prepared query. @@ -710,6 +753,7 @@ RPC: // here since the RPC layer loses the type information. if err.Error() == consul.ErrQueryNotFound.Error() { d.addSOA(d.domain, resp) + d.addNSAndARecordsForDomain(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -752,6 +796,7 @@ RPC: // If we have no nodes, return not found! if len(out.Nodes) == 0 { d.addSOA(d.domain, resp) + d.addNSAndARecordsForDomain(resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -776,6 +821,7 @@ RPC: // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { + d.addNSAndARecordsForDomain(resp) d.addSOA(d.domain, resp) return } diff --git a/agent/dns_test.go b/agent/dns_test.go index 90e905d6aa..995ddf6c71 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -168,7 +168,7 @@ func TestDNS_NodeLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v %#v", in, len(in.Answer)) } @@ -179,6 +179,14 @@ func TestDNS_NodeLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + + nsRec, ok := in.Ns[1].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } } func TestDNS_CaseInsensitiveNodeLookup(t *testing.T) { @@ -619,7 +627,7 @@ func TestDNS_ServiceLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v", in) } @@ -630,6 +638,14 @@ func TestDNS_ServiceLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + + nsRec, ok := in.Ns[1].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } } } @@ -679,7 +695,6 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { }, } verify.Values(t, "answer", in.Answer, wantAnswer) - wantExtra := []dns.RR{ &dns.CNAME{ Hdr: dns.RR_Header{Name: "foo.node.dc1.consul.", Rrtype: 0x5, Class: 0x1, Rdlength: 0x2}, @@ -689,6 +704,10 @@ func TestDNS_ServiceLookupWithInternalServiceAddress(t *testing.T) { Hdr: dns.RR_Header{Name: "db.service.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 }, + &dns.A{ + Hdr: dns.RR_Header{Name: "ns.127.0.0.1.consul.", Rrtype: 0x1, Class: 0x1, Rdlength: 0x4}, + A: []byte{0x7f, 0x0, 0x0, 0x1}, // 127.0.0.1 + }, } verify.Values(t, "extra", in.Extra, wantExtra) } @@ -842,7 +861,7 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - if len(in.Extra) != 2 { + if len(in.Extra) != 3 { t.Fatalf("Bad: %#v", in) } @@ -873,6 +892,20 @@ func TestDNS_ExternalServiceToConsulCNAMELookup(t *testing.T) { if aRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Extra[1]) } + + aRec2, ok := in.Extra[2].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[2]) + } + if aRec2.Hdr.Name != "ns.127.0.0.1.consul." { + t.Fatalf("Bad: %#v", in.Extra[2]) + } + if aRec2.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Extra[2]) + } + if aRec2.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Extra[2]) + } } } @@ -968,7 +1001,7 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } - if len(in.Extra) != 3 { + if len(in.Extra) != 4 { t.Fatalf("Bad: %#v", in) } @@ -1013,6 +1046,20 @@ func TestDNS_ExternalServiceToConsulCNAMENestedLookup(t *testing.T) { if aRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Extra[2]) } + + aRec2, ok := in.Extra[3].(*dns.A) + if !ok { + t.Fatalf("Bad: %#v", in.Extra[3]) + } + if aRec2.Hdr.Name != "ns.127.0.0.1.consul." { + t.Fatalf("Bad: %#v", in.Extra[3]) + } + if aRec2.A.String() != "127.0.0.1" { + t.Fatalf("Bad: %#v", in.Extra[3]) + } + if aRec2.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Extra[3]) + } } } @@ -3758,7 +3805,7 @@ func TestDNS_NonExistingLookup(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v %#v", in, len(in.Answer)) } @@ -3769,6 +3816,14 @@ func TestDNS_NonExistingLookup(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + + nsRec, ok := in.Ns[1].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } } func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { @@ -3859,21 +3914,28 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v", in) } - - soaRec, ok := in.Ns[0].(*dns.SOA) + soaRec, ok := in.Ns[1].(*dns.SOA) if !ok { - t.Fatalf("Bad: %#v", in.Ns[0]) + t.Fatalf("Bad: %#v", in.Ns[1]) } if soaRec.Hdr.Ttl != 0 { - t.Fatalf("Bad: %#v", in.Ns[0]) + t.Fatalf("Bad: %#v", in.Ns[1]) } if in.Rcode != dns.RcodeSuccess { t.Fatalf("Bad: %#v", in) } + + nsRec, ok := in.Ns[0].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[0]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[0]) + } } // Check for ipv4 records on ipv6-only service directly and via the @@ -3893,18 +3955,26 @@ func TestDNS_NonExistingLookupEmptyAorAAAA(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v", in) } - soaRec, ok := in.Ns[0].(*dns.SOA) + nsRec, ok := in.Ns[0].(*dns.NS) if !ok { t.Fatalf("Bad: %#v", in.Ns[0]) } - if soaRec.Hdr.Ttl != 0 { + if nsRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + soaRec, ok := in.Ns[1].(*dns.SOA) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if soaRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if in.Rcode != dns.RcodeSuccess { t.Fatalf("Bad: %#v", in) } @@ -3944,7 +4014,7 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v", in) } @@ -3955,6 +4025,15 @@ func TestDNS_PreparedQuery_AllowStale(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + + nsRec, ok := in.Ns[1].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + } } @@ -3982,7 +4061,7 @@ func TestDNS_InvalidQueries(t *testing.T) { t.Fatalf("err: %v", err) } - if len(in.Ns) != 1 { + if len(in.Ns) != 2 { t.Fatalf("Bad: %#v", in) } @@ -3993,6 +4072,14 @@ func TestDNS_InvalidQueries(t *testing.T) { if soaRec.Hdr.Ttl != 0 { t.Fatalf("Bad: %#v", in.Ns[0]) } + + nsRec, ok := in.Ns[1].(*dns.NS) + if !ok { + t.Fatalf("Bad: %#v", in.Ns[1]) + } + if nsRec.Hdr.Ttl != 0 { + t.Fatalf("Bad: %#v", in.Ns[1]) + } } }