Support DNS recursion and TCP queries

This commit is contained in:
Armon Dadgar 2014-01-03 15:43:35 -08:00
parent b9e0eef1ff
commit 29fe144b5b
4 changed files with 115 additions and 18 deletions

View File

@ -52,6 +52,7 @@ func (c *Command) readConfig() *Config {
"address to bind RPC listener to") "address to bind RPC listener to")
cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory") cmdFlags.StringVar(&cmdConfig.DataDir, "data", "", "path to the data directory")
cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter") cmdFlags.StringVar(&cmdConfig.Datacenter, "dc", "", "node datacenter")
cmdFlags.StringVar(&cmdConfig.DNSRecursor, "recursor", "", "address of dns recursor")
cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server") cmdFlags.BoolVar(&cmdConfig.Server, "server", false, "run agent as server")
cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode") cmdFlags.BoolVar(&cmdConfig.Bootstrap, "bootstrap", false, "enable server bootstrap mode")
if err := cmdFlags.Parse(c.args); err != nil { if err := cmdFlags.Parse(c.args); err != nil {
@ -148,7 +149,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
} }
if config.DNSAddr != "" { if config.DNSAddr != "" {
server, err := NewDNSServer(agent, logOutput, config.Domain, config.DNSAddr) server, err := NewDNSServer(agent, logOutput, config.Domain,
config.DNSAddr, config.DNSRecursor)
if err != nil { if err != nil {
agent.Shutdown() agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err)) c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err))

View File

@ -30,6 +30,10 @@ type Config struct {
// DNSAddr is the address of the DNS server for the agent // DNSAddr is the address of the DNS server for the agent
DNSAddr string DNSAddr string
// DNSRecursor can be set to allow the DNS server to recursively
// resolve non-consul domains
DNSRecursor string
// Domain is the DNS domain for the records. Defaults to "consul." // Domain is the DNS domain for the records. Defaults to "consul."
Domain string Domain string
@ -154,6 +158,9 @@ func MergeConfig(a, b *Config) *Config {
if b.DNSAddr != "" { if b.DNSAddr != "" {
result.DNSAddr = b.DNSAddr result.DNSAddr = b.DNSAddr
} }
if b.DNSRecursor != "" {
result.DNSRecursor = b.DNSRecursor
}
if b.Domain != "" { if b.Domain != "" {
result.Domain = b.Domain result.Domain = b.Domain
} }

View File

@ -22,32 +22,41 @@ type DNSServer struct {
agent *Agent agent *Agent
dnsHandler *dns.ServeMux dnsHandler *dns.ServeMux
dnsServer *dns.Server dnsServer *dns.Server
dnsServerTCP *dns.Server
domain string domain string
recursor string
logger *log.Logger logger *log.Logger
} }
// NewDNSServer starts a new DNS server to provide an agent interface // NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSServer, error) { func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) {
// Make sure domain is FQDN // Make sure domain is FQDN
domain = dns.Fqdn(domain) domain = dns.Fqdn(domain)
// Construct the DNS components // Construct the DNS components
mux := dns.NewServeMux() mux := dns.NewServeMux()
// Setup the server // Setup the servers
server := &dns.Server{ server := &dns.Server{
Addr: bind, Addr: bind,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
UDPSize: 65535, UDPSize: 65535,
} }
serverTCP := &dns.Server{
Addr: bind,
Net: "tcp",
Handler: mux,
}
// Create the server // Create the server
srv := &DNSServer{ srv := &DNSServer{
agent: agent, agent: agent,
dnsHandler: mux, dnsHandler: mux,
dnsServer: server, dnsServer: server,
dnsServerTCP: serverTCP,
domain: domain, domain: domain,
recursor: recursor,
logger: log.New(logOutput, "", log.LstdFlags), logger: log.New(logOutput, "", log.LstdFlags),
} }
@ -56,15 +65,25 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
if domain != consulDomain { if domain != consulDomain {
mux.HandleFunc(consulDomain, srv.handleTest) mux.HandleFunc(consulDomain, srv.handleTest)
} }
if recursor != "" {
mux.HandleFunc(".", srv.handleRecurse)
}
// Async start the DNS Server, handle a potential error // Async start the DNS Servers, handle a potential error
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
err := server.ListenAndServe() err := server.ListenAndServe()
srv.logger.Printf("[ERR] dns: error starting server: %v", err) srv.logger.Printf("[ERR] dns: error starting udp server: %v", err)
errCh <- err errCh <- err
}() }()
errChTCP := make(chan error, 1)
go func() {
err := serverTCP.ListenAndServe()
srv.logger.Printf("[ERR] dns: error starting tcp server: %v", err)
errChTCP <- err
}()
// Check the server is running, do a test lookup // Check the server is running, do a test lookup
checkCh := make(chan error, 1) checkCh := make(chan error, 1)
go func() { go func() {
@ -93,6 +112,8 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind string) (*DNSS
select { select {
case e := <-errCh: case e := <-errCh:
return srv, e return srv, e
case e := <-errChTCP:
return srv, e
case e := <-checkCh: case e := <-checkCh:
return srv, e return srv, e
case <-time.After(time.Second): case <-time.After(time.Second):
@ -119,10 +140,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetReply(req) m.SetReply(req)
m.Authoritative = true m.Authoritative = true
d.addSOA(d.domain, m) d.addSOA(d.domain, m)
defer resp.WriteMsg(m)
// Dispatch the correct handler // Dispatch the correct handler
d.dispatch(req, m) d.dispatch(req, m)
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
} }
// handleTest is used to handle DNS queries in the ".consul." domain // handleTest is used to handle DNS queries in the ".consul." domain
@ -147,7 +172,9 @@ func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
txt := &dns.TXT{header, []string{"ok"}} txt := &dns.TXT{header, []string{"ok"}}
m.Answer = append(m.Answer, txt) m.Answer = append(m.Answer, txt)
d.addSOA(consulDomain, m) d.addSOA(consulDomain, m)
resp.WriteMsg(m) if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
} }
// addSOA is used to add an SOA record to a message for the given domain // addSOA is used to add an SOA record to a message for the given domain
@ -353,3 +380,40 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.ServiceNodes, req
resp.Extra = append(resp.Extra, aRec) resp.Extra = append(resp.Extra, aRec)
} }
} }
// handleRecurse is used to handle recursive DNS queries
func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
network := "udp"
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%s) (%v)", q, network, time.Now().Sub(s))
}(time.Now())
// Switch to TCP if the client is
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
network = "tcp"
}
// Recursively resolve
c := &dns.Client{Net: network}
r, rtt, err := c.Exchange(req, d.recursor)
// On failure, return a SERVFAIL message
if err != nil {
d.logger.Printf("[ERR] dns: recurse failed: %v", err)
m := &dns.Msg{}
m.SetReply(req)
m.SetRcode(req, dns.RcodeServerFailure)
resp.WriteMsg(m)
return
}
d.logger.Printf("[DEBUG] dns: recurse RTT for %v (%v)", q, rtt)
// Seems to be a bug that forcing compression fixes...
r.Compress = true
// Forward the response
if err := resp.WriteMsg(r); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}

View File

@ -11,7 +11,8 @@ import (
func makeDNSServer(t *testing.T) (string, *DNSServer) { func makeDNSServer(t *testing.T) (string, *DNSServer) {
conf := nextConfig() conf := nextConfig()
dir, agent := makeAgent(t, conf) dir, agent := makeAgent(t, conf)
server, err := NewDNSServer(agent, agent.logOutput, conf.Domain, conf.DNSAddr) server, err := NewDNSServer(agent, agent.logOutput, conf.Domain,
conf.DNSAddr, "8.8.8.8:53")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -173,3 +174,26 @@ func TestDNS_ServiceLookup(t *testing.T) {
t.Fatalf("Bad: %#v", in.Extra[0]) t.Fatalf("Bad: %#v", in.Extra[0])
} }
} }
func TestDNS_Recurse(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()
m := new(dns.Msg)
m.SetQuestion("apple.com.", dns.TypeANY)
c := new(dns.Client)
c.Net = "tcp"
in, _, err := c.Exchange(m, srv.agent.config.DNSAddr)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(in.Answer) == 0 {
t.Fatalf("Bad: %#v", in)
}
if in.Rcode != dns.RcodeSuccess {
t.Fatalf("Bad: %#v", in)
}
}