diff --git a/agent/acl.go b/agent/acl.go index b4108b4848..50e913a4d9 100644 --- a/agent/acl.go +++ b/agent/acl.go @@ -2,6 +2,7 @@ package agent import ( "fmt" + "net/http" "github.com/hashicorp/serf/serf" @@ -82,9 +83,10 @@ func (a *Agent) vetServiceUpdateWithAuthorizer(authz acl.Authorizer, serviceID s // agent/local/state.go's deleteService assumes the Catalog.Deregister RPC call // will include "Unknown service"in the error if deregistration fails due to a // service with that ID not existing. - return NotFoundError{Reason: fmt.Sprintf( - "Unknown service ID %q. Ensure that the service ID is passed, not the service name.", - serviceID)} + return HTTPError{ + StatusCode: http.StatusNotFound, + Reason: fmt.Sprintf("Unknown service ID %q. Ensure that the service ID is passed, not the service name.", serviceID), + } } return nil @@ -140,9 +142,10 @@ func (a *Agent) vetCheckUpdateWithAuthorizer(authz acl.Authorizer, checkID struc } } } else { - return NotFoundError{Reason: fmt.Sprintf( - "Unknown check ID %q. Ensure that the check ID is passed, not the check name.", - checkID.String())} + return HTTPError{ + StatusCode: http.StatusNotFound, + Reason: fmt.Sprintf("Unknown check ID %q. Ensure that the check ID is passed, not the check name.", checkID.String()), + } } return nil diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index 54f6c3948a..ef859f84e9 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -16,7 +16,7 @@ type aclBootstrapResponse struct { structs.ACLToken } -var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"} +var aclDisabled = HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"} // checkACLDisabled will return a standard response if ACLs are disabled. This // returns true if they are disabled and we should not continue. @@ -127,7 +127,7 @@ func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request return nil, err } if policyID == "" && req.Method != "PUT" { - return nil, BadRequestError{Reason: "Missing policy ID"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy ID"} } return fn(resp, req, policyID) @@ -175,7 +175,7 @@ func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.R return nil, err } if policyName == "" { - return nil, BadRequestError{Reason: "Missing policy Name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy Name"} } return s.ACLPolicyRead(resp, req, "", policyName) @@ -207,18 +207,18 @@ func (s *HTTPHandlers) aclPolicyWriteInternal(_resp http.ResponseWriter, req *ht } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Policy)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Policy decoding failed: %v", err)} } args.Policy.Syntax = acl.SyntaxCurrent if create { if args.Policy.ID != "" { - return nil, BadRequestError{Reason: "Cannot specify the ID when creating a new policy"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify the ID when creating a new policy"} } } else { if args.Policy.ID != "" && args.Policy.ID != policyID { - return nil, BadRequestError{Reason: "Policy ID in URL and payload do not match"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Policy ID in URL and payload do not match"} } else if args.Policy.ID == "" { args.Policy.ID = policyID } @@ -317,7 +317,7 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request) fn = s.ACLTokenClone } if tokenID == "" && req.Method != "PUT" { - return nil, BadRequestError{Reason: "Missing token ID"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing token ID"} } return fn(resp, req, tokenID) @@ -422,12 +422,12 @@ func (s *HTTPHandlers) aclTokenSetInternal(req *http.Request, tokenID string, cr } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)} } if !create { if args.ACLToken.AccessorID != "" && args.ACLToken.AccessorID != tokenID { - return nil, BadRequestError{Reason: "Token Accessor ID in URL and payload do not match"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Token Accessor ID in URL and payload do not match"} } else if args.ACLToken.AccessorID == "" { args.ACLToken.AccessorID = tokenID } @@ -472,7 +472,7 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request return nil, err } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)} } s.parseToken(req, &args.Token) @@ -546,7 +546,7 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) return nil, err } if roleID == "" && req.Method != "PUT" { - return nil, BadRequestError{Reason: "Missing role ID"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role ID"} } return fn(resp, req, roleID) @@ -562,7 +562,7 @@ func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Req return nil, err } if roleName == "" { - return nil, BadRequestError{Reason: "Missing role Name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role Name"} } return s.ACLRoleRead(resp, req, "", roleName) @@ -621,11 +621,11 @@ func (s *HTTPHandlers) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Role)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Role decoding failed: %v", err)} } if args.Role.ID != "" && args.Role.ID != roleID { - return nil, BadRequestError{Reason: "Role ID in URL and payload do not match"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Role ID in URL and payload do not match"} } else if args.Role.ID == "" { args.Role.ID = roleID } @@ -716,7 +716,7 @@ func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Re return nil, err } if bindingRuleID == "" && req.Method != "PUT" { - return nil, BadRequestError{Reason: "Missing binding rule ID"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing binding rule ID"} } return fn(resp, req, bindingRuleID) @@ -770,11 +770,11 @@ func (s *HTTPHandlers) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.R } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.BindingRule)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} } if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID { - return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "BindingRule ID in URL and payload do not match"} } else if args.BindingRule.ID == "" { args.BindingRule.ID = bindingRuleID } @@ -862,7 +862,7 @@ func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Req return nil, err } if methodName == "" && req.Method != "PUT" { - return nil, BadRequestError{Reason: "Missing auth method name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing auth method name"} } return fn(resp, req, methodName) @@ -916,12 +916,12 @@ func (s *HTTPHandlers) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Re } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.AuthMethod)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} } if methodName != "" { if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName { - return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "AuthMethod Name in URL and payload do not match"} } else if args.AuthMethod.Name == "" { args.AuthMethod.Name = methodName } @@ -969,7 +969,7 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in } if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Auth)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)} } var out structs.ACLToken @@ -1058,11 +1058,11 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request) s.parseDC(req, &request.Datacenter) if err := decodeBody(req.Body, &request.Requests); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)} } if len(request.Requests) > maxRequests { - return nil, BadRequestError{Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)} } if len(request.Requests) == 0 { @@ -1083,7 +1083,7 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request) responses, err = structs.CreateACLAuthorizationResponses(authz, request.Requests) if err != nil { - return nil, BadRequestError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()} } } diff --git a/agent/acl_endpoint_test.go b/agent/acl_endpoint_test.go index 01a3f0b5ec..60a512ef41 100644 --- a/agent/acl_endpoint_test.go +++ b/agent/acl_endpoint_test.go @@ -26,6 +26,16 @@ import ( // They are not intended to thoroughly test the backing RPC // functionality as that will be done with other tests. +func isHTTPBadRequest(err error) bool { + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 400 { + return false + } + return true + } + return false +} + func TestACL_Disabled_Response(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -71,7 +81,7 @@ func TestACL_Disabled_Response(t *testing.T) { resp := httptest.NewRecorder() obj, err := tt.fn(resp, req) require.Nil(t, obj) - require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"}) + require.ErrorIs(t, err, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"}) }) } } @@ -270,8 +280,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLPolicyCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Policy CRUD Missing ID in URL", func(t *testing.T) { @@ -279,8 +288,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLPolicyCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update", func(t *testing.T) { @@ -327,8 +335,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLPolicyCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Invalid payload", func(t *testing.T) { @@ -339,8 +346,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLPolicyCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Delete", func(t *testing.T) { @@ -497,8 +503,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLRoleCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Role CRUD Missing ID in URL", func(t *testing.T) { @@ -506,8 +511,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLRoleCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update", func(t *testing.T) { @@ -567,8 +571,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLRoleCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Invalid payload", func(t *testing.T) { @@ -579,8 +582,7 @@ func TestACL_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLRoleCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Delete", func(t *testing.T) { @@ -818,8 +820,7 @@ func TestACL_HTTP(t *testing.T) { obj, err := a.srv.ACLTokenCRUD(resp, req) require.Error(t, err) require.Nil(t, obj) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update Accessor Mismatch", func(t *testing.T) { originalToken := tokenMap[idMap["token-cloned"]] @@ -841,8 +842,7 @@ func TestACL_HTTP(t *testing.T) { obj, err := a.srv.ACLTokenCRUD(resp, req) require.Error(t, err) require.Nil(t, obj) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Delete", func(t *testing.T) { req, _ := http.NewRequest("DELETE", "/v1/acl/token/"+idMap["token-cloned"]+"?token=root", nil) @@ -1284,8 +1284,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLAuthMethodCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update Name URL Mismatch", func(t *testing.T) { @@ -1302,8 +1301,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLAuthMethodCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update", func(t *testing.T) { @@ -1342,8 +1340,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLAuthMethodCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("List", func(t *testing.T) { @@ -1480,8 +1477,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLBindingRuleCRUD(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Update", func(t *testing.T) { @@ -1529,8 +1525,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLBindingRuleCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("Invalid payload", func(t *testing.T) { @@ -1541,8 +1536,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.ACLBindingRuleCreate(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) }) t.Run("List", func(t *testing.T) { diff --git a/agent/agent_endpoint.go b/agent/agent_endpoint.go index d9a516f966..abaec4e498 100644 --- a/agent/agent_endpoint.go +++ b/agent/agent_endpoint.go @@ -425,7 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request) svcState := s.agent.State.ServiceState(sid) if svcState == nil { - return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())} + return "", nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("unknown service ID: %s", sid.String())} } svc := svcState.Service @@ -555,7 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request) // key are ok, otherwise the argument doesn't apply to // the WAN. default: - return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot provide a segment with wan=true"} } } @@ -732,16 +732,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re } if err := decodeBody(req.Body, &args); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Verify the check has a name. if args.Name == "" { - return nil, BadRequestError{Reason: "Missing check name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check name"} } if args.Status != "" && !structs.ValidStatus(args.Status) { - return nil, BadRequestError{Reason: "Bad check status"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Bad check status"} } authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil) @@ -760,7 +760,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re chkType := args.CheckType() err = chkType.Validate() if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)} } // Store the type of check based on the definition @@ -773,7 +773,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re if service != nil { health.ServiceName = service.Service } else { - return nil, NotFoundError{fmt.Sprintf("ServiceID %q does not exist", cid.String())} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("ServiceID %q does not exist", cid.String())} } } @@ -878,7 +878,7 @@ type checkUpdate struct { func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { var update checkUpdate if err := decodeBody(req.Body, &update); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } switch update.Status { @@ -886,7 +886,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ case api.HealthWarning: case api.HealthCritical: default: - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)} } ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/") @@ -981,7 +981,7 @@ func (s *HTTPHandlers) AgentHealthServiceByID(resp http.ResponseWriter, req *htt return nil, err } if serviceID == "" { - return nil, &BadRequestError{Reason: "Missing serviceID"} + return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing serviceID"} } var entMeta acl.EnterpriseMeta @@ -1043,7 +1043,7 @@ func (s *HTTPHandlers) AgentHealthServiceByName(resp http.ResponseWriter, req *h } if serviceName == "" { - return nil, &BadRequestError{Reason: "Missing service Name"} + return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service Name"} } var entMeta acl.EnterpriseMeta @@ -1114,18 +1114,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http. } if err := decodeBody(req.Body, &args); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Verify the service has a name. if args.Name == "" { - return nil, BadRequestError{Reason: "Missing service name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"} } // Check the service address here and in the catalog RPC endpoint // since service registration isn't synchronous. if ipaddr.IsAny(args.Address) { - return nil, BadRequestError{Reason: "Invalid service address"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid service address"} } var token string @@ -1144,27 +1144,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http. ns := args.NodeService() if ns.Weights != nil { if err := structs.ValidateWeights(ns.Weights); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Weights: %v", err)} } } if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Service Meta: %v", err)} } // Run validation. This is the same validation that would happen on // the catalog endpoint so it helps ensure the sync will work properly. if err := ns.Validate(); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Validation failed: %v", err.Error())} } // Verify the check type. chkTypes, err := args.CheckTypes() if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)} } for _, check := range chkTypes { if check.Status != "" && !structs.ValidStatus(check.Status) { - return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"} } } @@ -1172,15 +1172,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http. if args.Connect != nil && args.Connect.SidecarService != nil { chkTypes, err := args.Connect.SidecarService.CheckTypes() if err != nil { - return nil, &BadRequestError{ - Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err), - } + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err)} } for _, check := range chkTypes { if check.Status != "" && !structs.ValidStatus(check.Status) { - return nil, &BadRequestError{ - Reason: "Status for checks must 'passing', 'warning', 'critical'", - } + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"} } } } @@ -1193,12 +1189,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http. // See if we have a sidecar to register too sidecar, sidecarChecks, sidecarToken, err := s.agent.sidecarServiceFromNodeService(ns, token) if err != nil { - return nil, &BadRequestError{ - Reason: fmt.Sprintf("Invalid SidecarService: %s", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid SidecarService: %s", err)} } if sidecar != nil { if err := sidecar.Validate(); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed Validation: %v", err.Error())} } // Make sure we are allowed to register the sidecar using the token // specified (might be specific to sidecar or the same one as the overall @@ -1299,19 +1294,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht sid := structs.NewServiceID(serviceID, nil) if sid.ID == "" { - return nil, BadRequestError{Reason: "Missing service ID"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service ID"} } // Ensure we have some action params := req.URL.Query() if _, ok := params["enable"]; !ok { - return nil, BadRequestError{Reason: "Missing value for enable"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"} } raw := params.Get("enable") enable, err := strconv.ParseBool(raw) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} } // Get the provided token, if any, and vet against any ACL policies. @@ -1340,11 +1335,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht if enable { reason := params.Get("reason") if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } } else { if err = s.agent.DisableServiceMaintenance(sid); err != nil { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } } s.syncChanges() @@ -1355,13 +1350,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http. // Ensure we have some action params := req.URL.Query() if _, ok := params["enable"]; !ok { - return nil, BadRequestError{Reason: "Missing value for enable"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"} } raw := params.Get("enable") enable, err := strconv.ParseBool(raw) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} } // Get the provided token, if any, and vet against any ACL policies. @@ -1416,9 +1411,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request) } if !logging.ValidateLogLevel(logLevel) { - return nil, BadRequestError{ - Reason: fmt.Sprintf("Unknown log level: %s", logLevel), - } + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unknown log level: %s", logLevel)} } flusher, ok := resp.(http.Flusher) @@ -1469,7 +1462,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request) func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) { if s.checkACLDisabled() { - return nil, UnauthorizedError{Reason: "ACL support disabled"} + return nil, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"} } // Fetch the ACL token, if any, and enforce agent policy. @@ -1491,7 +1484,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) ( // fields to this later if needed. var args api.AgentToken if err := decodeBody(req.Body, &args); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Figure out the target token. @@ -1522,7 +1515,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) ( s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI) default: - return NotFoundError{Reason: fmt.Sprintf("Token %q is unknown", target)} + return HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Token %q is unknown", target)} } // TODO: is it safe to move this out of WithPersistenceLock? @@ -1641,7 +1634,7 @@ func (s *HTTPHandlers) AgentConnectAuthorize(resp http.ResponseWriter, req *http } if err := decodeBody(req.Body, &authReq); err != nil { - return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } if !s.validateRequestPartition(resp, &authReq.EnterpriseMeta) { diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index fc29333e92..9aa0a17274 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -5473,7 +5473,7 @@ func TestAgent_Token(t *testing.T) { url: "acl_token?token=root", body: badJSON(), code: http.StatusBadRequest, - expectedErr: `Bad request: Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`, + expectedErr: `Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`, }, { name: "set user legacy", diff --git a/agent/catalog_endpoint.go b/agent/catalog_endpoint.go index 2ae1b07dc0..b873b8b5ef 100644 --- a/agent/catalog_endpoint.go +++ b/agent/catalog_endpoint.go @@ -136,7 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque } if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Setup the default DC if not provided @@ -166,7 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req return nil, err } if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Setup the default DC if not provided @@ -363,7 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing service name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"} } // Make the RPC request @@ -438,7 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R return nil, err } if args.Node == "" { - return nil, BadRequestError{Reason: "Missing node name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"} } // Make the RPC request @@ -503,7 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt return nil, err } if args.Node == "" { - return nil, BadRequestError{Reason: "Missing node name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"} } // Make the RPC request @@ -554,7 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing gateway name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"} } // Make the RPC request diff --git a/agent/config_endpoint.go b/agent/config_endpoint.go index 637b8ab919..b51c1e273f 100644 --- a/agent/config_endpoint.go +++ b/agent/config_endpoint.go @@ -56,7 +56,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i setMeta(resp, &reply.QueryMeta) if reply.Entry == nil { - return nil, NotFoundError{Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])} } return reply.Entry, nil @@ -75,7 +75,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i return reply.Entries, nil default: - return nil, NotFoundError{Reason: "Must provide either a kind or both kind and name"} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide either a kind or both kind and name"} } } @@ -91,12 +91,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request) pathArgs := strings.SplitN(kindAndName, "/", 2) if len(pathArgs) != 2 { - return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide both a kind and name to delete"} } entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1]) if err != nil { - return nil, BadRequestError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()} } args.Entry = entry // Parse enterprise meta. @@ -139,13 +139,13 @@ func (s *HTTPHandlers) ConfigApply(resp http.ResponseWriter, req *http.Request) var raw map[string]interface{} if err := decodeBodyDeprecated(req, &raw, nil); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)} } if entry, err := structs.DecodeConfigEntry(raw); err == nil { args.Entry = entry } else { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)} } // Parse enterprise meta. diff --git a/agent/config_endpoint_test.go b/agent/config_endpoint_test.go index 949b6be118..b8ed9d5507 100644 --- a/agent/config_endpoint_test.go +++ b/agent/config_endpoint_test.go @@ -601,9 +601,8 @@ func TestConfig_Apply_Decoding(t *testing.T) { _, err := a.srv.ConfigApply(resp, req) require.Error(t, err) - badReq, ok := err.(BadRequestError) - require.True(t, ok) - require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", badReq.Reason) + require.True(t, isHTTPBadRequest(err)) + require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", err.Error()) }) t.Run("Kind Not String", func(t *testing.T) { @@ -619,9 +618,8 @@ func TestConfig_Apply_Decoding(t *testing.T) { _, err := a.srv.ConfigApply(resp, req) require.Error(t, err) - badReq, ok := err.(BadRequestError) - require.True(t, ok) - require.Equal(t, "Request decoding failed: Kind value in payload is not a string", badReq.Reason) + require.True(t, isHTTPBadRequest(err)) + require.Equal(t, "Request decoding failed: Kind value in payload is not a string", err.Error()) }) t.Run("Lowercase kind", func(t *testing.T) { diff --git a/agent/connect_auth.go b/agent/connect_auth.go index 9bd8a46ebb..80fa6d42ea 100644 --- a/agent/connect_auth.go +++ b/agent/connect_auth.go @@ -3,6 +3,7 @@ package agent import ( "context" "fmt" + "net/http" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/cache" @@ -23,7 +24,7 @@ import ( // The ACL token and the auth request are provided and the auth decision (true // means authorized) and reason string are returned. // -// If the request input is invalid the error returned will be a BadRequestError, +// If the request input is invalid the error returned will be a BadRequest HTTPError, // if the token doesn't grant necessary access then an acl.ErrPermissionDenied // error is returned, otherwise error indicates an unexpected server failure. If // access is denied, no error is returned but the first return value is false. @@ -37,23 +38,23 @@ func (a *Agent) ConnectAuthorize(token string, } if req == nil { - return returnErr(BadRequestError{"Invalid request"}) + return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid request"}) } // We need to have a target to check intentions if req.Target == "" { - return returnErr(BadRequestError{"Target service must be specified"}) + return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Target service must be specified"}) } // Parse the certificate URI from the client ID uri, err := connect.ParseCertURIFromString(req.ClientCertURI) if err != nil { - return returnErr(BadRequestError{"ClientCertURI not a valid Connect identifier"}) + return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Connect identifier"}) } uriService, ok := uri.(*connect.SpiffeIDService) if !ok { - return returnErr(BadRequestError{"ClientCertURI not a valid Service identifier"}) + return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Service identifier"}) } // We need to verify service:write permissions for the given token. diff --git a/agent/connect_ca_endpoint.go b/agent/connect_ca_endpoint.go index 2c6a1dfabc..2e78bc7d89 100644 --- a/agent/connect_ca_endpoint.go +++ b/agent/connect_ca_endpoint.go @@ -20,7 +20,7 @@ func (s *HTTPHandlers) ConnectCARoots(resp http.ResponseWriter, req *http.Reques if pemParam := req.URL.Query().Get("pem"); pemParam != "" { val, err := strconv.ParseBool(pemParam) if err != nil { - return nil, BadRequestError{Reason: "The 'pem' query parameter must be a boolean value"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The 'pem' query parameter must be a boolean value"} } pemResponse = val } @@ -90,15 +90,14 @@ func (s *HTTPHandlers) ConnectCAConfigurationSet(req *http.Request) (interface{} s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) if err := decodeBody(req.Body, &args.Config); err != nil { - return nil, BadRequestError{ - Reason: fmt.Sprintf("Request decode failed: %v", err), - } + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } var reply interface{} err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply) if err != nil && err.Error() == consul.ErrStateReadOnly.Error() { - return nil, BadRequestError{ + return nil, HTTPError{ + StatusCode: http.StatusBadRequest, Reason: "Provider State is read-only. It must be omitted" + " or identical to the current value", } diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index ff3df3d06c..14176ab1e1 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -14,7 +14,7 @@ func (s *HTTPHandlers) checkCoordinateDisabled() error { if !s.agent.config.DisableCoordinates { return nil } - return UnauthorizedError{Reason: "Coordinate support disabled"} + return HTTPError{StatusCode: http.StatusUnauthorized, Reason: "Coordinate support disabled"} } // sorter wraps a coordinate list and implements the sort.Interface to sort by @@ -156,7 +156,7 @@ func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Requ args := structs.CoordinateUpdateRequest{} if err := decodeBody(req.Body, &args); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) diff --git a/agent/coordinate_endpoint_test.go b/agent/coordinate_endpoint_test.go index ff38d83227..331451641f 100644 --- a/agent/coordinate_endpoint_test.go +++ b/agent/coordinate_endpoint_test.go @@ -39,10 +39,14 @@ func TestCoordinate_Disabled_Response(t *testing.T) { req, _ := http.NewRequest("PUT", "/should/not/care", nil) resp := httptest.NewRecorder() obj, err := tt(resp, req) - err, ok := err.(UnauthorizedError) - if !ok { - t.Fatalf("expected unauthorized error but got %v", err) + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 401 { + t.Fatalf("expected status 401 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP error but got %v", err) } + if obj != nil { t.Fatalf("bad: %#v", obj) } diff --git a/agent/discovery_chain_endpoint.go b/agent/discovery_chain_endpoint.go index e9bb631850..475ef02d6c 100644 --- a/agent/discovery_chain_endpoint.go +++ b/agent/discovery_chain_endpoint.go @@ -25,7 +25,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re return nil, err } if args.Name == "" { - return nil, BadRequestError{Reason: "Missing chain name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing chain name"} } args.EvaluateInDatacenter = req.URL.Query().Get("compile-dc") @@ -38,12 +38,12 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re if req.Method == "POST" { var raw map[string]interface{} if err := decodeBody(req.Body, &raw); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)} } apiReq, err := decodeDiscoveryChainReadRequest(raw) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)} } args.OverrideProtocol = apiReq.OverrideProtocol @@ -52,7 +52,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re if apiReq.OverrideMeshGateway.Mode != "" { _, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode)) if err != nil { - return nil, BadRequestError{Reason: "Invalid OverrideMeshGateway.Mode parameter"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid OverrideMeshGateway.Mode parameter"} } args.OverrideMeshGateway = apiReq.OverrideMeshGateway } diff --git a/agent/discovery_chain_endpoint_test.go b/agent/discovery_chain_endpoint_test.go index b93cd45c92..8b4a7e2723 100644 --- a/agent/discovery_chain_endpoint_test.go +++ b/agent/discovery_chain_endpoint_test.go @@ -57,8 +57,7 @@ func TestDiscoveryChainRead(t *testing.T) { resp := httptest.NewRecorder() _, err = a.srv.DiscoveryChainRead(resp, req) require.Error(t, err) - _, ok := err.(BadRequestError) - require.True(t, ok) + require.True(t, isHTTPBadRequest(err)) })) require.True(t, t.Run(method+": read default chain", func(t *testing.T) { diff --git a/agent/event_endpoint.go b/agent/event_endpoint.go index 53b0e5d65b..78be383d87 100644 --- a/agent/event_endpoint.go +++ b/agent/event_endpoint.go @@ -25,7 +25,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i return nil, err } if event.Name == "" { - return nil, BadRequestError{Reason: "Missing name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing name"} } // Get the ACL token @@ -55,7 +55,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i // Try to fire the event if err := s.agent.UserEvent(dc, token, event); err != nil { if acl.IsErrPermissionDenied(err) { - return nil, ForbiddenError{Reason: acl.ErrPermissionDenied.Error()} + return nil, HTTPError{StatusCode: http.StatusForbidden, Reason: acl.ErrPermissionDenied.Error()} } resp.WriteHeader(http.StatusInternalServerError) return nil, err diff --git a/agent/event_endpoint_test.go b/agent/event_endpoint_test.go index e5f0b39f73..e921f5aa5f 100644 --- a/agent/event_endpoint_test.go +++ b/agent/event_endpoint_test.go @@ -103,8 +103,12 @@ func TestEventFire_token(t *testing.T) { if !acl.IsErrPermissionDenied(err) { t.Fatalf("bad: %s", err.Error()) } - if err, ok := err.(ForbiddenError); !ok { - t.Fatalf("Expected forbidden but got %v", err) + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 403 { + t.Fatalf("Expected 403 but got %d", err.StatusCode) + } + } else { + t.Fatalf("Expected HTTP Error %v", err) } } } diff --git a/agent/federation_state_endpoint.go b/agent/federation_state_endpoint.go index 94b7a7cdd1..aecb58ff33 100644 --- a/agent/federation_state_endpoint.go +++ b/agent/federation_state_endpoint.go @@ -13,7 +13,7 @@ func (s *HTTPHandlers) FederationStateGet(resp http.ResponseWriter, req *http.Re return nil, err } if datacenterName == "" { - return nil, BadRequestError{Reason: "Missing datacenter name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing datacenter name"} } args := structs.FederationStateQuery{ diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index 7f904089d8..656fc2a048 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -34,7 +34,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R return nil, err } if args.State == "" { - return nil, BadRequestError{Reason: "Missing check state"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check state"} } // Make the RPC request @@ -82,7 +82,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ return nil, err } if args.Node == "" { - return nil, BadRequestError{Reason: "Missing node name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"} } // Make the RPC request @@ -132,7 +132,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing service name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"} } // Make the RPC request @@ -224,7 +224,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing service name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"} } out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args) @@ -242,7 +242,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re // Filter to only passing if specified filter, err := getBoolQueryParam(params, api.HealthPassing) if err != nil { - return nil, BadRequestError{Reason: "Invalid value for ?passing"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid value for ?passing"} } // FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer diff --git a/agent/health_endpoint_test.go b/agent/health_endpoint_test.go index f29c2f4be6..c6b4472f39 100644 --- a/agent/health_endpoint_test.go +++ b/agent/health_endpoint_test.go @@ -1405,9 +1405,7 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil) resp := httptest.NewRecorder() _, err := a.srv.HealthServiceNodes(resp, req) - if _, ok := err.(BadRequestError); !ok { - t.Fatalf("Expected bad request error but got %v", err) - } + require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err)) if !strings.Contains(err.Error(), "Invalid value for ?passing") { t.Errorf("bad %s", err.Error()) } @@ -1813,8 +1811,7 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) { resp := httptest.NewRecorder() _, err := a.srv.HealthConnectServiceNodes(resp, req) assert.NotNil(t, err) - _, ok := err.(BadRequestError) - assert.True(t, ok) + assert.True(t, isHTTPBadRequest(err)) assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing")) }) diff --git a/agent/http.go b/agent/http.go index 6be651c1a9..8601eee823 100644 --- a/agent/http.go +++ b/agent/http.go @@ -51,41 +51,6 @@ func (e MethodNotAllowedError) Error() string { return fmt.Sprintf("method %s not allowed", e.Method) } -// BadRequestError should be returned by a handler when parameters or the payload are not valid -type BadRequestError struct { - Reason string -} - -func (e BadRequestError) Error() string { - return fmt.Sprintf("Bad request: %s", e.Reason) -} - -// NotFoundError should be returned by a handler when a resource specified does not exist -type NotFoundError struct { - Reason string -} - -func (e NotFoundError) Error() string { - return e.Reason -} - -// UnauthorizedError should be returned by a handler when the request lacks valid authorization. -type UnauthorizedError struct { - Reason string -} - -func (e UnauthorizedError) Error() string { - return e.Reason -} - -type EntityTooLargeError struct { - Reason string -} - -func (e EntityTooLargeError) Error() string { - return e.Reason -} - // CodeWithPayloadError allow returning non HTTP 200 // Error codes while not returning PlainText payload type CodeWithPayloadError struct { @@ -98,12 +63,15 @@ func (e CodeWithPayloadError) Error() string { return e.Reason } -type ForbiddenError struct { - Reason string +// HTTPError is returned by the handler when a specific http error +// code is needed alongside a plain text response. +type HTTPError struct { + StatusCode int + Reason string } -func (e ForbiddenError) Error() string { - return e.Reason +func (h HTTPError) Error() string { + return h.Reason } // HTTPHandlers provides an HTTP api for an agent. @@ -423,8 +391,7 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc if acl.IsErrPermissionDenied(err) || acl.IsErrNotFound(err) { return true } - _, ok := err.(ForbiddenError) - return ok + return false } isMethodNotAllowed := func(err error) bool { @@ -432,35 +399,20 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc return ok } - isBadRequest := func(err error) bool { - _, ok := err.(BadRequestError) - return ok - } - - isNotFound := func(err error) bool { - _, ok := err.(NotFoundError) - return ok - } - - isUnauthorized := func(err error) bool { - _, ok := err.(UnauthorizedError) - return ok - } - isTooManyRequests := func(err error) bool { // Sadness net/rpc can't do nice typed errors so this is all we got return err.Error() == consul.ErrRateLimited.Error() } - isEntityToLarge := func(err error) bool { - _, ok := err.(EntityTooLargeError) - return ok - } - addAllowHeader := func(methods []string) { resp.Header().Add("Allow", strings.Join(methods, ",")) } + isHTTPError := func(err error) bool { + _, ok := err.(HTTPError) + return ok + } + handleErr := func(err error) { if req.Context().Err() != nil { httpLogger.Info("Request cancelled", @@ -490,21 +442,21 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc addAllowHeader(err.(MethodNotAllowedError).Allow) resp.WriteHeader(http.StatusMethodNotAllowed) // 405 fmt.Fprint(resp, err.Error()) - case isBadRequest(err): - resp.WriteHeader(http.StatusBadRequest) - fmt.Fprint(resp, err.Error()) - case isNotFound(err): - resp.WriteHeader(http.StatusNotFound) - fmt.Fprint(resp, err.Error()) - case isUnauthorized(err): - resp.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(resp, err.Error()) + case isHTTPError(err): + err := err.(HTTPError) + code := http.StatusInternalServerError + if err.StatusCode != 0 { + code = err.StatusCode + } + reason := "An unexpected error occurred" + if err.Error() != "" { + reason = err.Error() + } + resp.WriteHeader(code) + fmt.Fprint(resp, reason) case isTooManyRequests(err): resp.WriteHeader(http.StatusTooManyRequests) fmt.Fprint(resp, err.Error()) - case isEntityToLarge(err): - resp.WriteHeader(http.StatusRequestEntityTooLarge) - fmt.Fprint(resp, err.Error()) default: resp.WriteHeader(http.StatusInternalServerError) fmt.Fprint(resp, err.Error()) @@ -1175,7 +1127,7 @@ func (s *HTTPHandlers) checkWriteAccess(req *http.Request) error { } } - return ForbiddenError{Reason: "Access is restricted"} + return HTTPError{StatusCode: http.StatusForbidden, Reason: "Access is restricted"} } func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) { diff --git a/agent/http_oss.go b/agent/http_oss.go index c14c31d8de..94eb575c36 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -14,10 +14,16 @@ import ( func (s *HTTPHandlers) parseEntMeta(req *http.Request, entMeta *acl.EnterpriseMeta) error { if headerNS := req.Header.Get("X-Consul-Namespace"); headerNS != "" { - return BadRequestError{Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature", + } } if queryNS := req.URL.Query().Get("ns"); queryNS != "" { - return BadRequestError{Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature", + } } return s.parseEntMetaPartition(req, entMeta) @@ -32,7 +38,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionPartition(logName, partition s // No special handling for wildcard namespaces as they are pointless in OSS. - return BadRequestError{Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature", + } } func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string, _ bool) error { @@ -44,7 +53,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string, // No special handling for wildcard namespaces as they are pointless in OSS. - return BadRequestError{Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature", + } } func (s *HTTPHandlers) parseEntMetaNoWildcard(req *http.Request, _ *acl.EnterpriseMeta) error { @@ -72,7 +84,10 @@ func (s *HTTPHandlers) rewordUnknownEnterpriseFieldError(err error) error { func parseACLAuthMethodEnterpriseMeta(req *http.Request, _ *structs.ACLAuthMethodEnterpriseMeta) error { if methodNS := req.URL.Query().Get("authmethod-ns"); methodNS != "" { - return BadRequestError{Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature", + } } return nil @@ -91,10 +106,16 @@ func (s *HTTPHandlers) uiTemplateDataTransform(data map[string]interface{}) erro func (s *HTTPHandlers) parseEntMetaPartition(req *http.Request, meta *acl.EnterpriseMeta) error { if headerAP := req.Header.Get("X-Consul-Partition"); headerAP != "" { - return BadRequestError{Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature", + } } if queryAP := req.URL.Query().Get("partition"); queryAP != "" { - return BadRequestError{Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature"} + return HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature", + } } return nil diff --git a/agent/http_test.go b/agent/http_test.go index e63849b151..525c4675a1 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -1488,10 +1488,16 @@ func TestAllowedNets(t *testing.T) { t.Fatalf("bad checkWriteAccess for values %+v, got %v", v, err) } - _, isForbiddenErr := err.(ForbiddenError) - if err != nil && !isForbiddenErr { - t.Fatalf("expected ForbiddenError but got: %s", err) + if err != nil { + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 403 { + t.Fatalf("expected 403 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP Error but got %v", err) + } } + } } diff --git a/agent/intentions_endpoint.go b/agent/intentions_endpoint.go index b99911f7f6..7c9855a44d 100644 --- a/agent/intentions_endpoint.go +++ b/agent/intentions_endpoint.go @@ -56,8 +56,9 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil { return nil, err } + if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") { - return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"} } args := structs.IntentionRequest{ @@ -70,10 +71,10 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque } if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" { - return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"} } if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" { - return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"} } args.Intention.FillPartitionAndNamespace(&entMeta, false) @@ -324,7 +325,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if err.Error() == consul.ErrIntentionNotFound.Error() { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } // Not ideal, but there are a number of error scenarios that are not @@ -332,7 +333,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req // to detect a parameter error and return a 400 response. The error // is not a constant type or message, so we have to use strings.Contains if strings.Contains(err.Error(), "UUID") { - return nil, BadRequestError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()} } return nil, err @@ -366,7 +367,7 @@ func (s *HTTPHandlers) IntentionPutExact(resp http.ResponseWriter, req *http.Req s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) if err := decodeBody(req.Body, &args.Intention); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } // Explicitly CLEAR the old legacy ID field @@ -520,7 +521,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter, if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { // We have to check the string since the RPC sheds the error type if err.Error() == consul.ErrIntentionNotFound.Error() { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } // Not ideal, but there are a number of error scenarios that are not @@ -528,7 +529,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter, // to detect a parameter error and return a 400 response. The error // is not a constant type or message, so we have to use strings.Contains if strings.Contains(err.Error(), "UUID") { - return nil, BadRequestError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()} } return nil, err @@ -552,8 +553,9 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil { return nil, err } + if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") { - return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"} } args := structs.IntentionRequest{ @@ -562,14 +564,14 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) if err := decodeBody(req.Body, &args.Intention); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" { - return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"} } if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" { - return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"} } args.Intention.FillPartitionAndNamespace(&entMeta, false) diff --git a/agent/intentions_endpoint_test.go b/agent/intentions_endpoint_test.go index 62190cf9ba..1180a2c7dd 100644 --- a/agent/intentions_endpoint_test.go +++ b/agent/intentions_endpoint_test.go @@ -483,7 +483,7 @@ func TestIntentionSpecificGet(t *testing.T) { obj, err := a.srv.IntentionSpecific(resp, req) require.Nil(t, obj) require.Error(t, err) - require.IsType(t, BadRequestError{}, err) + require.True(t, isHTTPBadRequest(err)) require.Contains(t, err.Error(), "UUID") }) diff --git a/agent/kvs_endpoint.go b/agent/kvs_endpoint.go index 85273aa8e5..4ac9f4119f 100644 --- a/agent/kvs_endpoint.go +++ b/agent/kvs_endpoint.go @@ -56,7 +56,7 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args if _, ok := params["recurse"]; ok { method = "KVS.List" } else if args.Key == "" { - return nil, BadRequestError{Reason: "Missing key name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"} } // Do not allow wildcard NS on GET reqs @@ -157,7 +157,7 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args return nil, err } if args.Key == "" { - return nil, BadRequestError{Reason: "Missing key name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"} } if conflictingFlags(resp, req, "cas", "acquire", "release") { return nil, nil @@ -208,7 +208,8 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args // Check the content-length if req.ContentLength > int64(s.agent.config.KVMaxValueSize) { - return nil, EntityTooLargeError{ + return nil, HTTPError{ + StatusCode: http.StatusRequestEntityTooLarge, Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/config/config-files#kv_max_value_size"), } @@ -257,7 +258,7 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar if _, ok := params["recurse"]; ok { applyReq.Op = api.KVDeleteTree } else if args.Key == "" { - return nil, BadRequestError{Reason: "Missing key name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"} } // Check for cas value diff --git a/agent/operator_endpoint.go b/agent/operator_endpoint.go index 4a33497a80..851ef52e1c 100644 --- a/agent/operator_endpoint.go +++ b/agent/operator_endpoint.go @@ -49,10 +49,13 @@ func (s *HTTPHandlers) OperatorRaftPeer(resp http.ResponseWriter, req *http.Requ } if !hasID && !hasAddress { - return nil, BadRequestError{Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove"} + return nil, HTTPError{ + StatusCode: http.StatusBadRequest, + Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove", + } } if hasID && hasAddress { - return nil, BadRequestError{Reason: "Must specify only one of ?id or ?address"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Must specify only one of ?id or ?address"} } var reply struct{} @@ -79,7 +82,7 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht var args keyringArgs if req.Method == "POST" || req.Method == "PUT" || req.Method == "DELETE" { if err := decodeBody(req.Body, &args); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } } s.parseToken(req, &args.Token) @@ -88,12 +91,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht if relayFactor := req.URL.Query().Get("relay-factor"); relayFactor != "" { n, err := strconv.Atoi(relayFactor) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing relay factor: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing relay factor: %v", err)} } args.RelayFactor, err = ParseRelayFactor(n) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid relay-factor: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid relay-factor: %v", err)} } } @@ -102,12 +105,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht var err error args.LocalOnly, err = strconv.ParseBool(localOnly) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing local-only: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing local-only: %v", err)} } err = ValidateLocalOnly(args.LocalOnly, req.Method == "GET") if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Invalid use of local-only: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid use of local-only: %v", err)} } } @@ -226,7 +229,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter, conf := api.NewAutopilotConfiguration() if err := decodeBody(req.Body, &conf); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)} } args.Config = structs.AutopilotConfig{ @@ -245,7 +248,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter, if _, ok := params["cas"]; ok { casVal, err := strconv.ParseUint(params.Get("cas"), 10, 64) if err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing cas value: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing cas value: %v", err)} } args.Config.ModifyIndex = casVal args.CAS = true diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index b398e24465..f6cbe64941 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -23,7 +23,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) if err := decodeBody(req.Body, &args.Query); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } var reply string @@ -143,7 +143,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter, // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } return nil, err } @@ -196,7 +196,7 @@ RETRY_ONCE: // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } return nil, err } @@ -225,7 +225,7 @@ RETRY_ONCE: // We have to check the string since the RPC sheds // the specific error type. if structs.IsErrQueryNotFound(err) { - return nil, NotFoundError{Reason: err.Error()} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()} } return nil, err } @@ -247,7 +247,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter, s.parseToken(req, &args.Token) if req.ContentLength > 0 { if err := decodeBody(req.Body, &args.Query); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } } diff --git a/agent/prepared_query_endpoint_test.go b/agent/prepared_query_endpoint_test.go index e4d3056e83..34b8975fda 100644 --- a/agent/prepared_query_endpoint_test.go +++ b/agent/prepared_query_endpoint_test.go @@ -621,8 +621,12 @@ func TestPreparedQuery_Execute(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body) resp := httptest.NewRecorder() _, err := a.srv.PreparedQuerySpecific(resp, req) - if err, ok := err.(NotFoundError); !ok { - t.Fatalf("Expected not found error but got %v", err) + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 404 { + t.Fatalf("expected status 404 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP error but got %v", err) } }) } @@ -756,8 +760,12 @@ func TestPreparedQuery_Explain(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body) resp := httptest.NewRecorder() _, err := a.srv.PreparedQuerySpecific(resp, req) - if err, ok := err.(NotFoundError); !ok { - t.Fatalf("Expected not found error but got %v", err) + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 404 { + t.Fatalf("expected status 404 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP error but got %v", err) } }) @@ -845,8 +853,12 @@ func TestPreparedQuery_Get(t *testing.T) { req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body) resp := httptest.NewRecorder() _, err := a.srv.PreparedQuerySpecific(resp, req) - if err, ok := err.(NotFoundError); !ok { - t.Fatalf("Expected not found error but got %v", err) + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 404 { + t.Fatalf("expected status 404 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP error but got %v", err) } }) } diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index afe3faa3c4..74228cbb19 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -40,7 +40,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request // Handle optional request body if req.ContentLength > 0 { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil { - return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)} } } @@ -75,7 +75,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques return nil, err } if args.Session.ID == "" { - return nil, BadRequestError{Reason: "Missing session"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"} } var out string @@ -103,14 +103,14 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request) } args.Session = args.SessionID if args.SessionID == "" { - return nil, BadRequestError{Reason: "Missing session"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"} } var out structs.IndexedSessions if err := s.agent.RPC("Session.Renew", &args, &out); err != nil { return nil, err } else if out.Sessions == nil { - return nil, NotFoundError{Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)} } return out.Sessions, nil @@ -134,7 +134,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) ( } args.Session = args.SessionID if args.SessionID == "" { - return nil, BadRequestError{Reason: "Missing session"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"} } var out structs.IndexedSessions @@ -190,7 +190,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque return nil, err } if args.Node == "" { - return nil, BadRequestError{Reason: "Missing node name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"} } var out structs.IndexedSessions diff --git a/agent/txn_endpoint.go b/agent/txn_endpoint.go index c75d30bcc6..4e898bfce8 100644 --- a/agent/txn_endpoint.go +++ b/agent/txn_endpoint.go @@ -88,7 +88,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( // Check Content-Length first before decoding to return early if req.ContentLength > maxTxnLen { - return nil, 0, EntityTooLargeError{ + return nil, 0, HTTPError{ + StatusCode: http.StatusRequestEntityTooLarge, Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"), } @@ -100,7 +101,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( if err.Error() == "http: request body too large" { // The request size is also verified during decoding to double check // if the Content-Length header was not set by the client. - return nil, 0, EntityTooLargeError{ + return nil, 0, HTTPError{ + StatusCode: http.StatusRequestEntityTooLarge, Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.", maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"), } @@ -108,15 +110,16 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( // Note the body is in API format, and not the RPC format. If we can't // decode it, we will return a 400 since we don't have enough context to // associate the error with a given operation. - return nil, 0, BadRequestError{Reason: fmt.Sprintf("Failed to parse body: %v", err)} + return nil, 0, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to parse body: %v", err)} } } // Enforce a reasonable upper limit on the number of operations in a // transaction in order to curb abuse. if size := len(ops); size > maxTxnOps { - return nil, 0, EntityTooLargeError{ - Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps), + return nil, 0, HTTPError{ + StatusCode: http.StatusRequestEntityTooLarge, + Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps), } } @@ -130,8 +133,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) ( case in.KV != nil: size := len(in.KV.Value) if int64(size) > kvMaxValueSize { - return nil, 0, EntityTooLargeError{ - Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize), + return nil, 0, HTTPError{ + StatusCode: http.StatusRequestEntityTooLarge, + Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize), } } diff --git a/agent/txn_endpoint_test.go b/agent/txn_endpoint_test.go index 2f9d6fbca1..4b529d5dee 100644 --- a/agent/txn_endpoint_test.go +++ b/agent/txn_endpoint_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" @@ -31,10 +32,7 @@ func TestTxnEndpoint_Bad_JSON(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() _, err := a.srv.Txn(resp, req) - err, ok := err.(BadRequestError) - if !ok { - t.Fatalf("expected bad request error but got %v", err) - } + require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err)) if !strings.Contains(err.Error(), "Failed to parse") { t.Fatalf("expected conflicting args error") } @@ -63,11 +61,19 @@ func TestTxnEndpoint_Bad_Size_Item(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() _, err := agent.srv.Txn(resp, req) - if err, ok := err.(EntityTooLargeError); !ok && !wantPass { - t.Fatalf("expected too large error but got %v", err) - } - if err != nil && wantPass { - t.Fatalf("err: %v", err) + + if wantPass { + if err != nil { + t.Fatalf("err: %v", err) + } + } else { + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 413 { + t.Fatalf("expected 413 but got %d", err.StatusCode) + } + } else { + t.Fatalf("excected HTTP error but got %v", err) + } } } @@ -138,11 +144,19 @@ func TestTxnEndpoint_Bad_Size_Net(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() _, err := agent.srv.Txn(resp, req) - if err, ok := err.(EntityTooLargeError); !ok && !wantPass { - t.Fatalf("expected too large error but got %v", err) - } - if err != nil && wantPass { - t.Fatalf("err: %v", err) + + if wantPass { + if err != nil { + t.Fatalf("err: %v", err) + } + } else { + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 413 { + t.Fatalf("expected 413 but got %d", err.StatusCode) + } + } else { + t.Fatalf("excected HTTP error but got %v", err) + } } } @@ -205,8 +219,13 @@ func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) { req, _ := http.NewRequest("PUT", "/v1/txn", buf) resp := httptest.NewRecorder() _, err := a.srv.Txn(resp, req) - if err, ok := err.(EntityTooLargeError); !ok { - t.Fatalf("expected too large error but got %v", err) + + if err, ok := err.(HTTPError); ok { + if err.StatusCode != 413 { + t.Fatalf("expected 413 but got %d", err.StatusCode) + } + } else { + t.Fatalf("expected HTTP error but got %v", err) } } diff --git a/agent/ui_endpoint.go b/agent/ui_endpoint.go index dfe14e9d5a..c9fb82d086 100644 --- a/agent/ui_endpoint.go +++ b/agent/ui_endpoint.go @@ -140,7 +140,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) ( return nil, err } if args.Node == "" { - return nil, BadRequestError{Reason: "Missing node name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"} } // Make the RPC request @@ -272,7 +272,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing gateway name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"} } // Make the RPC request @@ -316,12 +316,12 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req return nil, err } if args.ServiceName == "" { - return nil, BadRequestError{Reason: "Missing service name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"} } kind, ok := req.URL.Query()["kind"] if !ok { - return nil, BadRequestError{Reason: "Missing service kind"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service kind"} } args.ServiceKind = structs.ServiceKind(kind[0]) @@ -329,7 +329,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req case structs.ServiceKindTypical, structs.ServiceKindIngressGateway: // allowed default: - return nil, BadRequestError{Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)} } // Make the RPC request @@ -594,7 +594,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R return nil, err } if name == "" { - return nil, BadRequestError{Reason: "Missing gateway name"} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"} } args.Match = &structs.IntentionQueryMatch{ Type: structs.IntentionMatchDestination, @@ -624,14 +624,14 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques // Check the UI was enabled at agent startup (note this is not reloadable // currently). if !s.IsUIEnabled() { - return nil, NotFoundError{Reason: "UI is not enabled"} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "UI is not enabled"} } // Load reloadable proxy config cfg, ok := s.metricsProxyCfg.Load().(config.UIMetricsProxy) if !ok || cfg.BaseURL == "" { // Proxy not configured - return nil, NotFoundError{Reason: "Metrics proxy is not enabled"} + return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Metrics proxy is not enabled"} } // Fetch the ACL token, if provided, but ONLY from headers since other @@ -686,7 +686,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques u, err := url.Parse(newURL) if err != nil { log.Error("couldn't parse target URL", "base_url", cfg.BaseURL, "path", subPath) - return nil, BadRequestError{Reason: "Invalid path."} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."} } // Clean the new URL path to prevent path traversal attacks and remove any @@ -735,7 +735,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques "path", subPath, "target_url", u.String(), ) - return nil, BadRequestError{Reason: "Invalid path."} + return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."} } // Add any configured headers