package agent import ( "fmt" "github.com/hashicorp/consul/consul/structs" "github.com/miekg/dns" "io" "log" "net" "strings" "time" ) const ( testQuery = "_test.consul." consulDomain = "consul." ) // DNSServer is used to wrap an Agent and expose various // service discovery endpoints using a DNS interface. type DNSServer struct { agent *Agent dnsHandler *dns.ServeMux dnsServer *dns.Server dnsServerTCP *dns.Server domain string recursor string logger *log.Logger } // NewDNSServer starts a new DNS server to provide an agent interface func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) { // Make sure domain is FQDN domain = dns.Fqdn(domain) // Construct the DNS components mux := dns.NewServeMux() // Setup the servers server := &dns.Server{ Addr: bind, Net: "udp", Handler: mux, UDPSize: 65535, } serverTCP := &dns.Server{ Addr: bind, Net: "tcp", Handler: mux, } // Create the server srv := &DNSServer{ agent: agent, dnsHandler: mux, dnsServer: server, dnsServerTCP: serverTCP, domain: domain, recursor: recursor, logger: log.New(logOutput, "", log.LstdFlags), } // Register mux handlers, always handle "consul." mux.HandleFunc(domain, srv.handleQuery) if domain != consulDomain { mux.HandleFunc(consulDomain, srv.handleTest) } if recursor != "" { mux.HandleFunc(".", srv.handleRecurse) } // Async start the DNS Servers, handle a potential error errCh := make(chan error, 1) go func() { err := server.ListenAndServe() srv.logger.Printf("[ERR] dns: error starting udp server: %v", err) errCh <- err }() errChTCP := make(chan error, 1) go func() { err := serverTCP.ListenAndServe() srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err) errChTCP <- err }() // Check the server is running, do a test lookup checkCh := make(chan error, 1) go func() { // This is jank, but we have no way to edge trigger on // the start of our server, so we just wait and hope it is up. time.Sleep(50 * time.Millisecond) m := new(dns.Msg) m.SetQuestion(testQuery, dns.TypeANY) c := new(dns.Client) in, _, err := c.Exchange(m, bind) if err != nil { checkCh <- err return } if len(in.Answer) == 0 { checkCh <- fmt.Errorf("no response to test message") return } close(checkCh) }() // Wait for either the check, listen error, or timeout select { case e := <-errCh: return srv, e case e := <-errChTCP: return srv, e case e := <-checkCh: return srv, e case <-time.After(time.Second): return srv, fmt.Errorf("timeout setting up DNS server") } return srv, nil } // handleQUery is used to handle DNS queries in the configured domain func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] defer func(s time.Time) { d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s)) }(time.Now()) // Check if this is potentially a test query if q.Name == testQuery { d.handleTest(resp, req) return } // Setup the message response m := new(dns.Msg) m.SetReply(req) m.Authoritative = true d.addSOA(d.domain, m) // Dispatch the correct handler d.dispatch(req, m) // Write out the complete response if err := resp.WriteMsg(m); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } } // handleTest is used to handle DNS queries in the ".consul." domain func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] defer func(s time.Time) { d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s)) }(time.Now()) if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) { return } if q.Name != testQuery { return } // Always respond with TXT "ok" m := new(dns.Msg) m.SetReply(req) m.Authoritative = true header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} txt := &dns.TXT{header, []string{"ok"}} m.Answer = append(m.Answer, txt) d.addSOA(consulDomain, m) if err := resp.WriteMsg(m); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } } // addSOA is used to add an SOA record to a message for the given domain func (d *DNSServer) addSOA(domain string, msg *dns.Msg) { soa := &dns.SOA{ Hdr: dns.RR_Header{ Name: domain, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 0, }, Ns: "ns." + domain, Mbox: "postmaster." + domain, Serial: uint32(time.Now().Unix()), Refresh: 3600, Retry: 600, Expire: 86400, Minttl: 0, } msg.Ns = append(msg.Ns, soa) } // dispatch is used to parse a request and invoke the correct handler func (d *DNSServer) dispatch(req, resp *dns.Msg) { // By default the query is in the default datacenter datacenter := d.agent.config.Datacenter // Get the QName without the domain suffix qName := dns.Fqdn(req.Question[0].Name) qName = strings.TrimSuffix(qName, d.domain) // Split into the label parts labels := dns.SplitDomainName(qName) // The last label is either "node", "service" or a datacenter name PARSE: if len(labels) == 0 { goto INVALID } switch labels[len(labels)-1] { case "service": // Handle lookup with and without tag switch len(labels) { case 2: d.serviceLookup(datacenter, labels[0], "", req, resp) case 3: d.serviceLookup(datacenter, labels[1], labels[0], req, resp) default: goto INVALID } case "node": if len(labels) != 2 { goto INVALID } d.nodeLookup(datacenter, labels[0], req, resp) default: // Store the DC, and re-parse datacenter = labels[len(labels)-1] labels = labels[:len(labels)-1] goto PARSE } return INVALID: d.logger.Printf("[WARN] dns: QName invalid: %s", qName) resp.SetRcode(req, dns.RcodeNameError) } // nodeLookup is used to handle a node query func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) { // Only handle ANY and A type requests qType := req.Question[0].Qtype if qType != dns.TypeANY && qType != dns.TypeA { return } // Make an RPC request args := structs.NodeSpecificRequest{ Datacenter: datacenter, Node: node, } var out structs.NodeServices if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) return } // If we have no address, return not found! if out.Node.Address == "" { resp.SetRcode(req, dns.RcodeNameError) return } // Parse the IP ip := net.ParseIP(out.Node.Address) if ip == nil { d.logger.Printf("[ERR] dns: failed to parse IP %v", out.Node) resp.SetRcode(req, dns.RcodeServerFailure) return } // Format A record aRec := &dns.A{ Hdr: dns.RR_Header{ Name: req.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0, }, A: ip, } // Add the response resp.Answer = append(resp.Answer, aRec) } // serviceLookup is used to handle a service query func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dns.Msg) { // Make an RPC request args := structs.ServiceSpecificRequest{ Datacenter: datacenter, ServiceName: service, ServiceTag: tag, TagFilter: tag != "", } var out structs.CheckServiceNodes if err := d.agent.RPC("Health.ServiceNodes", &args, &out); err != nil { d.logger.Printf("[ERR] dns: rpc error: %v", err) resp.SetRcode(req, dns.RcodeServerFailure) return } // If we have no nodes, return not found! if len(out) == 0 { resp.SetRcode(req, dns.RcodeNameError) return } // Add various responses depending on the request qType := req.Question[0].Qtype if qType == dns.TypeANY || qType == dns.TypeA { d.serviceARecords(out, req, resp) } if qType == dns.TypeANY || qType == dns.TypeSRV { d.serviceSRVRecords(datacenter, out, req, resp) } } // serviceARecords is used to add the A records for a service lookup func (d *DNSServer) serviceARecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) { handled := make(map[string]struct{}) for _, node := range nodes { // Avoid duplicate entries, possible if a node has // the same service on multiple ports, etc. addr := node.Node.Address if _, ok := handled[addr]; ok { continue } handled[addr] = struct{}{} ip := net.ParseIP(addr) if ip == nil { d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", addr, node.Node) continue } aRec := &dns.A{ Hdr: dns.RR_Header{ Name: req.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0, }, A: ip, } resp.Answer = append(resp.Answer, aRec) } } // serviceARecords is used to add the SRV records for a service lookup func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg) { handled := make(map[string]struct{}) for _, node := range nodes { // Avoid duplicate entries, possible if a node has // the same service the same port, etc. tuple := fmt.Sprintf("%s:%d", node.Node.Node, node.Service.Port) if _, ok := handled[tuple]; ok { continue } handled[tuple] = struct{}{} // Add the SRV record srvRec := &dns.SRV{ Hdr: dns.RR_Header{ Name: req.Question[0].Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: 0, }, Priority: 1, Weight: 1, Port: uint16(node.Service.Port), Target: fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain), } resp.Answer = append(resp.Answer, srvRec) // Avoid duplicate A records, possible if a node has // the same service on multiple ports, etc. addr := node.Node.Address if _, ok := handled[addr]; ok { continue } handled[addr] = struct{}{} ip := net.ParseIP(addr) if ip == nil { d.logger.Printf("[ERR] dns: failed to parse IP %v for %v", addr, node.Node) continue } aRec := &dns.A{ Hdr: dns.RR_Header{ Name: srvRec.Target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0, }, A: ip, } resp.Extra = append(resp.Extra, aRec) } } // handleRecurse is used to handle recursive DNS queries func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] network := "udp" defer func(s time.Time) { d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v)", q, network, time.Now().Sub(s)) }(time.Now()) // Switch to TCP if the client is if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok { network = "tcp" } // Recursively resolve c := &dns.Client{Net: network} r, rtt, err := c.Exchange(req, d.recursor) // On failure, return a SERVFAIL message if err != nil { d.logger.Printf("[ERR] dns: recurse failed: %v", err) m := &dns.Msg{} m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) resp.WriteMsg(m) return } d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt) // Forward the response if err := resp.WriteMsg(r); err != nil { d.logger.Printf("[WARN] dns: failed to respond: %v", err) } }