diff --git a/agent/dns.go b/agent/dns.go index 0114ac425c..48f08b71d9 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -1072,7 +1072,7 @@ func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasE // trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit // of DNS responses -func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { +func trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { hasExtra := len(resp.Extra) > 0 // There is some overhead, 65535 does not work maxSize := 65523 // 64k - 12 bytes DNS raw overhead @@ -1080,8 +1080,6 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // We avoid some function calls and allocations by only handling the // extra data when necessary. var index map[string]dns.RR - originalSize := resp.Len() - originalNumRecords := len(resp.Answer) // It is not possible to return more than 4k records even with compression // Since we are performing binary search it is not a big deal, but it @@ -1103,6 +1101,10 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // This enforces the given limit on 64k, the max limit for DNS messages for len(resp.Answer) > 1 && resp.Len() > maxSize { truncated = true + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) @@ -1114,13 +1116,7 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { syncExtra(index, resp) } } - if truncated { - d.logger.Debug("TCP answer to question too large, truncated", - "question", req.Question, - "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), - "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), - ) - } + return truncated } @@ -1169,6 +1165,10 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { // Even when size is too big for one single record, try to send it anyway // (useful for 512 bytes messages) for len(resp.Answer) > 1 && resp.Len() > maxSize-7 { + // first try to remove the NS section may be it will truncate enough + if len(resp.Ns) != 0 { + resp.Ns = []dns.RR{} + } // More than 100 bytes, find with a binary search if resp.Len()-maxSize > 100 { bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra) @@ -1189,14 +1189,24 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { // trimDNSResponse will trim the response for UDP and TCP func (d *DNSServer) trimDNSResponse(cfg *dnsConfig, network string, req, resp *dns.Msg) { var trimmed bool + originalSize := resp.Len() + originalNumRecords := len(resp.Answer) if network != "tcp" { trimmed = trimUDPResponse(req, resp, cfg.UDPAnswerLimit) } else { - trimmed = d.trimTCPResponse(req, resp) + trimmed = trimTCPResponse(req, resp) } // Flag that there are more records to return in the UDP response - if trimmed && cfg.EnableTruncate { - resp.Truncated = true + if trimmed { + if cfg.EnableTruncate { + resp.Truncated = true + } + d.logger.Debug("DNS response too large, truncated", + "protocol", network, + "question", req.Question, + "records", fmt.Sprintf("%d/%d", len(resp.Answer), originalNumRecords), + "size", fmt.Sprintf("%d/%d", resp.Len(), originalSize), + ) } } diff --git a/agent/dns_test.go b/agent/dns_test.go index e72b9ce7cd..11448ab1b5 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -4428,6 +4428,7 @@ func TestBinarySearch(t *testing.T) { msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV) msg.Answer = msgSrc.Answer msg.Extra = msgSrc.Extra + msg.Ns = msgSrc.Ns index := make(map[string]dns.RR, len(msg.Extra)) indexRRs(msg.Extra, index) blen := dnsBinaryTruncate(msg, maxSize, index, true) @@ -6888,6 +6889,103 @@ func TestDNS_trimUDPResponse_TrimLimit(t *testing.T) { } } +func TestDNS_trimUDPResponse_TrimLimitWithNS(t *testing.T) { + t.Parallel() + cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`) + + req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{} + for i := 0; i < cfg.DNSUDPAnswerLimit+1; i++ { + target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "redis-cache-redis.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + }, + Target: target, + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)), + } + ns := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + }, + Ns: fmt.Sprintf("soa-%d", i), + } + + resp.Answer = append(resp.Answer, srv) + resp.Extra = append(resp.Extra, a) + resp.Ns = append(resp.Ns, ns) + if i < cfg.DNSUDPAnswerLimit { + expected.Answer = append(expected.Answer, srv) + expected.Extra = append(expected.Extra, a) + } + } + + if trimmed := trimUDPResponse(req, resp, cfg.DNSUDPAnswerLimit); !trimmed { + t.Fatalf("Bad %#v", *resp) + } + require.LessOrEqual(t, resp.Len(), defaultMaxUDPSize) + require.Len(t, resp.Ns, 0) +} + +func TestDNS_trimTCPResponse_TrimLimitWithNS(t *testing.T) { + t.Parallel() + cfg := loadRuntimeConfig(t, `node_name = "test" data_dir = "a" bind_addr = "127.0.0.1" node_name = "dummy"`) + + req, resp, expected := &dns.Msg{}, &dns.Msg{}, &dns.Msg{} + for i := 0; i < 5000; i++ { + target := fmt.Sprintf("ip-10-0-1-%d.node.dc1.consul.", 185+i) + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: "redis-cache-redis.service.consul.", + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + }, + Target: target, + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.ParseIP(fmt.Sprintf("10.0.1.%d", 185+i)), + } + ns := &dns.SOA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + }, + Ns: fmt.Sprintf("soa-%d", i), + } + + resp.Answer = append(resp.Answer, srv) + resp.Extra = append(resp.Extra, a) + resp.Ns = append(resp.Ns, ns) + if i < cfg.DNSUDPAnswerLimit { + expected.Answer = append(expected.Answer, srv) + expected.Extra = append(expected.Extra, a) + } + } + req.Question = append(req.Question, dns.Question{Qtype: dns.TypeSRV}) + + if trimmed := trimTCPResponse(req, resp); !trimmed { + t.Fatalf("Bad %#v", *resp) + } + require.LessOrEqual(t, resp.Len(), 65523) + require.Len(t, resp.Ns, 0) +} + func loadRuntimeConfig(t *testing.T, hcl string) *config.RuntimeConfig { t.Helper() result, err := config.Load(config.LoadOpts{HCL: []string{hcl}})