diff --git a/command/agent/command.go b/command/agent/command.go index 9ca39af856..606649846a 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -227,8 +227,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log return err } - server, err := NewDNSServer(agent, logOutput, config.Domain, - dnsAddr.String(), config.DNSRecursor) + server, err := NewDNSServer(agent, &config.DNSConfig, logOutput, + config.Domain, dnsAddr.String(), config.DNSRecursor) if err != nil { agent.Shutdown() c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err)) diff --git a/command/agent/dns.go b/command/agent/dns.go index f0b1ed9914..58aa4a2379 100644 --- a/command/agent/dns.go +++ b/command/agent/dns.go @@ -23,6 +23,7 @@ const ( // service discovery endpoints using a DNS interface. type DNSServer struct { agent *Agent + config *DNSConfig dnsHandler *dns.ServeMux dnsServer *dns.Server dnsServerTCP *dns.Server @@ -32,7 +33,7 @@ type DNSServer struct { } // NewDNSServer starts a new DNS server to provide an agent interface -func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) { +func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) { // Make sure domain is FQDN domain = dns.Fqdn(domain) @@ -55,6 +56,7 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor stri // Create the server srv := &DNSServer{ agent: agent, + config: config, dnsHandler: mux, dnsServer: server, dnsServerTCP: serverTCP, @@ -306,16 +308,25 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. // Make an RPC request args := structs.NodeSpecificRequest{ - Datacenter: datacenter, - Node: node, + Datacenter: datacenter, + Node: node, + QueryOptions: structs.QueryOptions{AllowStale: d.config.AllowStale}, } var out structs.IndexedNodeServices +RPC: if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) return } + // Verify that request is not too stale, redo the request + if args.AllowStale && out.LastContact > d.config.MaxStale { + args.AllowStale = false + d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") + goto RPC + } + // If we have no address, return not found! if out.NodeServices == nil { resp.SetRcode(req, dns.RcodeNameError) @@ -398,18 +409,27 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uin func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) { // Make an RPC request args := structs.ServiceSpecificRequest{ - Datacenter: datacenter, - ServiceName: service, - ServiceTag: tag, - TagFilter: tag != "", + Datacenter: datacenter, + ServiceName: service, + ServiceTag: tag, + TagFilter: tag != "", + QueryOptions: structs.QueryOptions{AllowStale: d.config.AllowStale}, } var out structs.IndexedCheckServiceNodes +RPC: if err := d.agent.RPC("Health.ServiceNodes", &args, &out); err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) return } + // Verify that request is not too stale, redo the request + if args.AllowStale && out.LastContact > d.config.MaxStale { + args.AllowStale = false + d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") + goto RPC + } + // If we have no nodes, return not found! if len(out.Nodes) == 0 { resp.SetRcode(req, dns.RcodeNameError) diff --git a/command/agent/dns_test.go b/command/agent/dns_test.go index d4add63548..2fd098265a 100644 --- a/command/agent/dns_test.go +++ b/command/agent/dns_test.go @@ -14,8 +14,9 @@ func makeDNSServer(t *testing.T) (string, *DNSServer) { conf := nextConfig() addr, _ := conf.ClientListener(conf.Ports.DNS) dir, agent := makeAgent(t, conf) - server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, - addr.String(), "8.8.8.8:53") + config := &DNSConfig{} + server, err := NewDNSServer(agent, config, agent.logOutput, + conf.Domain, addr.String(), "8.8.8.8:53") if err != nil { t.Fatalf("err: %v", err) }