2016-02-07 13:39:37 -08:00

802 lines
22 KiB
Go

package agent
import (
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/consul"
"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
)
const (
maxServiceResponses = 3 // For UDP only
maxRecurseRecords = 5
)
// DNSServer is used to wrap an Agent and expose various
// service discovery endpoints using a DNS interface.
type DNSServer struct {
agent *Agent
config *DNSConfig
dnsHandler *dns.ServeMux
dnsServer *dns.Server
dnsServerTCP *dns.Server
domain string
recursors []string
logger *log.Logger
}
// Shutdown stops the DNS Servers
func (d *DNSServer) Shutdown() {
if err := d.dnsServer.Shutdown(); err != nil {
d.logger.Printf("[ERR] dns: error stopping udp server: %v", err)
}
if err := d.dnsServerTCP.Shutdown(); err != nil {
d.logger.Printf("[ERR] dns: error stopping tcp server: %v", err)
}
}
// NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain string, bind string, recursors []string) (*DNSServer, error) {
// Make sure domain is FQDN
domain = dns.Fqdn(domain)
// Construct the DNS components
mux := dns.NewServeMux()
var wg sync.WaitGroup
// Setup the servers
server := &dns.Server{
Addr: bind,
Net: "udp",
Handler: mux,
UDPSize: 65535,
NotifyStartedFunc: wg.Done,
}
serverTCP := &dns.Server{
Addr: bind,
Net: "tcp",
Handler: mux,
NotifyStartedFunc: wg.Done,
}
// Create the server
srv := &DNSServer{
agent: agent,
config: config,
dnsHandler: mux,
dnsServer: server,
dnsServerTCP: serverTCP,
domain: domain,
recursors: recursors,
logger: log.New(logOutput, "", log.LstdFlags),
}
// Register mux handler, for reverse lookup
mux.HandleFunc("arpa.", srv.handlePtr)
// Register mux handlers
mux.HandleFunc(domain, srv.handleQuery)
if len(recursors) > 0 {
validatedRecursors := make([]string, len(recursors))
for idx, recursor := range recursors {
recursor, err := recursorAddr(recursor)
if err != nil {
return nil, fmt.Errorf("Invalid recursor address: %v", err)
}
validatedRecursors[idx] = recursor
}
srv.recursors = validatedRecursors
mux.HandleFunc(".", srv.handleRecurse)
}
wg.Add(2)
// Async start the DNS Servers, handle a potential error
errCh := make(chan error, 1)
go func() {
if err := server.ListenAndServe(); err != nil {
srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
errCh <- fmt.Errorf("dns udp setup failed: %v", err)
}
}()
errChTCP := make(chan error, 1)
go func() {
if err := serverTCP.ListenAndServe(); err != nil {
srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
errChTCP <- fmt.Errorf("dns tcp setup failed: %v", err)
}
}()
// Wait for NotifyStartedFunc callbacks indicating server has started
startCh := make(chan struct{})
go func() {
wg.Wait()
close(startCh)
}()
// Wait for either the check, listen error, or timeout
select {
case e := <-errCh:
return srv, e
case e := <-errChTCP:
return srv, e
case <-startCh:
return srv, nil
case <-time.After(time.Second):
return srv, fmt.Errorf("timeout setting up DNS server")
}
}
// recursorAddr is used to add a port to the recursor if omitted.
func recursorAddr(recursor string) (string, error) {
// Add the port if none
START:
_, _, err := net.SplitHostPort(recursor)
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
recursor = fmt.Sprintf("%s:%d", recursor, 53)
goto START
}
if err != nil {
return "", err
}
// Get the address
addr, err := net.ResolveTCPAddr("tcp", recursor)
if err != nil {
return "", err
}
// Return string
return addr.String(), nil
}
// handlePtr is used to handle "reverse" DNS queries
func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
defer func(s time.Time) {
metrics.MeasureSince([]string{"consul", "dns", "ptr_query", d.agent.config.NodeName}, s)
d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)",
q, time.Now().Sub(s), resp.RemoteAddr().String(),
resp.RemoteAddr().Network())
}(time.Now())
// Setup the message response
m := new(dns.Msg)
m.SetReply(req)
m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0)
// Only add the SOA if requested
if req.Question[0].Qtype == dns.TypeSOA {
d.addSOA(d.domain, m)
}
datacenter := d.agent.config.Datacenter
// Get the QName without the domain suffix
qName := strings.ToLower(dns.Fqdn(req.Question[0].Name))
args := structs.DCSpecificRequest{
Datacenter: datacenter,
QueryOptions: structs.QueryOptions{
Token: d.agent.config.ACLToken,
AllowStale: d.config.AllowStale,
},
}
var out structs.IndexedNodes
// TODO: Replace ListNodes with an internal RPC that can do the filter
// server side to avoid transferring the entire node list.
if err := d.agent.RPC("Catalog.ListNodes", &args, &out); err == nil {
for _, n := range out.Nodes {
arpa, _ := dns.ReverseAddr(n.Address)
if arpa == qName {
ptr := &dns.PTR{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: 0},
Ptr: fmt.Sprintf("%s.node.%s.%s", n.Node, datacenter, d.domain),
}
m.Answer = append(m.Answer, ptr)
break
}
}
}
// nothing found locally, recurse
if len(m.Answer) == 0 {
d.handleRecurse(resp, req)
return
}
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}
// handleQuery is used to handle DNS queries in the configured domain
func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
defer func(s time.Time) {
metrics.MeasureSince([]string{"consul", "dns", "domain_query", d.agent.config.NodeName}, s)
d.logger.Printf("[DEBUG] dns: request for %v (%v) from client %s (%s)",
q, time.Now().Sub(s), resp.RemoteAddr().String(),
resp.RemoteAddr().Network())
}(time.Now())
// Switch to TCP if the client is
network := "udp"
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
network = "tcp"
}
// Setup the message response
m := new(dns.Msg)
m.SetReply(req)
m.Authoritative = true
m.RecursionAvailable = (len(d.recursors) > 0)
// Only add the SOA if requested
if req.Question[0].Qtype == dns.TypeSOA {
d.addSOA(d.domain, m)
}
// Dispatch the correct handler
d.dispatch(network, req, m)
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}
// addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
soa := &dns.SOA{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
Ttl: 0,
},
Ns: "ns." + domain,
Mbox: "postmaster." + domain,
Serial: uint32(time.Now().Unix()),
Refresh: 3600,
Retry: 600,
Expire: 86400,
Minttl: 0,
}
msg.Ns = append(msg.Ns, soa)
}
// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, req, resp *dns.Msg) {
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
// Get the QName without the domain suffix
qName := strings.ToLower(dns.Fqdn(req.Question[0].Name))
qName = strings.TrimSuffix(qName, d.domain)
// Split into the label parts
labels := dns.SplitDomainName(qName)
// The last label is either "node", "service", "query", or a datacenter name
PARSE:
n := len(labels)
if n == 0 {
goto INVALID
}
switch labels[n-1] {
case "service":
if n == 1 {
goto INVALID
}
// Support RFC 2782 style syntax
if n == 3 && strings.HasPrefix(labels[n-2], "_") && strings.HasPrefix(labels[n-3], "_") {
// Grab the tag since we make nuke it if it's tcp
tag := labels[n-2][1:]
// Treat _name._tcp.service.consul as a default, no need to filter on that tag
if tag == "tcp" {
tag = ""
}
// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, req, resp)
// Consul 0.3 and prior format for SRV queries
} else {
// Support "." in the label, re-join all the parts
tag := ""
if n >= 3 {
tag = strings.Join(labels[:n-2], ".")
}
// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, req, resp)
}
case "node":
if n == 1 {
goto INVALID
}
// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
case "query":
if n == 1 {
goto INVALID
}
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
d.preparedQueryLookup(network, datacenter, query, req, resp)
default:
// Store the DC, and re-parse
datacenter = labels[n-1]
labels = labels[:n-1]
goto PARSE
}
return
INVALID:
d.logger.Printf("[WARN] dns: QName invalid: %s", qName)
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
}
// translateAddr is used to provide the final, translated address for a node,
// depending on how this agent and the other node are configured.
func (d *DNSServer) translateAddr(dc string, node *structs.Node) string {
addr := node.Address
if d.agent.config.TranslateWanAddrs && (d.agent.config.Datacenter != dc) {
wanAddr := node.TaggedAddresses["wan"]
if wanAddr != "" {
addr = wanAddr
}
}
return addr
}
// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
// Only handle ANY, A and AAAA type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA {
return
}
// Make an RPC request
args := structs.NodeSpecificRequest{
Datacenter: datacenter,
Node: node,
QueryOptions: structs.QueryOptions{
Token: d.agent.config.ACLToken,
AllowStale: d.config.AllowStale,
},
}
var out structs.IndexedNodeServices
RPC:
if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// Verify that request is not too stale, redo the request
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}
// If we have no address, return not found!
if out.NodeServices == nil {
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
// Add the node record
addr := d.translateAddr(datacenter, out.NodeServices.Node)
records := d.formatNodeRecord(out.NodeServices.Node, addr,
req.Question[0].Name, qType, d.config.NodeTTL)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}
// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
if ip != nil {
ipv4 = ip.To4()
}
switch {
case ipv4 != nil && (qType == dns.TypeANY || qType == dns.TypeA):
return []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: qName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
},
A: ip,
}}
case ip != nil && ipv4 == nil && (qType == dns.TypeANY || qType == dns.TypeAAAA):
return []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: qName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
},
AAAA: ip,
}}
case ip == nil && (qType == dns.TypeANY || qType == dns.TypeCNAME ||
qType == dns.TypeA || qType == dns.TypeAAAA):
// Get the CNAME
cnRec := &dns.CNAME{
Hdr: dns.RR_Header{
Name: qName,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
},
Target: dns.Fqdn(addr),
}
records = append(records, cnRec)
// Recurse
more := d.resolveCNAME(cnRec.Target)
extra := 0
MORE_REC:
for _, rr := range more {
switch rr.Header().Rrtype {
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
records = append(records, rr)
extra++
if extra == maxRecurseRecords {
break MORE_REC
}
}
}
}
return records
}
// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
// Make an RPC request
args := structs.ServiceSpecificRequest{
Datacenter: datacenter,
ServiceName: service,
ServiceTag: tag,
TagFilter: tag != "",
QueryOptions: structs.QueryOptions{
Token: d.agent.config.ACLToken,
AllowStale: d.config.AllowStale,
},
}
var out structs.IndexedCheckServiceNodes
RPC:
if err := d.agent.RPC("Health.ServiceNodes", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// Verify that request is not too stale, redo the request
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}
// Determine the TTL
var ttl time.Duration
if d.config.ServiceTTL != nil {
var ok bool
ttl, ok = d.config.ServiceTTL[service]
if !ok {
ttl = d.config.ServiceTTL["*"]
}
}
// Filter out any service nodes due to health checks
out.Nodes = out.Nodes.Filter(d.config.OnlyPassing)
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
// Perform a random shuffle
out.Nodes.Shuffle()
// Add various responses depending on the request
qType := req.Question[0].Qtype
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
}
// If the network is not TCP, restrict the number of responses
if network != "tcp" && len(resp.Answer) > maxServiceResponses {
resp.Answer = resp.Answer[:maxServiceResponses]
// Flag that there are more records to return in the UDP response
if d.config.EnableTruncate {
resp.Truncated = true
}
}
// If the answer is empty, return not found
if len(resp.Answer) == 0 {
d.addSOA(d.domain, resp)
return
}
}
// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, req, resp *dns.Msg) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
QueryIDOrName: query,
QueryOptions: structs.QueryOptions{
Token: d.agent.config.ACLToken,
AllowStale: d.config.AllowStale,
},
}
// TODO (slackpad) - What's a safe limit we can set here? It seems like
// with dup filtering done at this level we need to get everything to
// match the previous behavior. We can optimize by pushing more filtering
// into the query execution, but for now I think we need to get the full
// response. We could also choose a large arbitrary number that will
// likely work in practice, like 10*maxServiceResponses which should help
// reduce bandwidth if there are thousands of nodes available.
endpoint := d.agent.getEndpoint(preparedQueryEndpoint)
var out structs.PreparedQueryExecuteResponse
RPC:
if err := d.agent.RPC(endpoint+".Execute", &args, &out); err != nil {
// If they give a bogus query name, treat that as a name error,
// not a full on server error. We have to use a string compare
// here since the RPC layer loses the type information.
if err.Error() == consul.ErrQueryNotFound.Error() {
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}
// Verify that request is not too stale, redo the request.
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}
// Determine the TTL. The parse should never fail since we vet it when
// the query is created, but we check anyway. If the query didn't
// specify a TTL then we will try to use the agent's service-specific
// TTL configs.
var ttl time.Duration
if out.DNS.TTL != "" {
var err error
ttl, err = time.ParseDuration(out.DNS.TTL)
if err != nil {
d.logger.Printf("[WARN] dns: Failed to parse TTL '%s' for prepared query '%s', ignoring", out.DNS.TTL, query)
}
} else if d.config.ServiceTTL != nil {
var ok bool
ttl, ok = d.config.ServiceTTL[out.Service]
if !ok {
ttl = d.config.ServiceTTL["*"]
}
}
// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
d.addSOA(d.domain, resp)
resp.SetRcode(req, dns.RcodeNameError)
return
}
// Add various responses depending on the request.
qType := req.Question[0].Qtype
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
}
// If the network is not TCP, restrict the number of responses.
if network != "tcp" && len(resp.Answer) > maxServiceResponses {
resp.Answer = resp.Answer[:maxServiceResponses]
// Flag that there are more records to return in the UDP
// response.
if d.config.EnableTruncate {
resp.Truncated = true
}
}
// If the answer is empty, return not found.
if len(resp.Answer) == 0 {
d.addSOA(d.domain, resp)
return
}
}
// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
for _, node := range nodes {
// Start with the translated address but use the service address,
// if specified.
addr := d.translateAddr(dc, node.Node)
if node.Service.Address != "" {
addr = node.Service.Address
}
// Avoid duplicate entries, possible if a node has
// the same service on multiple ports, etc.
if _, ok := handled[addr]; ok {
continue
}
handled[addr] = struct{}{}
// Add the node record
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}
}
// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
handled := make(map[string]struct{})
for _, node := range nodes {
// Avoid duplicate entries, possible if a node has
// the same service the same port, etc.
tuple := fmt.Sprintf("%s:%s:%d", node.Node.Node, node.Service.Address, node.Service.Port)
if _, ok := handled[tuple]; ok {
continue
}
handled[tuple] = struct{}{}
// Add the SRV record
srvRec := &dns.SRV{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: uint32(ttl / time.Second),
},
Priority: 1,
Weight: 1,
Port: uint16(node.Service.Port),
Target: fmt.Sprintf("%s.node.%s.%s", node.Node.Node, dc, d.domain),
}
resp.Answer = append(resp.Answer, srvRec)
// Start with the translated address but use the service address,
// if specified.
addr := d.translateAddr(dc, node.Node)
if node.Service.Address != "" {
addr = node.Service.Address
}
// Add the extra record
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl)
if records != nil {
resp.Extra = append(resp.Extra, records...)
}
}
}
// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
network := "udp"
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v) from client %s (%s)",
q, network, time.Now().Sub(s), resp.RemoteAddr().String(),
resp.RemoteAddr().Network())
}(time.Now())
// Switch to TCP if the client is
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
network = "tcp"
}
// Recursively resolve
c := &dns.Client{Net: network}
var r *dns.Msg
var rtt time.Duration
var err error
for _, recursor := range d.recursors {
r, rtt, err = c.Exchange(req, recursor)
if err == nil {
// Forward the response
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
if err := resp.WriteMsg(r); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
return
}
d.logger.Printf("[ERR] dns: recurse failed: %v", err)
}
// If all resolvers fail, return a SERVFAIL message
d.logger.Printf("[ERR] dns: all resolvers failed for %v from client %s (%s)",
q, resp.RemoteAddr().String(), resp.RemoteAddr().Network())
m := &dns.Msg{}
m.SetReply(req)
m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure)
resp.WriteMsg(m)
}
// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR {
// Do nothing if we don't have a recursor
if len(d.recursors) == 0 {
return nil
}
// Ask for any A records
m := new(dns.Msg)
m.SetQuestion(name, dns.TypeA)
// Make a DNS lookup request
c := &dns.Client{Net: "udp"}
var r *dns.Msg
var rtt time.Duration
var err error
for _, recursor := range d.recursors {
r, rtt, err = c.Exchange(m, recursor)
if err == nil {
d.logger.Printf("[DEBUG] dns: cname recurse RTT for %v (%v)", name, rtt)
return r.Answer
}
d.logger.Printf("[ERR] dns: cname recurse failed for %v: %v", name, err)
}
d.logger.Printf("[ERR] dns: all resolvers failed for %v", name)
return nil
}