diff --git a/agent/dns.go b/agent/dns.go index 2cc015a426..879f177f8a 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -755,9 +755,9 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { // Beyond 2500 records, performance gets bad // Limit the number of records at once, anyway, it won't fit in 64k // For SRV Records, the max is around 500 records, for A, less than 2k - truncateAt := 2048 + truncateAt := 4096 if req.Question[0].Qtype == dns.TypeSRV { - truncateAt = 640 + truncateAt = 1024 } if len(resp.Answer) > truncateAt { resp.Answer = resp.Answer[:truncateAt] diff --git a/agent/dns_test.go b/agent/dns_test.go index 0770c92a0a..1fc31bbbac 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -3080,32 +3080,39 @@ func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { "tcp", "udp", } - for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} { - for _, question := range questions { - for _, protocol := range protocols { - for _, compress := range []bool{true, false} { - t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%v", question, protocol, qType, compress), func(t *testing.T) { - m := new(dns.Msg) - m.SetQuestion(question, dns.TypeANY) - if protocol == "udp" { - m.SetEdns0(8192, true) - } - c := new(dns.Client) - c.Net = protocol - m.Compress = compress - in, out, err := c.Exchange(m, a.DNSAddr()) - if err != nil && err != dns.ErrTruncated { - t.Fatalf("err: %v", err) - } - // Check for the truncate bit - shouldBeTruncated := numServices > 5000 + for _, maxSize := range []uint16{8192, 65535} { + for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} { + for _, question := range questions { + for _, protocol := range protocols { + for _, compress := range []bool{true, false} { + t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%v", question, protocol, qType, compress), func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeANY) + maxSz := maxSize + if protocol == "udp" { + maxSz = 8192 + } + m.SetEdns0(uint16(maxSz), true) + c := new(dns.Client) + c.Net = protocol + m.Compress = compress + in, _, err := c.Exchange(m, a.DNSAddr()) + if err != nil && err != dns.ErrTruncated { + t.Fatalf("err: %v", err) + } - if shouldBeTruncated != in.Truncated || len(in.Answer) > 2000 || len(in.Answer) < 1 || in.Len() > 65535 { + // Check for the truncate bit + buf, err := m.Pack() info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) sz:= %d in %v", - service, question, protocol, numServices, len(in.Answer), out) - t.Fatalf("Should have truncated:=%v for %s", shouldBeTruncated, info) - } - }) + service, question, protocol, numServices, len(in.Answer), in) + if err != nil { + t.Fatalf("Error while packing: %v ; info:=%s", err, info) + } + if len(buf) > int(maxSz) { + t.Fatalf("len(buf) := %d > maxSz=%d for %v", len(buf), maxSz, info) + } + }) + } } } }