From c5d2bea92c244c65f62181cda1c1bf0eca9bca10 Mon Sep 17 00:00:00 2001 From: Mathew Estafanous <56979977+Mathew-Estafanous@users.noreply.github.com> Date: Mon, 31 Jan 2022 11:17:35 -0500 Subject: [PATCH] Change error-handling across handlers. (#12225) --- agent/catalog_endpoint.go | 24 ++++--------- agent/config_endpoint.go | 8 ++--- agent/coordinate_endpoint.go | 33 ++++++++---------- agent/coordinate_endpoint_test.go | 10 +++--- agent/discovery_chain_endpoint.go | 4 +-- agent/event_endpoint.go | 9 ++--- agent/event_endpoint_test.go | 14 ++++---- agent/health_endpoint.go | 21 +++--------- agent/health_endpoint_test.go | 29 ++++++---------- agent/http.go | 21 ++++++++++-- agent/intentions_endpoint.go | 8 ++--- agent/kvs_endpoint.go | 33 ++++++------------ agent/prepared_query_endpoint.go | 20 +++-------- agent/prepared_query_endpoint_test.go | 24 +++++-------- agent/session_endpoint.go | 24 ++++--------- agent/txn_endpoint.go | 49 +++++++++++---------------- agent/txn_endpoint_test.go | 43 ++++++++++------------- agent/ui_endpoint.go | 24 ++++--------- agent/ui_endpoint_test.go | 5 +-- 19 files changed, 149 insertions(+), 254 deletions(-) diff --git a/agent/catalog_endpoint.go b/agent/catalog_endpoint.go index 8443e31f3e..2ae1b07dc0 100644 --- a/agent/catalog_endpoint.go +++ b/agent/catalog_endpoint.go @@ -136,9 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque } if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Setup the default DC if not provided @@ -168,9 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req return nil, err } if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Setup the default DC if not provided @@ -367,9 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R return nil, err } if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing service name") - return nil, nil + return nil, BadRequestError{Reason: "Missing service name"} } // Make the RPC request @@ -444,9 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R return nil, err } if args.Node == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing node name") - return nil, nil + return nil, BadRequestError{Reason: "Missing node name"} } // Make the RPC request @@ -511,9 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt return nil, err } if args.Node == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing node name") - return nil, nil + return nil, BadRequestError{Reason: "Missing node name"} } // Make the RPC request @@ -564,9 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt return nil, err } if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing gateway name") - return nil, nil + return nil, BadRequestError{Reason: "Missing gateway name"} } // Make the RPC request diff --git a/agent/config_endpoint.go b/agent/config_endpoint.go index 9f916624ef..4bd96d4363 100644 --- a/agent/config_endpoint.go +++ b/agent/config_endpoint.go @@ -90,16 +90,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request) pathArgs := strings.SplitN(kindAndName, "/", 2) if len(pathArgs) != 2 { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprintf(resp, "Must provide both a kind and name to delete") - return nil, nil + return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"} } entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1]) if err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "%v", err) - return nil, nil + return nil, BadRequestError{Reason: err.Error()} } args.Entry = entry // Parse enterprise meta. diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index 822ef5b39b..ff3df3d06c 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -8,16 +8,13 @@ import ( "github.com/hashicorp/consul/agent/structs" ) -// checkCoordinateDisabled will return a standard response if coordinates are -// disabled. This returns true if they are disabled and we should not continue. -func (s *HTTPHandlers) checkCoordinateDisabled(resp http.ResponseWriter, req *http.Request) bool { +// checkCoordinateDisabled will return an unauthorized error if coordinates are +// disabled. Otherwise, a nil error will be returned. +func (s *HTTPHandlers) checkCoordinateDisabled() error { if !s.agent.config.DisableCoordinates { - return false + return nil } - - resp.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(resp, "Coordinate support disabled") - return true + return UnauthorizedError{Reason: "Coordinate support disabled"} } // sorter wraps a coordinate list and implements the sort.Interface to sort by @@ -44,8 +41,8 @@ func (s *sorter) Less(i, j int) bool { // CoordinateDatacenters returns the WAN nodes in each datacenter, along with // raw network coordinates. func (s *HTTPHandlers) CoordinateDatacenters(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.checkCoordinateDisabled(resp, req) { - return nil, nil + if err := s.checkCoordinateDisabled(); err != nil { + return nil, err } var out []structs.DatacenterMap @@ -73,8 +70,8 @@ func (s *HTTPHandlers) CoordinateDatacenters(resp http.ResponseWriter, req *http // CoordinateNodes returns the LAN nodes in the given datacenter, along with // raw network coordinates. func (s *HTTPHandlers) CoordinateNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.checkCoordinateDisabled(resp, req) { - return nil, nil + if err := s.checkCoordinateDisabled(); err != nil { + return nil, err } args := structs.DCSpecificRequest{} @@ -98,8 +95,8 @@ func (s *HTTPHandlers) CoordinateNodes(resp http.ResponseWriter, req *http.Reque // CoordinateNode returns the LAN node in the given datacenter, along with // raw network coordinates. func (s *HTTPHandlers) CoordinateNode(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.checkCoordinateDisabled(resp, req) { - return nil, nil + if err := s.checkCoordinateDisabled(); err != nil { + return nil, err } node, err := getPathSuffixUnescaped(req.URL.Path, "/v1/coordinate/node/") @@ -153,15 +150,13 @@ func filterCoordinates(req *http.Request, in structs.Coordinates) structs.Coordi // CoordinateUpdate inserts or updates the LAN coordinate of a node. func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if s.checkCoordinateDisabled(resp, req) { - return nil, nil + if err := s.checkCoordinateDisabled(); err != nil { + return nil, err } args := structs.CoordinateUpdateRequest{} if err := decodeBody(req.Body, &args); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) diff --git a/agent/coordinate_endpoint_test.go b/agent/coordinate_endpoint_test.go index 36b956a8f8..ff38d83227 100644 --- a/agent/coordinate_endpoint_test.go +++ b/agent/coordinate_endpoint_test.go @@ -39,16 +39,14 @@ func TestCoordinate_Disabled_Response(t *testing.T) { req, _ := http.NewRequest("PUT", "/should/not/care", nil) resp := httptest.NewRecorder() obj, err := tt(resp, req) - if err != nil { - t.Fatalf("err: %v", err) + err, ok := err.(UnauthorizedError) + if !ok { + t.Fatalf("expected unauthorized error but got %v", err) } if obj != nil { t.Fatalf("bad: %#v", obj) } - if got, want := resp.Code, http.StatusUnauthorized; got != want { - t.Fatalf("got %d want %d", got, want) - } - if !strings.Contains(resp.Body.String(), "Coordinate support disabled") { + if !strings.Contains(err.Error(), "Coordinate support disabled") { t.Fatalf("bad: %#v", resp) } }) diff --git a/agent/discovery_chain_endpoint.go b/agent/discovery_chain_endpoint.go index ec3d46e9e6..666841ef32 100644 --- a/agent/discovery_chain_endpoint.go +++ b/agent/discovery_chain_endpoint.go @@ -51,9 +51,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re if apiReq.OverrideMeshGateway.Mode != "" { _, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode)) if err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Invalid OverrideMeshGateway.Mode parameter") - return nil, nil + return nil, BadRequestError{Reason: "Invalid OverrideMeshGateway.Mode parameter"} } args.OverrideMeshGateway = apiReq.OverrideMeshGateway } diff --git a/agent/event_endpoint.go b/agent/event_endpoint.go index 10421d6202..53b0e5d65b 100644 --- a/agent/event_endpoint.go +++ b/agent/event_endpoint.go @@ -2,7 +2,6 @@ package agent import ( "bytes" - "fmt" "io" "net/http" "strconv" @@ -26,9 +25,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i return nil, err } if event.Name == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing name") - return nil, nil + return nil, BadRequestError{Reason: "Missing name"} } // Get the ACL token @@ -58,9 +55,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i // Try to fire the event if err := s.agent.UserEvent(dc, token, event); err != nil { if acl.IsErrPermissionDenied(err) { - resp.WriteHeader(http.StatusForbidden) - fmt.Fprint(resp, acl.ErrPermissionDenied.Error()) - return nil, nil + return nil, ForbiddenError{Reason: acl.ErrPermissionDenied.Error()} } resp.WriteHeader(http.StatusInternalServerError) return nil, err diff --git a/agent/event_endpoint_test.go b/agent/event_endpoint_test.go index d430c15ddb..e5f0b39f73 100644 --- a/agent/event_endpoint_test.go +++ b/agent/event_endpoint_test.go @@ -88,13 +88,11 @@ func TestEventFire_token(t *testing.T) { url := fmt.Sprintf("/v1/event/fire/%s?token=%s", c.event, token) req, _ := http.NewRequest("PUT", url, nil) resp := httptest.NewRecorder() - if _, err := a.srv.EventFire(resp, req); err != nil { - t.Fatalf("err: %s", err) - } + _, err := a.srv.EventFire(resp, req) // Check the result - body := resp.Body.String() if c.allowed { + body := resp.Body.String() if acl.IsErrPermissionDenied(errors.New(body)) { t.Fatalf("bad: %s", body) } @@ -102,11 +100,11 @@ func TestEventFire_token(t *testing.T) { t.Fatalf("bad: %d", resp.Code) } } else { - if !acl.IsErrPermissionDenied(errors.New(body)) { - t.Fatalf("bad: %s", body) + if !acl.IsErrPermissionDenied(err) { + t.Fatalf("bad: %s", err.Error()) } - if resp.Code != 403 { - t.Fatalf("bad: %d", resp.Code) + if err, ok := err.(ForbiddenError); !ok { + t.Fatalf("Expected forbidden but got %v", err) } } } diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index f6e803a3cb..faa37b6157 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -1,7 +1,6 @@ package agent import ( - "fmt" "net/http" "net/url" "strconv" @@ -36,9 +35,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R return nil, err } if args.State == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing check state") - return nil, nil + return nil, BadRequestError{Reason: "Missing check state"} } // Make the RPC request @@ -82,9 +79,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ // Pull out the service name args.Node = strings.TrimPrefix(req.URL.Path, "/v1/health/node/") if args.Node == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing node name") - return nil, nil + return nil, BadRequestError{Reason: "Missing node name"} } // Make the RPC request @@ -130,9 +125,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R // Pull out the service name args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/checks/") if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing service name") - return nil, nil + return nil, BadRequestError{Reason: "Missing service name"} } // Make the RPC request @@ -218,9 +211,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re // Pull out the service name args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix) if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing service name") - return nil, nil + return nil, BadRequestError{Reason: "Missing service name"} } out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args) @@ -238,9 +229,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re // Filter to only passing if specified filter, err := getBoolQueryParam(params, api.HealthPassing) if err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Invalid value for ?passing") - return nil, nil + return nil, BadRequestError{Reason: "Invalid value for ?passing"} } // FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer diff --git a/agent/health_endpoint_test.go b/agent/health_endpoint_test.go index 5bd5444edf..baa4c43423 100644 --- a/agent/health_endpoint_test.go +++ b/agent/health_endpoint_test.go @@ -1,14 +1,13 @@ package agent import ( - "bytes" "fmt" - "io/ioutil" "net/http" "net/http/httptest" "net/url" "reflect" "strconv" + "strings" "testing" "time" @@ -1241,18 +1240,12 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) { t.Run("passing_bad", func(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil) resp := httptest.NewRecorder() - a.srv.HealthServiceNodes(resp, req) - - if code := resp.Code; code != 400 { - t.Errorf("bad response code %d, expected %d", code, 400) + _, err := a.srv.HealthServiceNodes(resp, req) + if _, ok := err.(BadRequestError); !ok { + t.Fatalf("Expected bad request error but got %v", err) } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - if !bytes.Contains(body, []byte("Invalid value for ?passing")) { - t.Errorf("bad %s", body) + if !strings.Contains(err.Error(), "Invalid value for ?passing") { + t.Errorf("bad %s", err.Error()) } }) } @@ -1654,12 +1647,12 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf( "/v1/health/connect/%s?passing=nope-nope", args.Service.Proxy.DestinationServiceName), nil) resp := httptest.NewRecorder() - a.srv.HealthConnectServiceNodes(resp, req) - assert.Equal(t, 400, resp.Code) + _, err := a.srv.HealthConnectServiceNodes(resp, req) + assert.NotNil(t, err) + _, ok := err.(BadRequestError) + assert.True(t, ok) - body, err := ioutil.ReadAll(resp.Body) - assert.Nil(t, err) - assert.True(t, bytes.Contains(body, []byte("Invalid value for ?passing"))) + assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing")) }) } diff --git a/agent/http.go b/agent/http.go index 5fc84ea113..b470547ed1 100644 --- a/agent/http.go +++ b/agent/http.go @@ -78,6 +78,14 @@ func (e UnauthorizedError) Error() string { return e.Reason } +type EntityTooLargeError struct { + Reason string +} + +func (e EntityTooLargeError) Error() string { + return e.Reason +} + // CodeWithPayloadError allow returning non HTTP 200 // Error codes while not returning PlainText payload type CodeWithPayloadError struct { @@ -91,10 +99,11 @@ func (e CodeWithPayloadError) Error() string { } type ForbiddenError struct { + Reason string } func (e ForbiddenError) Error() string { - return "Access is restricted" + return e.Reason } // HTTPHandlers provides an HTTP api for an agent. @@ -443,6 +452,11 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc return err.Error() == consul.ErrRateLimited.Error() } + isEntityToLarge := func(err error) bool { + _, ok := err.(EntityTooLargeError) + return ok + } + addAllowHeader := func(methods []string) { resp.Header().Add("Allow", strings.Join(methods, ",")) } @@ -488,6 +502,9 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc case isTooManyRequests(err): resp.WriteHeader(http.StatusTooManyRequests) fmt.Fprint(resp, err.Error()) + case isEntityToLarge(err): + resp.WriteHeader(http.StatusRequestEntityTooLarge) + fmt.Fprint(resp, err.Error()) default: resp.WriteHeader(http.StatusInternalServerError) fmt.Fprint(resp, err.Error()) @@ -1136,7 +1153,7 @@ func (s *HTTPHandlers) checkWriteAccess(req *http.Request) error { } } - return ForbiddenError{} + return ForbiddenError{Reason: "Access is restricted"} } func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) { diff --git a/agent/intentions_endpoint.go b/agent/intentions_endpoint.go index 2bb55b7f4b..4c326b4f1e 100644 --- a/agent/intentions_endpoint.go +++ b/agent/intentions_endpoint.go @@ -323,9 +323,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if err.Error() == consul.ErrIntentionNotFound.Error() { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - return nil, nil + return nil, NotFoundError{Reason: err.Error()} } // Not ideal, but there are a number of error scenarios that are not @@ -521,9 +519,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter, if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if err.Error() == consul.ErrIntentionNotFound.Error() { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - return nil, nil + return nil, NotFoundError{Reason: err.Error()} } // Not ideal, but there are a number of error scenarios that are not diff --git a/agent/kvs_endpoint.go b/agent/kvs_endpoint.go index b6bed301be..4b8cc3348f 100644 --- a/agent/kvs_endpoint.go +++ b/agent/kvs_endpoint.go @@ -55,8 +55,8 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args params := req.URL.Query() if _, ok := params["recurse"]; ok { method = "KVS.List" - } else if missingKey(resp, args) { - return nil, nil + } else if args.Key == "" { + return nil, BadRequestError{Reason: "Missing key name"} } // Do not allow wildcard NS on GET reqs @@ -156,8 +156,8 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args if err := s.parseEntMetaNoWildcard(req, &args.EnterpriseMeta); err != nil { return nil, err } - if missingKey(resp, args) { - return nil, nil + if args.Key == "" { + return nil, BadRequestError{Reason: "Missing key name"} } if conflictingFlags(resp, req, "cas", "acquire", "release") { return nil, nil @@ -208,13 +208,10 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args // Check the content-length if req.ContentLength > int64(s.agent.config.KVMaxValueSize) { - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprintf(resp, - "Request body(%d bytes) too large, max size: %d bytes. See %s.", - req.ContentLength, s.agent.config.KVMaxValueSize, - "https://www.consul.io/docs/agent/options.html#kv_max_value_size", - ) - return nil, nil + return nil, EntityTooLargeError{ + Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", + req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/options.html#kv_max_value_size"), + } } // Copy the value @@ -259,8 +256,8 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar params := req.URL.Query() if _, ok := params["recurse"]; ok { applyReq.Op = api.KVDeleteTree - } else if missingKey(resp, args) { - return nil, nil + } else if args.Key == "" { + return nil, BadRequestError{Reason: "Missing key name"} } // Check for cas value @@ -286,16 +283,6 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar return true, nil } -// missingKey checks if the key is missing -func missingKey(resp http.ResponseWriter, args *structs.KeyRequest) bool { - if args.Key == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing key name") - return true - } - return false -} - // conflictingFlags determines if non-composable flags were passed in a request. func conflictingFlags(resp http.ResponseWriter, req *http.Request, flags ...string) bool { params := req.URL.Query() diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index 31e900288e..b398e24465 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -23,9 +23,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) if err := decodeBody(req.Body, &args.Query); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } var reply string @@ -145,9 +143,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter, // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - return nil, nil + return nil, NotFoundError{Reason: err.Error()} } return nil, err } @@ -200,9 +196,7 @@ RETRY_ONCE: // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - return nil, nil + return nil, NotFoundError{Reason: err.Error()} } return nil, err } @@ -231,9 +225,7 @@ RETRY_ONCE: // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - return nil, nil + return nil, NotFoundError{Reason: err.Error()} } return nil, err } @@ -255,9 +247,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter, s.parseToken(req, &args.Token) if req.ContentLength > 0 { if err := decodeBody(req.Body, &args.Query); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } } diff --git a/agent/prepared_query_endpoint_test.go b/agent/prepared_query_endpoint_test.go index 79658f8c11..e4d3056e83 100644 --- a/agent/prepared_query_endpoint_test.go +++ b/agent/prepared_query_endpoint_test.go @@ -620,11 +620,9 @@ func TestPreparedQuery_Execute(t *testing.T) { body := bytes.NewBuffer(nil) req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body) resp := httptest.NewRecorder() - if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil { - t.Fatalf("err: %v", err) - } - if resp.Code != 404 { - t.Fatalf("bad code: %d", resp.Code) + _, err := a.srv.PreparedQuerySpecific(resp, req) + if err, ok := err.(NotFoundError); !ok { + t.Fatalf("Expected not found error but got %v", err) } }) } @@ -757,11 +755,9 @@ func TestPreparedQuery_Explain(t *testing.T) { body := bytes.NewBuffer(nil) req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body) resp := httptest.NewRecorder() - if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil { - t.Fatalf("err: %v", err) - } - if resp.Code != 404 { - t.Fatalf("bad code: %d", resp.Code) + _, err := a.srv.PreparedQuerySpecific(resp, req) + if err, ok := err.(NotFoundError); !ok { + t.Fatalf("Expected not found error but got %v", err) } }) @@ -848,11 +844,9 @@ func TestPreparedQuery_Get(t *testing.T) { body := bytes.NewBuffer(nil) req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body) resp := httptest.NewRecorder() - if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil { - t.Fatalf("err: %v", err) - } - if resp.Code != 404 { - t.Fatalf("bad code: %d", resp.Code) + _, err := a.srv.PreparedQuerySpecific(resp, req) + if err, ok := err.(NotFoundError); !ok { + t.Fatalf("Expected not found error but got %v", err) } }) } diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index 9371bf7418..afe3faa3c4 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -40,9 +40,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request // Handle optional request body if req.ContentLength > 0 { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Request decode failed: %v", err) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} } } @@ -77,9 +75,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques return nil, err } if args.Session.ID == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing session") - return nil, nil + return nil, BadRequestError{Reason: "Missing session"} } var out string @@ -107,18 +103,14 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request) } args.Session = args.SessionID if args.SessionID == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing session") - return nil, nil + return nil, BadRequestError{Reason: "Missing session"} } var out structs.IndexedSessions if err := s.agent.RPC("Session.Renew", &args, &out); err != nil { return nil, err } else if out.Sessions == nil { - resp.WriteHeader(http.StatusNotFound) - fmt.Fprintf(resp, "Session id '%s' not found", args.SessionID) - return nil, nil + return nil, NotFoundError{Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)} } return out.Sessions, nil @@ -142,9 +134,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) ( } args.Session = args.SessionID if args.SessionID == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing session") - return nil, nil + return nil, BadRequestError{Reason: "Missing session"} } var out structs.IndexedSessions @@ -200,9 +190,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque return nil, err } if args.Node == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing node name") - return nil, nil + return nil, BadRequestError{Reason: "Missing node name"} } var out structs.IndexedSessions diff --git a/agent/txn_endpoint.go b/agent/txn_endpoint.go index f954ace410..58a1cd4b0f 100644 --- a/agent/txn_endpoint.go +++ b/agent/txn_endpoint.go @@ -63,7 +63,7 @@ func isWrite(op api.KVOp) bool { // internal RPC format. This returns a count of the number of write ops, and // a boolean, that if false means an error response has been generated and // processing should stop. -func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (structs.TxnOps, int, bool) { +func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (structs.TxnOps, int, error) { // The TxnMaxReqLen limit and KVMaxValueSize limit both default to the // suggested raft data size and can be configured independently. The // TxnMaxReqLen is enforced on the cumulative size of the transaction, @@ -87,13 +87,10 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( // Check Content-Length first before decoding to return early if req.ContentLength > maxTxnLen { - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprintf(resp, - "Request body(%d bytes) too large, max size: %d bytes. See %s.", - req.ContentLength, maxTxnLen, - "https://www.consul.io/docs/agent/options.html#txn_max_req_len", - ) - return nil, 0, false + return nil, 0, EntityTooLargeError{ + Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", + req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/options.html#txn_max_req_len"), + } } var ops api.TxnOps @@ -102,30 +99,24 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( if err.Error() == "http: request body too large" { // The request size is also verified during decoding to double check // if the Content-Length header was not set by the client. - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprintf(resp, - "Request body too large, max size: %d bytes. See %s.", - maxTxnLen, - "https://www.consul.io/docs/agent/options.html#txn_max_req_len", - ) + return nil, 0, EntityTooLargeError{ + Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.", + maxTxnLen, "https://www.consul.io/docs/agent/options.html#txn_max_req_len"), + } } else { // Note the body is in API format, and not the RPC format. If we can't // decode it, we will return a 400 since we don't have enough context to // associate the error with a given operation. - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Failed to parse body: %v", err) + return nil, 0, BadRequestError{Reason: fmt.Sprintf("Failed to parse body: %v", err)} } - return nil, 0, false } // Enforce a reasonable upper limit on the number of operations in a // transaction in order to curb abuse. if size := len(ops); size > maxTxnOps { - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprintf(resp, "Transaction contains too many operations (%d > %d)", - size, maxTxnOps) - - return nil, 0, false + return nil, 0, EntityTooLargeError{ + Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps), + } } // Convert the KV API format into the RPC format. Note that fixupKVOps @@ -138,9 +129,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( case in.KV != nil: size := len(in.KV.Value) if int64(size) > kvMaxValueSize { - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprintf(resp, "Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize) - return nil, 0, false + return nil, 0, EntityTooLargeError{ + Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize), + } } verb := in.KV.Verb @@ -297,7 +288,7 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( } } - return opsRPC, writes, true + return opsRPC, writes, nil } // Txn handles requests to apply multiple operations in a single, atomic @@ -306,9 +297,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( // and everything else will be routed through Raft like a normal write. func (s *HTTPHandlers) Txn(resp http.ResponseWriter, req *http.Request) (interface{}, error) { // Convert the ops from the API format to the internal format. - ops, writes, ok := s.convertOps(resp, req) - if !ok { - return nil, nil + ops, writes, err := s.convertOps(resp, req) + if err != nil { + return nil, err } // Fast-path a transaction with only writes to the read-only endpoint, diff --git a/agent/txn_endpoint_test.go b/agent/txn_endpoint_test.go index 6b3a7c4680..2f9d6fbca1 100644 --- a/agent/txn_endpoint_test.go +++ b/agent/txn_endpoint_test.go @@ -30,13 +30,12 @@ func TestTxnEndpoint_Bad_JSON(t *testing.T) { buf := bytes.NewBuffer([]byte("{")) req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() - if _, err := a.srv.Txn(resp, req); err != nil { - t.Fatalf("err: %v", err) + _, err := a.srv.Txn(resp, req) + err, ok := err.(BadRequestError) + if !ok { + t.Fatalf("expected bad request error but got %v", err) } - if resp.Code != 400 { - t.Fatalf("expected 400, got %d", resp.Code) - } - if !bytes.Contains(resp.Body.Bytes(), []byte("Failed to parse")) { + if !strings.Contains(err.Error(), "Failed to parse") { t.Fatalf("expected conflicting args error") } } @@ -63,15 +62,13 @@ func TestTxnEndpoint_Bad_Size_Item(t *testing.T) { `, value))) req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() - if _, err := agent.srv.Txn(resp, req); err != nil { + _, err := agent.srv.Txn(resp, req) + if err, ok := err.(EntityTooLargeError); !ok && !wantPass { + t.Fatalf("expected too large error but got %v", err) + } + if err != nil && wantPass { t.Fatalf("err: %v", err) } - if resp.Code != 413 && !wantPass { - t.Fatalf("expected 413, got %d", resp.Code) - } - if resp.Code != 200 && wantPass { - t.Fatalf("expected 200, got %d", resp.Code) - } } t.Run("exceeds default limits", func(t *testing.T) { @@ -140,15 +137,13 @@ func TestTxnEndpoint_Bad_Size_Net(t *testing.T) { `, value, value, value))) req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() - if _, err := agent.srv.Txn(resp, req); err != nil { + _, err := agent.srv.Txn(resp, req) + if err, ok := err.(EntityTooLargeError); !ok && !wantPass { + t.Fatalf("expected too large error but got %v", err) + } + if err != nil && wantPass { t.Fatalf("err: %v", err) } - if resp.Code != 413 && !wantPass { - t.Fatalf("expected 413, got %d", resp.Code) - } - if resp.Code != 200 && wantPass { - t.Fatalf("expected 200, got %d", resp.Code) - } } t.Run("exceeds default limits", func(t *testing.T) { @@ -209,11 +204,9 @@ func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) { `, strings.Repeat(`{ "KV": { "Verb": "get", "Key": "key" } },`, 2*maxTxnOps)))) req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() - if _, err := a.srv.Txn(resp, req); err != nil { - t.Fatalf("err: %v", err) - } - if resp.Code != 413 { - t.Fatalf("expected 413, got %d", resp.Code) + _, err := a.srv.Txn(resp, req) + if err, ok := err.(EntityTooLargeError); !ok { + t.Fatalf("expected too large error but got %v", err) } } diff --git a/agent/ui_endpoint.go b/agent/ui_endpoint.go index 4c6b3e7686..56de071c08 100644 --- a/agent/ui_endpoint.go +++ b/agent/ui_endpoint.go @@ -139,9 +139,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) ( return nil, err } if args.Node == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing node name") - return nil, nil + return nil, BadRequestError{Reason: "Missing node name"} } // Make the RPC request @@ -255,9 +253,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt return nil, err } if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing gateway name") - return nil, nil + return nil, BadRequestError{Reason: "Missing gateway name"} } // Make the RPC request @@ -301,16 +297,12 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req return nil, err } if args.ServiceName == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing service name") - return nil, nil + return nil, BadRequestError{Reason: "Missing service name"} } kind, ok := req.URL.Query()["kind"] if !ok { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing service kind") - return nil, nil + return nil, BadRequestError{Reason: "Missing service kind"} } args.ServiceKind = structs.ServiceKind(kind[0]) @@ -318,9 +310,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req case structs.ServiceKindTypical, structs.ServiceKindIngressGateway: // allowed default: - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(resp, "Unsupported service kind %q", args.ServiceKind) - return nil, nil + return nil, BadRequestError{Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)} } // Make the RPC request @@ -584,9 +574,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R return nil, err } if name == "" { - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, "Missing gateway name") - return nil, nil + return nil, BadRequestError{Reason: "Missing gateway name"} } args.Match = &structs.IntentionQueryMatch{ Type: structs.IntentionMatchDestination, diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index 84cd971a7a..757975bc5c 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -1409,14 +1409,15 @@ func TestUIServiceTopology(t *testing.T) { retry.Run(t, func(r *retry.R) { resp := httptest.NewRecorder() obj, err := a.srv.UIServiceTopology(resp, tc.httpReq) - assert.Nil(r, err) if tc.wantErr != "" { + assert.NotNil(r, err) assert.Nil(r, tc.want) // should not define a non-nil want - require.Equal(r, tc.wantErr, resp.Body.String()) + require.Contains(r, err.Error(), tc.wantErr) require.Nil(r, obj) return } + assert.Nil(r, err) require.NoError(r, checkIndex(resp)) require.NotNil(r, obj)