From f669bb7b0fd323f63a5f990102f1c27d9fe4d6fe Mon Sep 17 00:00:00 2001 From: Aestek Date: Wed, 24 Apr 2019 20:11:54 +0200 Subject: [PATCH] Add support for DNS config hot-reload (#4875) The DNS config parameters `recursors` and `dns_config.*` are now hot reloaded on SIGHUP or `consul reload` and do not need an agent restart to be modified. Config is stored in an atomic.Value and loaded at the beginning of each request. Reloading only affects requests that start _after_ the reload. Ongoing requests are not affected. To match the current behavior the recursor handler is loaded and unloaded as needed on config reload. --- agent/agent.go | 6 + agent/dns.go | 372 +++++++++++++++++++++++++-------------------- agent/dns_test.go | 190 ++++++++++++++++++++++- agent/testagent.go | 3 +- 4 files changed, 402 insertions(+), 169 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index aa9de46bb4..96dc6843b4 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -3579,6 +3579,12 @@ func (a *Agent) ReloadConfig(newCfg *config.RuntimeConfig) error { a.loadLimits(newCfg) + for _, s := range a.dnsServers { + if err := s.ReloadConfig(newCfg); err != nil { + return fmt.Errorf("Failed reloading dns config : %v", err) + } + } + // create the config for the rpc server/client consulCfg, err := a.consulConfig() if err != nil { diff --git a/agent/dns.go b/agent/dns.go index 05bf6a93e8..9306d5730d 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -25,7 +25,7 @@ import ( const ( // UDP can fit ~25 A records in a 512B response, and ~14 AAAA - // records. Limit further to prevent unintentional configuration + // records. Limit further to prevent unintentional configuration // abuse that would have a negative effect on application response // times. maxUDPAnswerLimit = 8 @@ -46,7 +46,7 @@ type dnsSOAConfig struct { Refresh uint32 // 3600 by default Retry uint32 // 600 Expire uint32 // 86400 - Minttl uint32 // 0, + Minttl uint32 // 0 } type dnsConfig struct { @@ -60,128 +60,134 @@ type dnsConfig struct { NodeTTL time.Duration OnlyPassing bool RecursorTimeout time.Duration + Recursors []string SegmentName string - ServiceTTL map[string]time.Duration UDPAnswerLimit int ARecordLimit int NodeMetaTXT bool - dnsSOAConfig dnsSOAConfig + SOAConfig dnsSOAConfig + // TTLRadix sets service TTLs by prefix, eg: "database-*" + TTLRadix *radix.Tree + // TTLStict sets TTLs to service by full name match. It Has higher priority than TTLRadix + TTLStrict map[string]time.Duration + DisableCompression bool } // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { *dns.Server - agent *Agent - config *dnsConfig - domain string - recursors []string - logger *log.Logger - // Those are handling prefix lookups - ttlRadix *radix.Tree - ttlStrict map[string]time.Duration + agent *Agent + mux *dns.ServeMux + domain string + logger *log.Logger - // disableCompression is the config.DisableCompression flag that can - // be safely changed at runtime. It always contains a bool and is - // initialized with the value from config.DisableCompression. - disableCompression atomic.Value + // config stores the config as an atomic value (for hot-reloading). It is always of type *dnsConfig + config atomic.Value + + // recursorEnabled stores whever the recursor handler is enabled as an atomic flag. + // the recursor handler is only enabled if recursors are configured. This flag is used during config hot-reloading + recursorEnabled uint32 } func NewDNSServer(a *Agent) (*DNSServer, error) { - var recursors []string - for _, r := range a.config.DNSRecursors { - ra, err := recursorAddr(r) - if err != nil { - return nil, fmt.Errorf("Invalid recursor address: %v", err) - } - recursors = append(recursors, ra) - } - // Make sure domain is FQDN, make it case insensitive for ServeMux domain := dns.Fqdn(strings.ToLower(a.config.DNSDomain)) - dnscfg := GetDNSConfig(a.config) srv := &DNSServer{ - agent: a, - config: dnscfg, - domain: domain, - logger: a.logger, - recursors: recursors, - ttlRadix: radix.New(), - ttlStrict: make(map[string]time.Duration), + agent: a, + domain: domain, + logger: a.logger, } - if dnscfg.ServiceTTL != nil { - for key, ttl := range dnscfg.ServiceTTL { - // All suffix with '*' are put in radix - // This include '*' that will match anything - if strings.HasSuffix(key, "*") { - srv.ttlRadix.Insert(key[:len(key)-1], ttl) - } else { - srv.ttlStrict[key] = ttl - } - } + cfg, err := GetDNSConfig(a.config) + if err != nil { + return nil, err } - - srv.disableCompression.Store(a.config.DNSDisableCompression) + srv.config.Store(cfg) return srv, nil } // GetDNSConfig takes global config and creates the config used by DNS server -func GetDNSConfig(conf *config.RuntimeConfig) *dnsConfig { - return &dnsConfig{ - AllowStale: conf.DNSAllowStale, - ARecordLimit: conf.DNSARecordLimit, - Datacenter: conf.Datacenter, - EnableTruncate: conf.DNSEnableTruncate, - MaxStale: conf.DNSMaxStale, - NodeName: conf.NodeName, - NodeTTL: conf.DNSNodeTTL, - OnlyPassing: conf.DNSOnlyPassing, - RecursorTimeout: conf.DNSRecursorTimeout, - SegmentName: conf.SegmentName, - ServiceTTL: conf.DNSServiceTTL, - UDPAnswerLimit: conf.DNSUDPAnswerLimit, - NodeMetaTXT: conf.DNSNodeMetaTXT, - UseCache: conf.DNSUseCache, - CacheMaxAge: conf.DNSCacheMaxAge, - dnsSOAConfig: dnsSOAConfig{ +func GetDNSConfig(conf *config.RuntimeConfig) (*dnsConfig, error) { + cfg := &dnsConfig{ + AllowStale: conf.DNSAllowStale, + ARecordLimit: conf.DNSARecordLimit, + Datacenter: conf.Datacenter, + EnableTruncate: conf.DNSEnableTruncate, + MaxStale: conf.DNSMaxStale, + NodeName: conf.NodeName, + NodeTTL: conf.DNSNodeTTL, + OnlyPassing: conf.DNSOnlyPassing, + RecursorTimeout: conf.DNSRecursorTimeout, + SegmentName: conf.SegmentName, + UDPAnswerLimit: conf.DNSUDPAnswerLimit, + NodeMetaTXT: conf.DNSNodeMetaTXT, + DisableCompression: conf.DNSDisableCompression, + UseCache: conf.DNSUseCache, + CacheMaxAge: conf.DNSCacheMaxAge, + SOAConfig: dnsSOAConfig{ Expire: conf.DNSSOA.Expire, Minttl: conf.DNSSOA.Minttl, Refresh: conf.DNSSOA.Refresh, Retry: conf.DNSSOA.Retry, }, } + if conf.DNSServiceTTL != nil { + cfg.TTLRadix = radix.New() + cfg.TTLStrict = make(map[string]time.Duration) + + for key, ttl := range conf.DNSServiceTTL { + // All suffix with '*' are put in radix + // This include '*' that will match anything + if strings.HasSuffix(key, "*") { + cfg.TTLRadix.Insert(key[:len(key)-1], ttl) + } else { + cfg.TTLStrict[key] = ttl + } + } + } + for _, r := range conf.DNSRecursors { + ra, err := recursorAddr(r) + if err != nil { + return nil, fmt.Errorf("Invalid recursor address: %v", err) + } + cfg.Recursors = append(cfg.Recursors, ra) + } + + return cfg, nil } // GetTTLForService Find the TTL for a given service. // return ttl, true if found, 0, false otherwise -func (d *DNSServer) GetTTLForService(service string) (time.Duration, bool) { - if d.config.ServiceTTL != nil { - ttl, ok := d.ttlStrict[service] +func (cfg *dnsConfig) GetTTLForService(service string) (time.Duration, bool) { + if cfg.TTLStrict != nil { + ttl, ok := cfg.TTLStrict[service] if ok { return ttl, true } - _, ttlRaw, ok := d.ttlRadix.LongestPrefix(service) + } + if cfg.TTLRadix != nil { + _, ttlRaw, ok := cfg.TTLRadix.LongestPrefix(service) if ok { return ttlRaw.(time.Duration), true } } - return time.Duration(0), false + return 0, false } func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error { - mux := dns.NewServeMux() - mux.HandleFunc("arpa.", d.handlePtr) - mux.HandleFunc(d.domain, d.handleQuery) - if len(d.recursors) > 0 { - mux.HandleFunc(".", d.handleRecurse) - } + cfg := d.config.Load().(*dnsConfig) + + d.mux = dns.NewServeMux() + d.mux.HandleFunc("arpa.", d.handlePtr) + d.mux.HandleFunc(d.domain, d.handleQuery) + d.toggleRecursorHandlerFromConfig(cfg) d.Server = &dns.Server{ Addr: addr, Net: network, - Handler: mux, + Handler: d.mux, NotifyStartedFunc: notif, } if network == "udp" { @@ -190,6 +196,34 @@ func (d *DNSServer) ListenAndServe(network, addr string, notif func()) error { return d.Server.ListenAndServe() } +// toggleRecursorHandlerFromConfig enables or disables the recursor handler based on config idempotently +func (d *DNSServer) toggleRecursorHandlerFromConfig(cfg *dnsConfig) { + shouldEnable := len(cfg.Recursors) > 0 + + if shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 0, 1) { + d.mux.HandleFunc(".", d.handleRecurse) + d.logger.Println("[DEBUG] dns: recursor enabled") + return + } + + if !shouldEnable && atomic.CompareAndSwapUint32(&d.recursorEnabled, 1, 0) { + d.mux.HandleRemove(".") + d.logger.Println("[DEBUG] dns: recursor disabled") + return + } +} + +// ReloadConfig hot-reloads the server config with new parameters under config.RuntimeConfig.DNS* +func (d *DNSServer) ReloadConfig(newCfg *config.RuntimeConfig) error { + cfg, err := GetDNSConfig(newCfg) + if err != nil { + return err + } + d.config.Store(cfg) + d.toggleRecursorHandlerFromConfig(cfg) + return nil +} + // setEDNS is used to set the responses EDNS size headers and // possibly the ECS headers as well if they were present in the // original request @@ -258,16 +292,18 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { resp.RemoteAddr().Network()) }(time.Now()) + cfg := d.config.Load().(*dnsConfig) + // Setup the message response m := new(dns.Msg) m.SetReply(req) - m.Compress = !d.disableCompression.Load().(bool) + m.Compress = !cfg.DisableCompression m.Authoritative = true - m.RecursionAvailable = (len(d.recursors) > 0) + m.RecursionAvailable = (len(cfg.Recursors) > 0) // Only add the SOA if requested if req.Question[0].Qtype == dns.TypeSOA { - d.addSOA(m) + d.addSOA(cfg, m) } datacenter := d.agent.config.Datacenter @@ -279,7 +315,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { Datacenter: datacenter, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), - AllowStale: d.config.AllowStale, + AllowStale: cfg.AllowStale, }, } var out structs.IndexedNodes @@ -308,7 +344,7 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) { Datacenter: datacenter, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), - AllowStale: d.config.AllowStale, + AllowStale: cfg.AllowStale, }, ServiceAddress: serviceAddress, } @@ -360,25 +396,27 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { network = "tcp" } + cfg := d.config.Load().(*dnsConfig) + // Setup the message response m := new(dns.Msg) m.SetReply(req) - m.Compress = !d.disableCompression.Load().(bool) + m.Compress = !cfg.DisableCompression m.Authoritative = true - m.RecursionAvailable = (len(d.recursors) > 0) + m.RecursionAvailable = (len(cfg.Recursors) > 0) ecsGlobal := true switch req.Question[0].Qtype { case dns.TypeSOA: - ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) - m.Answer = append(m.Answer, d.soa()) + ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault) + m.Answer = append(m.Answer, d.soa(cfg)) m.Ns = append(m.Ns, ns...) m.Extra = append(m.Extra, glue...) m.SetRcode(req, dns.RcodeSuccess) case dns.TypeNS: - ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault) + ns, glue := d.nameservers(cfg, req.IsEdns0() != nil, maxRecursionLevelDefault) m.Answer = ns m.Extra = glue m.SetRcode(req, dns.RcodeSuccess) @@ -398,34 +436,34 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { } } -func (d *DNSServer) soa() *dns.SOA { +func (d *DNSServer) soa(cfg *dnsConfig) *dns.SOA { return &dns.SOA{ Hdr: dns.RR_Header{ Name: d.domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, // Has to be consistent with MinTTL to avoid invalidation - Ttl: d.config.dnsSOAConfig.Minttl, + Ttl: cfg.SOAConfig.Minttl, }, Ns: "ns." + d.domain, Serial: uint32(time.Now().Unix()), Mbox: "hostmaster." + d.domain, - Refresh: d.config.dnsSOAConfig.Refresh, - Retry: d.config.dnsSOAConfig.Retry, - Expire: d.config.dnsSOAConfig.Expire, - Minttl: d.config.dnsSOAConfig.Minttl, + Refresh: cfg.SOAConfig.Refresh, + Retry: cfg.SOAConfig.Retry, + Expire: cfg.SOAConfig.Expire, + Minttl: cfg.SOAConfig.Minttl, } } // addSOA is used to add an SOA record to a message for the given domain -func (d *DNSServer) addSOA(msg *dns.Msg) { - msg.Ns = append(msg.Ns, d.soa()) +func (d *DNSServer) addSOA(cfg *dnsConfig, msg *dns.Msg) { + msg.Ns = append(msg.Ns, d.soa(cfg)) } // nameservers returns the names and ip addresses of up to three random servers // in the current cluster which serve as authoritative name servers for zone. -func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { - out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel) +func (d *DNSServer) nameservers(cfg *dnsConfig, edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) { + out, err := d.lookupServiceNodes(cfg, d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel) if err != nil { d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err) return nil, nil @@ -456,15 +494,15 @@ func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, Name: d.domain, Rrtype: dns.TypeNS, Class: dns.ClassINET, - Ttl: uint32(d.config.NodeTTL / time.Second), + Ttl: uint32(cfg.NodeTTL / time.Second), }, Ns: fqdn, } ns = append(ns, nsrr) - glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel, d.config.NodeMetaTXT) + glue, meta := d.formatNodeRecord(cfg, nil, addr, fqdn, dns.TypeANY, cfg.NodeTTL, edns, maxRecursionLevel, cfg.NodeMetaTXT) extra = append(extra, glue...) - if meta != nil && d.config.NodeMetaTXT { + if meta != nil && cfg.NodeMetaTXT { extra = append(extra, meta...) } @@ -499,6 +537,8 @@ func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *d // Provide a flag for remembering whether the datacenter name was parsed already. var dcParsed bool + cfg := d.config.Load().(*dnsConfig) + // The last label is either "node", "service", "query", "_", or a datacenter name PARSE: n := len(labels) @@ -531,7 +571,7 @@ PARSE: } // _name._tag.service.consul - d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel) + d.serviceLookup(cfg, network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel) // Consul 0.3 and prior format for SRV queries } else { @@ -543,7 +583,7 @@ PARSE: } // tag[.tag].name.service.consul - d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel) + d.serviceLookup(cfg, network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel) } case "connect": @@ -552,7 +592,7 @@ PARSE: } // name.connect.consul - d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel) + d.serviceLookup(cfg, network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel) case "node": if n == 1 { @@ -561,7 +601,7 @@ PARSE: // Allow a "." in the node name, just join all the parts node := strings.Join(labels[:n-1], ".") - d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel) + d.nodeLookup(cfg, network, datacenter, node, req, resp, maxRecursionLevel) case "query": if n == 1 { @@ -571,7 +611,7 @@ PARSE: // Allow a "." in the query name, just join all the parts. query := strings.Join(labels[:n-1], ".") ecsGlobal = false - d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) + d.preparedQueryLookup(cfg, network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel) case "addr": if n != 2 { @@ -591,7 +631,7 @@ PARSE: Name: qName + d.domain, Rrtype: dns.TypeA, Class: dns.ClassINET, - Ttl: uint32(d.config.NodeTTL / time.Second), + Ttl: uint32(cfg.NodeTTL / time.Second), }, A: ip, }) @@ -607,7 +647,7 @@ PARSE: Name: qName + d.domain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, - Ttl: uint32(d.config.NodeTTL / time.Second), + Ttl: uint32(cfg.NodeTTL / time.Second), }, AAAA: ip, }) @@ -638,13 +678,13 @@ PARSE: return INVALID: d.logger.Printf("[WARN] dns: QName invalid: %s", qName) - d.addSOA(resp) + d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } // nodeLookup is used to handle a node query -func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) nodeLookup(cfg *dnsConfig, network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) { // Only handle ANY, A, AAAA, and TXT type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT { @@ -657,10 +697,10 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. Node: node, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), - AllowStale: d.config.AllowStale, + AllowStale: cfg.AllowStale, }, } - out, err := d.lookupNode(args) + out, err := d.lookupNode(cfg, args) if err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) @@ -669,7 +709,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. // If we have no address, return not found! if out.NodeServices == nil { - d.addSOA(resp) + d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -679,7 +719,7 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. if qType == dns.TypeANY || qType == dns.TypeTXT { generateMeta = true metaInAnswer = true - } else if d.config.NodeMetaTXT { + } else if cfg.NodeMetaTXT { generateMeta = true } @@ -687,21 +727,21 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns. n := out.NodeServices.Node edns := req.IsEdns0() != nil addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses) - records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel, generateMeta) + records, meta := d.formatNodeRecord(cfg, out.NodeServices.Node, addr, req.Question[0].Name, qType, cfg.NodeTTL, edns, maxRecursionLevel, generateMeta) if records != nil { resp.Answer = append(resp.Answer, records...) } if meta != nil && metaInAnswer && generateMeta { resp.Answer = append(resp.Answer, meta...) - } else if meta != nil && generateMeta { + } else if meta != nil && cfg.NodeMetaTXT { resp.Extra = append(resp.Extra, meta...) } } -func (d *DNSServer) lookupNode(args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { +func (d *DNSServer) lookupNode(cfg *dnsConfig, args *structs.NodeSpecificRequest) (*structs.IndexedNodeServices, error) { var out structs.IndexedNodeServices - useCache := d.config.UseCache + useCache := cfg.UseCache RPC: if useCache { raw, _, err := d.agent.cache.Get(cachetype.NodeServicesName, args) @@ -722,7 +762,7 @@ RPC: // Verify that request is not too stale, redo the request if args.AllowStale { - if out.LastContact > d.config.MaxStale { + if out.LastContact > cfg.MaxStale { args.AllowStale = false useCache = false d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") @@ -761,7 +801,7 @@ func encodeKVasRFC1464(key, value string) (txt string) { // The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node // and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the // generated RRs should go and if they should be used at all. -func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) { +func (d *DNSServer) formatNodeRecord(cfg *dnsConfig, node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int, generateMeta bool) (records, meta []dns.RR) { // Parse the IP ip := net.ParseIP(addr) var ipv4 net.IP @@ -807,7 +847,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy records = append(records, cnRec) // Recurse - more := d.resolveCNAME(cnRec.Target, maxRecursionLevel) + more := d.resolveCNAME(cfg, cnRec.Target, maxRecursionLevel) extra := 0 MORE_REC: for _, rr := range more { @@ -1036,21 +1076,21 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { } // trimDNSResponse will trim the response for UDP and TCP -func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed bool) { +func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) (trimmed bool) { if network != "tcp" { - trimmed = trimUDPResponse(req, resp, d.config.UDPAnswerLimit) + trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) } else { trimmed = d.trimTCPResponse(req, resp) } // Flag that there are more records to return in the UDP response - if trimmed && d.config.EnableTruncate { + if trimmed && cfg.EnableTruncate { resp.Truncated = true } return trimmed } // lookupServiceNodes returns nodes with a given service. -func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) { +func (d *DNSServer) lookupServiceNodes(cfg *dnsConfig, datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) { args := structs.ServiceSpecificRequest{ Connect: connect, Datacenter: datacenter, @@ -1059,14 +1099,14 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect TagFilter: tag != "", QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), - AllowStale: d.config.AllowStale, - MaxAge: d.config.CacheMaxAge, + AllowStale: cfg.AllowStale, + MaxAge: cfg.CacheMaxAge, }, } var out structs.IndexedCheckServiceNodes - if d.config.UseCache { + if cfg.UseCache { raw, m, err := d.agent.cache.Get(cachetype.HealthServicesName, &args) if err != nil { return out, err @@ -1090,7 +1130,7 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect } // redo the request the response was too stale - if args.AllowStale && out.LastContact > d.config.MaxStale { + if args.AllowStale && out.LastContact > cfg.MaxStale { args.AllowStale = false d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") @@ -1103,13 +1143,13 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect // We copy the slice to avoid modifying the result if it comes from the cache nodes := make(structs.CheckServiceNodes, len(out.Nodes)) copy(nodes, out.Nodes) - out.Nodes = nodes.Filter(d.config.OnlyPassing) + out.Nodes = nodes.Filter(cfg.OnlyPassing) return out, nil } // serviceLookup is used to handle a service query -func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) { - out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel) +func (d *DNSServer) serviceLookup(cfg *dnsConfig, network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) { + out, err := d.lookupServiceNodes(cfg, datacenter, service, tag, connect, maxRecursionLevel) if err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) @@ -1118,7 +1158,7 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(resp) + d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -1127,21 +1167,21 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn out.Nodes.Shuffle() // Determine the TTL - ttl, _ := d.GetTTLForService(service) + ttl, _ := cfg.GetTTLForService(service) // Add various responses depending on the request qType := req.Question[0].Qtype if qType == dns.TypeSRV { - d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) + d.serviceSRVRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } else { - d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) + d.serviceNodeRecords(cfg, datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } - d.trimDNSResponse(network, req, resp) + d.trimDNSResponse(cfg, network, req, resp) // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(resp) + d.addSOA(cfg, resp) return } } @@ -1164,15 +1204,15 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET { } // preparedQueryLookup is used to handle a prepared query. -func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { +func (d *DNSServer) preparedQueryLookup(cfg *dnsConfig, network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) { // Execute the prepared query. args := structs.PreparedQueryExecuteRequest{ Datacenter: datacenter, QueryIDOrName: query, QueryOptions: structs.QueryOptions{ Token: d.agent.tokens.UserToken(), - AllowStale: d.config.AllowStale, - MaxAge: d.config.CacheMaxAge, + AllowStale: cfg.AllowStale, + MaxAge: cfg.CacheMaxAge, }, // Always pass the local agent through. In the DNS interface, there @@ -1201,13 +1241,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot } } - out, err := d.lookupPreparedQuery(args) + out, err := d.lookupPreparedQuery(cfg, args) // If they give a bogus query name, treat that as a name error, // not a full on server error. We have to use a string compare // here since the RPC layer loses the type information. if err != nil && err.Error() == consul.ErrQueryNotFound.Error() { - d.addSOA(resp) + d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } else if err != nil { @@ -1234,13 +1274,13 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot if err != nil { d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query) } - } else if d.config.ServiceTTL != nil { - ttl, _ = d.GetTTLForService(out.Service) + } else { + ttl, _ = cfg.GetTTLForService(out.Service) } // If we have no nodes, return not found! if len(out.Nodes) == 0 { - d.addSOA(resp) + d.addSOA(cfg, resp) resp.SetRcode(req, dns.RcodeNameError) return } @@ -1248,25 +1288,25 @@ func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remot // Add various responses depending on the request. qType := req.Question[0].Qtype if qType == dns.TypeSRV { - d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) + d.serviceSRVRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } else { - d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) + d.serviceNodeRecords(cfg, out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel) } - d.trimDNSResponse(network, req, resp) + d.trimDNSResponse(cfg, network, req, resp) // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { - d.addSOA(resp) + d.addSOA(cfg, resp) return } } -func (d *DNSServer) lookupPreparedQuery(args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { +func (d *DNSServer) lookupPreparedQuery(cfg *dnsConfig, args structs.PreparedQueryExecuteRequest) (*structs.PreparedQueryExecuteResponse, error) { var out structs.PreparedQueryExecuteResponse RPC: - if d.config.UseCache { + if cfg.UseCache { raw, m, err := d.agent.cache.Get(cachetype.PreparedQueryName, &args) if err != nil { return nil, err @@ -1288,7 +1328,7 @@ RPC: // Verify that request is not too stale, redo the request. if args.AllowStale { - if out.LastContact > d.config.MaxStale { + if out.LastContact > cfg.MaxStale { args.AllowStale = false d.logger.Printf("[WARN] dns: Query results too stale, re-requesting") goto RPC @@ -1301,7 +1341,7 @@ RPC: } // serviceNodeRecords is used to add the node records for a service lookup -func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { +func (d *DNSServer) serviceNodeRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { qName := req.Question[0].Name qType := req.Question[0].Qtype handled := make(map[string]struct{}) @@ -1335,13 +1375,13 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode if qType == dns.TypeANY || qType == dns.TypeTXT { generateMeta = true metaInAnswer = true - } else if d.config.NodeMetaTXT { + } else if cfg.NodeMetaTXT { generateMeta = true } // Add the node record had_answer := false - records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta) + records, meta := d.formatNodeRecord(cfg, node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel, generateMeta) if records != nil { switch records[0].(type) { case *dns.CNAME: @@ -1365,7 +1405,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode if had_answer { count++ - if count == d.config.ARecordLimit { + if count == cfg.ARecordLimit { // We stop only if greater than 0 or we reached the limit return } @@ -1423,7 +1463,7 @@ func findWeight(node structs.CheckServiceNode) int { } // serviceARecords is used to add the SRV records for a service lookup -func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { +func (d *DNSServer) serviceSRVRecords(cfg *dnsConfig, dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) { handled := make(map[string]struct{}) edns := req.IsEdns0() != nil @@ -1460,7 +1500,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } // Add the extra record - records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, d.config.NodeMetaTXT) + records, meta := d.formatNodeRecord(cfg, node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel, cfg.NodeMetaTXT) if len(records) > 0 { // Use the node address if it doesn't differ from the service address if addr == node.Node.Address { @@ -1491,7 +1531,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes } } - if meta != nil && d.config.NodeMetaTXT { + if meta != nil && cfg.NodeMetaTXT { resp.Extra = append(resp.Extra, meta...) } } @@ -1500,6 +1540,8 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes // handleRecurse is used to handle recursive DNS queries func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { + cfg := d.config.Load().(*dnsConfig) + q := req.Question[0] network := "udp" defer func(s time.Time) { @@ -1514,11 +1556,11 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { } // Recursively resolve - c := &dns.Client{Net: network, Timeout: d.config.RecursorTimeout} + c := &dns.Client{Net: network, Timeout: cfg.RecursorTimeout} var r *dns.Msg var rtt time.Duration var err error - for _, recursor := range d.recursors { + for _, recursor := range cfg.Recursors { r, rtt, err = c.Exchange(req, recursor) // Check if the response is valid and has the desired Response code if r != nil && (r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError) { @@ -1530,7 +1572,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { // Compress the response; we don't know if the incoming // response was compressed or not, so by not compressing // we might generate an invalid packet on the way out. - r.Compress = !d.disableCompression.Load().(bool) + r.Compress = !cfg.DisableCompression // Forward the response d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v) Recursor queried: %v", q, rtt, recursor) @@ -1547,7 +1589,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { q, resp.RemoteAddr().String(), resp.RemoteAddr().Network()) m := &dns.Msg{} m.SetReply(req) - m.Compress = !d.disableCompression.Load().(bool) + m.Compress = !cfg.DisableCompression m.RecursionAvailable = true m.SetRcode(req, dns.RcodeServerFailure) if edns := req.IsEdns0(); edns != nil { @@ -1557,7 +1599,7 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { } // resolveCNAME is used to recursively resolve CNAME records -func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR { +func (d *DNSServer) resolveCNAME(cfg *dnsConfig, name string, maxRecursionLevel int) []dns.RR { // If the CNAME record points to a Consul address, resolve it internally // Convert query to lowercase because DNS is case insensitive; d.domain is // already converted @@ -1577,7 +1619,7 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR { } // Do nothing if we don't have a recursor - if len(d.recursors) == 0 { + if len(cfg.Recursors) == 0 { return nil } @@ -1586,11 +1628,11 @@ func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR { m.SetQuestion(name, dns.TypeA) // Make a DNS lookup request - c := &dns.Client{Net: "udp", Timeout: d.config.RecursorTimeout} + c := &dns.Client{Net: "udp", Timeout: cfg.RecursorTimeout} var r *dns.Msg var rtt time.Duration var err error - for _, recursor := range d.recursors { + for _, recursor := range cfg.Recursors { r, rtt, err = c.Exchange(m, recursor) if err == nil { d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt) diff --git a/agent/dns_test.go b/agent/dns_test.go index 63c52e7265..571684385c 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" "reflect" + "sort" "strings" "testing" "time" @@ -3740,6 +3741,24 @@ func TestDNS_ServiceLookup_OnlyPassing(t *testing.T) { t.Fatalf("Bad: %#v", in.Answer[0]) } } + + newCfg := *a.Config + newCfg.DNSOnlyPassing = false + err := a.ReloadConfig(&newCfg) + require.NoError(t, err) + + // only_passing is now false. we should now get two nodes + m := new(dns.Msg) + m.SetQuestion("db.service.consul.", dns.TypeANY) + + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + require.NoError(t, err) + + require.Equal(t, 2, len(in.Answer)) + ips := []string{in.Answer[0].(*dns.A).A.String(), in.Answer[1].(*dns.A).A.String()} + sort.Strings(ips) + require.Equal(t, []string{"127.0.0.1", "127.0.0.2"}, ips) } func TestDNS_ServiceLookup_Randomize(t *testing.T) { @@ -5190,7 +5209,6 @@ func TestDNS_ServiceLookup_FilterACL(t *testing.T) { }) } } - func TestDNS_ServiceLookup_MetaTXT(t *testing.T) { a := NewTestAgent(t, t.Name(), `dns_config = { enable_additional_node_meta_txt = true }`) defer a.Shutdown() @@ -6341,11 +6359,177 @@ func TestDNS_formatNodeRecord(t *testing.T) { }, } - records, meta := s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false) + records, meta := s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, false) require.Len(t, records, 1) require.Len(t, meta, 0) - records, meta = s.formatNodeRecord(node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true) + records, meta = s.formatNodeRecord(&dnsConfig{}, node, "198.18.0.1", "test.node.consul", dns.TypeA, 5*time.Minute, false, 3, true) require.Len(t, records, 1) require.Len(t, meta, 2) } + +func TestDNS_ConfigReload(t *testing.T) { + t.Parallel() + + a := NewTestAgent(t, t.Name(), ` + recursors = ["8.8.8.8:53"] + dns_config = { + allow_stale = false + max_stale = "20s" + node_ttl = "10s" + service_ttl = { + "my_services*" = "5s" + "my_specific_service" = "30s" + } + enable_truncate = false + only_passing = false + recursor_timeout = "15s" + disable_compression = false + a_record_limit = 1 + enable_additional_node_meta_txt = false + soa = { + refresh = 1 + retry = 2 + expire = 3 + min_ttl = 4 + } + } + `) + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + for _, s := range a.dnsServers { + cfg := s.config.Load().(*dnsConfig) + require.Equal(t, []string{"8.8.8.8:53"}, cfg.Recursors) + require.False(t, cfg.AllowStale) + require.Equal(t, 20*time.Second, cfg.MaxStale) + require.Equal(t, 10*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, 5*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, 30*time.Second, ttl) + require.False(t, cfg.EnableTruncate) + require.False(t, cfg.OnlyPassing) + require.Equal(t, 15*time.Second, cfg.RecursorTimeout) + require.False(t, cfg.DisableCompression) + require.Equal(t, 1, cfg.ARecordLimit) + require.False(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(1), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(2), cfg.SOAConfig.Retry) + require.Equal(t, uint32(3), cfg.SOAConfig.Expire) + require.Equal(t, uint32(4), cfg.SOAConfig.Minttl) + } + + newCfg := *a.Config + newCfg.DNSRecursors = []string{"1.1.1.1:53"} + newCfg.DNSAllowStale = true + newCfg.DNSMaxStale = 21 * time.Second + newCfg.DNSNodeTTL = 11 * time.Second + newCfg.DNSServiceTTL = map[string]time.Duration{ + "2_my_services*": 6 * time.Second, + "2_my_specific_service": 31 * time.Second, + } + newCfg.DNSEnableTruncate = true + newCfg.DNSOnlyPassing = true + newCfg.DNSRecursorTimeout = 16 * time.Second + newCfg.DNSDisableCompression = true + newCfg.DNSARecordLimit = 2 + newCfg.DNSNodeMetaTXT = true + newCfg.DNSSOA.Refresh = 10 + newCfg.DNSSOA.Retry = 20 + newCfg.DNSSOA.Expire = 30 + newCfg.DNSSOA.Minttl = 40 + + err := a.ReloadConfig(&newCfg) + require.NoError(t, err) + + for _, s := range a.dnsServers { + cfg := s.config.Load().(*dnsConfig) + require.Equal(t, []string{"1.1.1.1:53"}, cfg.Recursors) + require.True(t, cfg.AllowStale) + require.Equal(t, 21*time.Second, cfg.MaxStale) + require.Equal(t, 11*time.Second, cfg.NodeTTL) + ttl, _ := cfg.GetTTLForService("my_services_1") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_services_1") + require.Equal(t, 6*time.Second, ttl) + ttl, _ = cfg.GetTTLForService("my_specific_service") + require.Equal(t, time.Duration(0), ttl) + ttl, _ = cfg.GetTTLForService("2_my_specific_service") + require.Equal(t, 31*time.Second, ttl) + require.True(t, cfg.EnableTruncate) + require.True(t, cfg.OnlyPassing) + require.Equal(t, 16*time.Second, cfg.RecursorTimeout) + require.True(t, cfg.DisableCompression) + require.Equal(t, 2, cfg.ARecordLimit) + require.True(t, cfg.NodeMetaTXT) + require.Equal(t, uint32(10), cfg.SOAConfig.Refresh) + require.Equal(t, uint32(20), cfg.SOAConfig.Retry) + require.Equal(t, uint32(30), cfg.SOAConfig.Expire) + require.Equal(t, uint32(40), cfg.SOAConfig.Minttl) + } +} + +func TestDNS_ReloadConfig_DuringQuery(t *testing.T) { + t.Parallel() + a := NewTestAgent(t, t.Name(), "") + defer a.Shutdown() + testrpc.WaitForLeader(t, a.RPC, "dc1") + + m := MockPreparedQuery{ + executeFn: func(args *structs.PreparedQueryExecuteRequest, reply *structs.PreparedQueryExecuteResponse) error { + time.Sleep(100 * time.Millisecond) + reply.Nodes = structs.CheckServiceNodes{ + { + Node: &structs.Node{ + ID: "my_node", + Address: "127.0.0.1", + }, + Service: &structs.NodeService{ + Address: "127.0.0.1", + Port: 8080, + }, + }, + } + return nil + }, + } + + err := a.registerEndpoint("PreparedQuery", &m) + require.NoError(t, err) + + { + m := new(dns.Msg) + m.SetQuestion("nope.query.consul.", dns.TypeA) + + timeout := time.NewTimer(time.Second) + res := make(chan *dns.Msg) + errs := make(chan error) + + go func() { + c := new(dns.Client) + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil { + errs <- err + return + } + res <- in + }() + + time.Sleep(50 * time.Millisecond) + + // reload the config halfway through, that should not affect the ongoing query + newCfg := *a.Config + newCfg.DNSAllowStale = true + a.ReloadConfig(&newCfg) + + select { + case in := <-res: + require.Equal(t, "127.0.0.1", in.Answer[0].(*dns.A).A.String()) + case err := <-errs: + require.NoError(t, err) + case <-timeout.C: + require.FailNow(t, "timeout") + } + } +} diff --git a/agent/testagent.go b/agent/testagent.go index bc74aad3ee..dbab2a4a8b 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -279,7 +279,8 @@ func (a *TestAgent) Client() *api.Client { // DNSDisableCompression disables compression for all started DNS servers. func (a *TestAgent) DNSDisableCompression(b bool) { for _, srv := range a.dnsServers { - srv.disableCompression.Store(b) + cfg := srv.config.Load().(*dnsConfig) + cfg.DisableCompression = b } }