diff --git a/agent/consul/prepared_query_endpoint.go b/agent/consul/prepared_query_endpoint.go index 6f6c5274da..01957a7c61 100644 --- a/agent/consul/prepared_query_endpoint.go +++ b/agent/consul/prepared_query_endpoint.go @@ -1,7 +1,6 @@ package consul import ( - "errors" "fmt" "strings" "time" @@ -16,11 +15,6 @@ import ( "github.com/hashicorp/go-uuid" ) -var ( - // ErrQueryNotFound is returned if the query lookup failed. - ErrQueryNotFound = errors.New("Query not found") -) - // PreparedQuery manages the prepared query endpoint. type PreparedQuery struct { srv *Server @@ -228,7 +222,7 @@ func (p *PreparedQuery) Get(args *structs.PreparedQuerySpecificRequest, return err } if query == nil { - return ErrQueryNotFound + return structs.ErrQueryNotFound } // If no prefix ACL applies to this query, then they are @@ -303,7 +297,7 @@ func (p *PreparedQuery) Explain(args *structs.PreparedQueryExecuteRequest, return err } if query == nil { - return ErrQueryNotFound + return structs.ErrQueryNotFound } // Place the query into a list so we can run the standard ACL filter on @@ -350,7 +344,7 @@ func (p *PreparedQuery) Execute(args *structs.PreparedQueryExecuteRequest, return err } if query == nil { - return ErrQueryNotFound + return structs.ErrQueryNotFound } // Execute the query for the local DC. diff --git a/agent/consul/prepared_query_endpoint_test.go b/agent/consul/prepared_query_endpoint_test.go index eb1edddc7b..b2764e8cac 100644 --- a/agent/consul/prepared_query_endpoint_test.go +++ b/agent/consul/prepared_query_endpoint_test.go @@ -175,7 +175,7 @@ func TestPreparedQuery_Apply(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - if err.Error() != ErrQueryNotFound.Error() { + if !structs.IsErrQueryNotFound(err) { t.Fatalf("err: %v", err) } } @@ -317,7 +317,7 @@ func TestPreparedQuery_Apply_ACLDeny(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - if err.Error() != ErrQueryNotFound.Error() { + if !structs.IsErrQueryNotFound(err) { t.Fatalf("err: %v", err) } } @@ -412,7 +412,7 @@ func TestPreparedQuery_Apply_ACLDeny(t *testing.T) { } var resp structs.IndexedPreparedQueries if err = msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - if err.Error() != ErrQueryNotFound.Error() { + if !structs.IsErrQueryNotFound(err) { t.Fatalf("err: %v", err) } } @@ -1082,7 +1082,7 @@ func TestPreparedQuery_Get(t *testing.T) { } var resp structs.IndexedPreparedQueries if err := msgpackrpc.CallWithCodec(codec, "PreparedQuery.Get", req, &resp); err != nil { - if err.Error() != ErrQueryNotFound.Error() { + if !structs.IsErrQueryNotFound(err) { t.Fatalf("err: %v", err) } } @@ -1429,7 +1429,7 @@ func TestPreparedQuery_Explain(t *testing.T) { } var resp structs.IndexedPreparedQueries if err := msgpackrpc.CallWithCodec(codec, "PreparedQuery.Explain", req, &resp); err != nil { - if err.Error() != ErrQueryNotFound.Error() { + if !structs.IsErrQueryNotFound(err) { t.Fatalf("err: %v", err) } } @@ -1617,7 +1617,7 @@ func TestPreparedQuery_Execute(t *testing.T) { var reply structs.PreparedQueryExecuteResponse err := msgpackrpc.CallWithCodec(codec1, "PreparedQuery.Execute", &req, &reply) - assert.EqualError(t, err, ErrQueryNotFound.Error()) + assert.EqualError(t, err, structs.ErrQueryNotFound.Error()) assert.Len(t, reply.Nodes, 0) }) diff --git a/agent/dns.go b/agent/dns.go index cc508749a9..187b0463c5 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -16,7 +16,6 @@ import ( "github.com/coredns/coredns/plugin/pkg/dnsutil" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/config" - "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/ipaddr" @@ -828,8 +827,7 @@ func (d *DNSServer) computeRCode(err error) int { if err == nil { return dns.RcodeSuccess } - dErr := err.Error() - if dErr == structs.ErrNoDCPath.Error() || dErr == consul.ErrQueryNotFound.Error() { + if structs.IsErrNoDCPath(err) || structs.IsErrQueryNotFound(err) { return dns.RcodeNameError } return dns.RcodeServerFailure diff --git a/agent/dns_test.go b/agent/dns_test.go index 1df0c4d9db..a5aa787c4b 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -5790,15 +5790,17 @@ func TestDNS_AddressLookupIPV6(t *testing.T) { } } -func TestDNS_NonExistingDC(t *testing.T) { +// TestDNS_NonExistentDC_Server verifies NXDOMAIN is returned when +// Consul server agent is queried for a service in a non-existent +// domain. +func TestDNS_NonExistentDC_Server(t *testing.T) { t.Parallel() a := NewTestAgent(t, "") defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") - // lookup a non-existing node, we should receive a SOA m := new(dns.Msg) - m.SetQuestion("consul.dc2.consul.", dns.TypeANY) + m.SetQuestion("consul.service.dc2.consul.", dns.TypeANY) c := new(dns.Client) in, _, err := c.Exchange(m, a.DNSAddr()) @@ -5811,6 +5813,47 @@ func TestDNS_NonExistingDC(t *testing.T) { } } +// TestDNS_NonExistentDC_RPC verifies NXDOMAIN is returned when +// Consul server agent is queried over RPC by a non-server agent +// for a service in a non-existent domain +func TestDNS_NonExistentDC_RPC(t *testing.T) { + t.Parallel() + s := NewTestAgent(t, ` + node_name = "test-server" + `) + + defer s.Shutdown() + c := NewTestAgent(t, ` + node_name = "test-client" + bootstrap = false + server = false + `) + defer c.Shutdown() + testrpc.WaitForLeader(t, s.RPC, "dc1") + + // Join LAN cluster + addr := fmt.Sprintf("127.0.0.1:%d", s.Config.SerfPortLAN) + _, err := c.JoinLAN([]string{addr}) + require.NoError(t, err) + retry.Run(t, func(r *retry.R) { + require.Len(r, s.LANMembers(), 2) + require.Len(r, c.LANMembers(), 2) + }) + + m := new(dns.Msg) + m.SetQuestion("consul.service.dc2.consul.", dns.TypeANY) + + d := new(dns.Client) + in, _, err := d.Exchange(m, c.DNSAddr()) + if err != nil { + t.Fatalf("err: %v", err) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("Expected RCode: %#v, had: %#v", dns.RcodeNameError, in.Rcode) + } +} + func TestDNS_NonExistingLookup(t *testing.T) { t.Parallel() a := NewTestAgent(t, "") diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index 5a33f89e7c..ab3de36dbb 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -7,7 +7,6 @@ import ( "strings" cachetype "github.com/hashicorp/consul/agent/cache-types" - "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" ) @@ -144,7 +143,7 @@ func (s *HTTPServer) preparedQueryExecute(id string, resp http.ResponseWriter, r if err := s.agent.RPC("PreparedQuery.Execute", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. - if err.Error() == consul.ErrQueryNotFound.Error() { + if structs.IsErrQueryNotFound(err) { resp.WriteHeader(http.StatusNotFound) fmt.Fprint(resp, err.Error()) return nil, nil @@ -198,7 +197,7 @@ RETRY_ONCE: if err := s.agent.RPC("PreparedQuery.Explain", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. - if err.Error() == consul.ErrQueryNotFound.Error() { + if structs.IsErrQueryNotFound(err) { resp.WriteHeader(http.StatusNotFound) fmt.Fprint(resp, err.Error()) return nil, nil @@ -229,7 +228,7 @@ RETRY_ONCE: if err := s.agent.RPC("PreparedQuery.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds // the specific error type. - if err.Error() == consul.ErrQueryNotFound.Error() { + if structs.IsErrQueryNotFound(err) { resp.WriteHeader(http.StatusNotFound) fmt.Fprint(resp, err.Error()) return nil, nil diff --git a/agent/structs/errors.go b/agent/structs/errors.go index 9d04b1ed56..c5ef31c8bc 100644 --- a/agent/structs/errors.go +++ b/agent/structs/errors.go @@ -14,6 +14,7 @@ const ( errSegmentsNotSupported = "Network segments are not supported in this version of Consul" errRPCRateExceeded = "RPC rate limit exceeded" errServiceNotFound = "Service not found: " + errQueryNotFound = "Query not found" ) var ( @@ -24,8 +25,17 @@ var ( ErrSegmentsNotSupported = errors.New(errSegmentsNotSupported) ErrRPCRateExceeded = errors.New(errRPCRateExceeded) ErrDCNotAvailable = errors.New(errDCNotAvailable) + ErrQueryNotFound = errors.New(errQueryNotFound) ) +func IsErrNoDCPath(err error) bool { + return err != nil && strings.Contains(err.Error(), errNoDCPath) +} + +func IsErrQueryNotFound(err error) bool { + return err != nil && strings.Contains(err.Error(), errQueryNotFound) +} + func IsErrNoLeader(err error) bool { return err != nil && strings.Contains(err.Error(), errNoLeader) }