agent: Only truncate DNS results for a UDP query

This commit is contained in:
Armon Dadgar 2014-02-14 14:22:49 -08:00
parent d35de5bc11
commit 467db27813
1 changed files with 16 additions and 10 deletions

View File

@ -15,7 +15,7 @@ import (
const ( const (
testQuery = "_test.consul." testQuery = "_test.consul."
consulDomain = "consul." consulDomain = "consul."
maxServiceResponses = 3 // TODO: Increase, currently a bug upstream in dns package maxServiceResponses = 3 // For UDP only
) )
// DNSServer is used to wrap an Agent and expose various // DNSServer is used to wrap an Agent and expose various
@ -137,6 +137,12 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
return return
} }
// Switch to TCP if the client is
network := "udp"
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
network = "tcp"
}
// Setup the message response // Setup the message response
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(req) m.SetReply(req)
@ -144,7 +150,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
d.addSOA(d.domain, m) d.addSOA(d.domain, m)
// Dispatch the correct handler // Dispatch the correct handler
d.dispatch(req, m) d.dispatch(network, req, m)
// Write out the complete response // Write out the complete response
if err := resp.WriteMsg(m); err != nil { if err := resp.WriteMsg(m); err != nil {
@ -200,7 +206,7 @@ func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
} }
// dispatch is used to parse a request and invoke the correct handler // dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(req, resp *dns.Msg) { func (d *DNSServer) dispatch(network string, req, resp *dns.Msg) {
// By default the query is in the default datacenter // By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter datacenter := d.agent.config.Datacenter
@ -221,9 +227,9 @@ PARSE:
// Handle lookup with and without tag // Handle lookup with and without tag
switch len(labels) { switch len(labels) {
case 2: case 2:
d.serviceLookup(datacenter, labels[0], "", req, resp) d.serviceLookup(network, datacenter, labels[0], "", req, resp)
case 3: case 3:
d.serviceLookup(datacenter, labels[1], labels[0], req, resp) d.serviceLookup(network, datacenter, labels[1], labels[0], req, resp)
default: default:
goto INVALID goto INVALID
} }
@ -232,7 +238,7 @@ PARSE:
if len(labels) != 2 { if len(labels) != 2 {
goto INVALID goto INVALID
} }
d.nodeLookup(datacenter, labels[0], req, resp) d.nodeLookup(network, datacenter, labels[0], req, resp)
default: default:
// Store the DC, and re-parse // Store the DC, and re-parse
@ -247,7 +253,7 @@ INVALID:
} }
// nodeLookup is used to handle a node query // nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) { func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
// Only handle ANY and A type requests // Only handle ANY and A type requests
qType := req.Question[0].Qtype qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA { if qType != dns.TypeANY && qType != dns.TypeA {
@ -296,7 +302,7 @@ func (d *DNSServer) nodeLookup(datacenter, node string, req, resp *dns.Msg) {
} }
// serviceLookup is used to handle a service query // serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dns.Msg) { func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
// Make an RPC request // Make an RPC request
args := structs.ServiceSpecificRequest{ args := structs.ServiceSpecificRequest{
Datacenter: datacenter, Datacenter: datacenter,
@ -323,8 +329,8 @@ func (d *DNSServer) serviceLookup(datacenter, service, tag string, req, resp *dn
// Perform a random shuffle // Perform a random shuffle
shuffleServiceNodes(out.Nodes) shuffleServiceNodes(out.Nodes)
// Restrict the number of responses // If the network is not TCP, restrict the number of responses
if len(out.Nodes) > maxServiceResponses { if network != "tcp" && len(out.Nodes) > maxServiceResponses {
out.Nodes = out.Nodes[:maxServiceResponses] out.Nodes = out.Nodes[:maxServiceResponses]
} }