Unify various status errors into one HTTP error type. (#12594)

Replaces specific error types for HTTP Status codes with 
a generic HTTPError type.

Co-authored-by: Chris S. Kim <ckim@hashicorp.com>
This commit is contained in:
Mathew Estafanous 2022-04-29 13:42:49 -04:00 committed by GitHub
parent d04fe6ca2c
commit 474385d153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 332 additions and 320 deletions

View File

@ -2,6 +2,7 @@ package agent
import ( import (
"fmt" "fmt"
"net/http"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
@ -82,9 +83,10 @@ func (a *Agent) vetServiceUpdateWithAuthorizer(authz acl.Authorizer, serviceID s
// agent/local/state.go's deleteService assumes the Catalog.Deregister RPC call // agent/local/state.go's deleteService assumes the Catalog.Deregister RPC call
// will include "Unknown service"in the error if deregistration fails due to a // will include "Unknown service"in the error if deregistration fails due to a
// service with that ID not existing. // service with that ID not existing.
return NotFoundError{Reason: fmt.Sprintf( return HTTPError{
"Unknown service ID %q. Ensure that the service ID is passed, not the service name.", StatusCode: http.StatusNotFound,
serviceID)} Reason: fmt.Sprintf("Unknown service ID %q. Ensure that the service ID is passed, not the service name.", serviceID),
}
} }
return nil return nil
@ -140,9 +142,10 @@ func (a *Agent) vetCheckUpdateWithAuthorizer(authz acl.Authorizer, checkID struc
} }
} }
} else { } else {
return NotFoundError{Reason: fmt.Sprintf( return HTTPError{
"Unknown check ID %q. Ensure that the check ID is passed, not the check name.", StatusCode: http.StatusNotFound,
checkID.String())} Reason: fmt.Sprintf("Unknown check ID %q. Ensure that the check ID is passed, not the check name.", checkID.String()),
}
} }
return nil return nil

View File

@ -16,7 +16,7 @@ type aclBootstrapResponse struct {
structs.ACLToken structs.ACLToken
} }
var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"} var aclDisabled = HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"}
// checkACLDisabled will return a standard response if ACLs are disabled. This // checkACLDisabled will return a standard response if ACLs are disabled. This
// returns true if they are disabled and we should not continue. // returns true if they are disabled and we should not continue.
@ -127,7 +127,7 @@ func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request
return nil, err return nil, err
} }
if policyID == "" && req.Method != "PUT" { if policyID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing policy ID"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy ID"}
} }
return fn(resp, req, policyID) return fn(resp, req, policyID)
@ -175,7 +175,7 @@ func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if policyName == "" { if policyName == "" {
return nil, BadRequestError{Reason: "Missing policy Name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy Name"}
} }
return s.ACLPolicyRead(resp, req, "", policyName) return s.ACLPolicyRead(resp, req, "", policyName)
@ -207,18 +207,18 @@ func (s *HTTPHandlers) aclPolicyWriteInternal(_resp http.ResponseWriter, req *ht
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Policy)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Policy)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Policy decoding failed: %v", err)}
} }
args.Policy.Syntax = acl.SyntaxCurrent args.Policy.Syntax = acl.SyntaxCurrent
if create { if create {
if args.Policy.ID != "" { if args.Policy.ID != "" {
return nil, BadRequestError{Reason: "Cannot specify the ID when creating a new policy"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify the ID when creating a new policy"}
} }
} else { } else {
if args.Policy.ID != "" && args.Policy.ID != policyID { if args.Policy.ID != "" && args.Policy.ID != policyID {
return nil, BadRequestError{Reason: "Policy ID in URL and payload do not match"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Policy ID in URL and payload do not match"}
} else if args.Policy.ID == "" { } else if args.Policy.ID == "" {
args.Policy.ID = policyID args.Policy.ID = policyID
} }
@ -317,7 +317,7 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request)
fn = s.ACLTokenClone fn = s.ACLTokenClone
} }
if tokenID == "" && req.Method != "PUT" { if tokenID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing token ID"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing token ID"}
} }
return fn(resp, req, tokenID) return fn(resp, req, tokenID)
@ -422,12 +422,12 @@ func (s *HTTPHandlers) aclTokenSetInternal(req *http.Request, tokenID string, cr
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)}
} }
if !create { if !create {
if args.ACLToken.AccessorID != "" && args.ACLToken.AccessorID != tokenID { if args.ACLToken.AccessorID != "" && args.ACLToken.AccessorID != tokenID {
return nil, BadRequestError{Reason: "Token Accessor ID in URL and payload do not match"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Token Accessor ID in URL and payload do not match"}
} else if args.ACLToken.AccessorID == "" { } else if args.ACLToken.AccessorID == "" {
args.ACLToken.AccessorID = tokenID args.ACLToken.AccessorID = tokenID
} }
@ -472,7 +472,7 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request
return nil, err return nil, err
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)}
} }
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
@ -546,7 +546,7 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request)
return nil, err return nil, err
} }
if roleID == "" && req.Method != "PUT" { if roleID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing role ID"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role ID"}
} }
return fn(resp, req, roleID) return fn(resp, req, roleID)
@ -562,7 +562,7 @@ func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Req
return nil, err return nil, err
} }
if roleName == "" { if roleName == "" {
return nil, BadRequestError{Reason: "Missing role Name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role Name"}
} }
return s.ACLRoleRead(resp, req, "", roleName) return s.ACLRoleRead(resp, req, "", roleName)
@ -621,11 +621,11 @@ func (s *HTTPHandlers) ACLRoleWrite(resp http.ResponseWriter, req *http.Request,
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Role)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Role)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Role decoding failed: %v", err)}
} }
if args.Role.ID != "" && args.Role.ID != roleID { if args.Role.ID != "" && args.Role.ID != roleID {
return nil, BadRequestError{Reason: "Role ID in URL and payload do not match"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Role ID in URL and payload do not match"}
} else if args.Role.ID == "" { } else if args.Role.ID == "" {
args.Role.ID = roleID args.Role.ID = roleID
} }
@ -716,7 +716,7 @@ func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Re
return nil, err return nil, err
} }
if bindingRuleID == "" && req.Method != "PUT" { if bindingRuleID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing binding rule ID"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing binding rule ID"}
} }
return fn(resp, req, bindingRuleID) return fn(resp, req, bindingRuleID)
@ -770,11 +770,11 @@ func (s *HTTPHandlers) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.R
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.BindingRule)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.BindingRule)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
} }
if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID { if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID {
return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "BindingRule ID in URL and payload do not match"}
} else if args.BindingRule.ID == "" { } else if args.BindingRule.ID == "" {
args.BindingRule.ID = bindingRuleID args.BindingRule.ID = bindingRuleID
} }
@ -862,7 +862,7 @@ func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Req
return nil, err return nil, err
} }
if methodName == "" && req.Method != "PUT" { if methodName == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing auth method name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing auth method name"}
} }
return fn(resp, req, methodName) return fn(resp, req, methodName)
@ -916,12 +916,12 @@ func (s *HTTPHandlers) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Re
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.AuthMethod)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.AuthMethod)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
} }
if methodName != "" { if methodName != "" {
if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName { if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName {
return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "AuthMethod Name in URL and payload do not match"}
} else if args.AuthMethod.Name == "" { } else if args.AuthMethod.Name == "" {
args.AuthMethod.Name = methodName args.AuthMethod.Name = methodName
} }
@ -969,7 +969,7 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in
} }
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Auth)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Auth)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
} }
var out structs.ACLToken var out structs.ACLToken
@ -1058,11 +1058,11 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
s.parseDC(req, &request.Datacenter) s.parseDC(req, &request.Datacenter)
if err := decodeBody(req.Body, &request.Requests); err != nil { if err := decodeBody(req.Body, &request.Requests); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
} }
if len(request.Requests) > maxRequests { if len(request.Requests) > maxRequests {
return nil, BadRequestError{Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)}
} }
if len(request.Requests) == 0 { if len(request.Requests) == 0 {
@ -1083,7 +1083,7 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
responses, err = structs.CreateACLAuthorizationResponses(authz, request.Requests) responses, err = structs.CreateACLAuthorizationResponses(authz, request.Requests)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
} }
} }

View File

@ -26,6 +26,16 @@ import (
// They are not intended to thoroughly test the backing RPC // They are not intended to thoroughly test the backing RPC
// functionality as that will be done with other tests. // functionality as that will be done with other tests.
func isHTTPBadRequest(err error) bool {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 400 {
return false
}
return true
}
return false
}
func TestACL_Disabled_Response(t *testing.T) { func TestACL_Disabled_Response(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -71,7 +81,7 @@ func TestACL_Disabled_Response(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := tt.fn(resp, req) obj, err := tt.fn(resp, req)
require.Nil(t, obj) require.Nil(t, obj)
require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"}) require.ErrorIs(t, err, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"})
}) })
} }
} }
@ -270,8 +280,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCRUD(resp, req) _, err := a.srv.ACLPolicyCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Policy CRUD Missing ID in URL", func(t *testing.T) { t.Run("Policy CRUD Missing ID in URL", func(t *testing.T) {
@ -279,8 +288,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCRUD(resp, req) _, err := a.srv.ACLPolicyCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update", func(t *testing.T) { t.Run("Update", func(t *testing.T) {
@ -327,8 +335,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCreate(resp, req) _, err := a.srv.ACLPolicyCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Invalid payload", func(t *testing.T) { t.Run("Invalid payload", func(t *testing.T) {
@ -339,8 +346,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCreate(resp, req) _, err := a.srv.ACLPolicyCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Delete", func(t *testing.T) { t.Run("Delete", func(t *testing.T) {
@ -497,8 +503,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCRUD(resp, req) _, err := a.srv.ACLRoleCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Role CRUD Missing ID in URL", func(t *testing.T) { t.Run("Role CRUD Missing ID in URL", func(t *testing.T) {
@ -506,8 +511,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCRUD(resp, req) _, err := a.srv.ACLRoleCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update", func(t *testing.T) { t.Run("Update", func(t *testing.T) {
@ -567,8 +571,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCreate(resp, req) _, err := a.srv.ACLRoleCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Invalid payload", func(t *testing.T) { t.Run("Invalid payload", func(t *testing.T) {
@ -579,8 +582,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCreate(resp, req) _, err := a.srv.ACLRoleCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Delete", func(t *testing.T) { t.Run("Delete", func(t *testing.T) {
@ -818,8 +820,7 @@ func TestACL_HTTP(t *testing.T) {
obj, err := a.srv.ACLTokenCRUD(resp, req) obj, err := a.srv.ACLTokenCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
require.Nil(t, obj) require.Nil(t, obj)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update Accessor Mismatch", func(t *testing.T) { t.Run("Update Accessor Mismatch", func(t *testing.T) {
originalToken := tokenMap[idMap["token-cloned"]] originalToken := tokenMap[idMap["token-cloned"]]
@ -841,8 +842,7 @@ func TestACL_HTTP(t *testing.T) {
obj, err := a.srv.ACLTokenCRUD(resp, req) obj, err := a.srv.ACLTokenCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
require.Nil(t, obj) require.Nil(t, obj)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Delete", func(t *testing.T) { t.Run("Delete", func(t *testing.T) {
req, _ := http.NewRequest("DELETE", "/v1/acl/token/"+idMap["token-cloned"]+"?token=root", nil) req, _ := http.NewRequest("DELETE", "/v1/acl/token/"+idMap["token-cloned"]+"?token=root", nil)
@ -1284,8 +1284,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCRUD(resp, req) _, err := a.srv.ACLAuthMethodCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update Name URL Mismatch", func(t *testing.T) { t.Run("Update Name URL Mismatch", func(t *testing.T) {
@ -1302,8 +1301,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCRUD(resp, req) _, err := a.srv.ACLAuthMethodCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update", func(t *testing.T) { t.Run("Update", func(t *testing.T) {
@ -1342,8 +1340,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCreate(resp, req) _, err := a.srv.ACLAuthMethodCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("List", func(t *testing.T) { t.Run("List", func(t *testing.T) {
@ -1480,8 +1477,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCRUD(resp, req) _, err := a.srv.ACLBindingRuleCRUD(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Update", func(t *testing.T) { t.Run("Update", func(t *testing.T) {
@ -1529,8 +1525,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCreate(resp, req) _, err := a.srv.ACLBindingRuleCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("Invalid payload", func(t *testing.T) { t.Run("Invalid payload", func(t *testing.T) {
@ -1541,8 +1536,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCreate(resp, req) _, err := a.srv.ACLBindingRuleCreate(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
}) })
t.Run("List", func(t *testing.T) { t.Run("List", func(t *testing.T) {

View File

@ -425,7 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request)
svcState := s.agent.State.ServiceState(sid) svcState := s.agent.State.ServiceState(sid)
if svcState == nil { if svcState == nil {
return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())} return "", nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
} }
svc := svcState.Service svc := svcState.Service
@ -555,7 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request)
// key are ok, otherwise the argument doesn't apply to // key are ok, otherwise the argument doesn't apply to
// the WAN. // the WAN.
default: default:
return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot provide a segment with wan=true"}
} }
} }
@ -732,16 +732,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
} }
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Verify the check has a name. // Verify the check has a name.
if args.Name == "" { if args.Name == "" {
return nil, BadRequestError{Reason: "Missing check name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check name"}
} }
if args.Status != "" && !structs.ValidStatus(args.Status) { if args.Status != "" && !structs.ValidStatus(args.Status) {
return nil, BadRequestError{Reason: "Bad check status"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Bad check status"}
} }
authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil) authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil)
@ -760,7 +760,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
chkType := args.CheckType() chkType := args.CheckType()
err = chkType.Validate() err = chkType.Validate()
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)}
} }
// Store the type of check based on the definition // Store the type of check based on the definition
@ -773,7 +773,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
if service != nil { if service != nil {
health.ServiceName = service.Service health.ServiceName = service.Service
} else { } else {
return nil, NotFoundError{fmt.Sprintf("ServiceID %q does not exist", cid.String())} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("ServiceID %q does not exist", cid.String())}
} }
} }
@ -878,7 +878,7 @@ type checkUpdate struct {
func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate var update checkUpdate
if err := decodeBody(req.Body, &update); err != nil { if err := decodeBody(req.Body, &update); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
switch update.Status { switch update.Status {
@ -886,7 +886,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ
case api.HealthWarning: case api.HealthWarning:
case api.HealthCritical: case api.HealthCritical:
default: default:
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
} }
ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/") ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/")
@ -981,7 +981,7 @@ func (s *HTTPHandlers) AgentHealthServiceByID(resp http.ResponseWriter, req *htt
return nil, err return nil, err
} }
if serviceID == "" { if serviceID == "" {
return nil, &BadRequestError{Reason: "Missing serviceID"} return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing serviceID"}
} }
var entMeta acl.EnterpriseMeta var entMeta acl.EnterpriseMeta
@ -1043,7 +1043,7 @@ func (s *HTTPHandlers) AgentHealthServiceByName(resp http.ResponseWriter, req *h
} }
if serviceName == "" { if serviceName == "" {
return nil, &BadRequestError{Reason: "Missing service Name"} return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service Name"}
} }
var entMeta acl.EnterpriseMeta var entMeta acl.EnterpriseMeta
@ -1114,18 +1114,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
} }
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Verify the service has a name. // Verify the service has a name.
if args.Name == "" { if args.Name == "" {
return nil, BadRequestError{Reason: "Missing service name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
} }
// Check the service address here and in the catalog RPC endpoint // Check the service address here and in the catalog RPC endpoint
// since service registration isn't synchronous. // since service registration isn't synchronous.
if ipaddr.IsAny(args.Address) { if ipaddr.IsAny(args.Address) {
return nil, BadRequestError{Reason: "Invalid service address"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid service address"}
} }
var token string var token string
@ -1144,27 +1144,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
ns := args.NodeService() ns := args.NodeService()
if ns.Weights != nil { if ns.Weights != nil {
if err := structs.ValidateWeights(ns.Weights); err != nil { if err := structs.ValidateWeights(ns.Weights); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Weights: %v", err)}
} }
} }
if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil { if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
} }
// Run validation. This is the same validation that would happen on // Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly. // the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil { if err := ns.Validate(); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
} }
// Verify the check type. // Verify the check type.
chkTypes, err := args.CheckTypes() chkTypes, err := args.CheckTypes()
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)}
} }
for _, check := range chkTypes { for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) { if check.Status != "" && !structs.ValidStatus(check.Status) {
return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"}
} }
} }
@ -1172,15 +1172,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
if args.Connect != nil && args.Connect.SidecarService != nil { if args.Connect != nil && args.Connect.SidecarService != nil {
chkTypes, err := args.Connect.SidecarService.CheckTypes() chkTypes, err := args.Connect.SidecarService.CheckTypes()
if err != nil { if err != nil {
return nil, &BadRequestError{ return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err)}
Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err),
}
} }
for _, check := range chkTypes { for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) { if check.Status != "" && !structs.ValidStatus(check.Status) {
return nil, &BadRequestError{ return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"}
Reason: "Status for checks must 'passing', 'warning', 'critical'",
}
} }
} }
} }
@ -1193,12 +1189,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
// See if we have a sidecar to register too // See if we have a sidecar to register too
sidecar, sidecarChecks, sidecarToken, err := s.agent.sidecarServiceFromNodeService(ns, token) sidecar, sidecarChecks, sidecarToken, err := s.agent.sidecarServiceFromNodeService(ns, token)
if err != nil { if err != nil {
return nil, &BadRequestError{ return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid SidecarService: %s", err)}
Reason: fmt.Sprintf("Invalid SidecarService: %s", err)}
} }
if sidecar != nil { if sidecar != nil {
if err := sidecar.Validate(); err != nil { if err := sidecar.Validate(); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
} }
// Make sure we are allowed to register the sidecar using the token // Make sure we are allowed to register the sidecar using the token
// specified (might be specific to sidecar or the same one as the overall // specified (might be specific to sidecar or the same one as the overall
@ -1299,19 +1294,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
sid := structs.NewServiceID(serviceID, nil) sid := structs.NewServiceID(serviceID, nil)
if sid.ID == "" { if sid.ID == "" {
return nil, BadRequestError{Reason: "Missing service ID"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service ID"}
} }
// Ensure we have some action // Ensure we have some action
params := req.URL.Query() params := req.URL.Query()
if _, ok := params["enable"]; !ok { if _, ok := params["enable"]; !ok {
return nil, BadRequestError{Reason: "Missing value for enable"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"}
} }
raw := params.Get("enable") raw := params.Get("enable")
enable, err := strconv.ParseBool(raw) enable, err := strconv.ParseBool(raw)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
} }
// Get the provided token, if any, and vet against any ACL policies. // Get the provided token, if any, and vet against any ACL policies.
@ -1340,11 +1335,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
if enable { if enable {
reason := params.Get("reason") reason := params.Get("reason")
if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil { if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
} else { } else {
if err = s.agent.DisableServiceMaintenance(sid); err != nil { if err = s.agent.DisableServiceMaintenance(sid); err != nil {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
} }
s.syncChanges() s.syncChanges()
@ -1355,13 +1350,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http.
// Ensure we have some action // Ensure we have some action
params := req.URL.Query() params := req.URL.Query()
if _, ok := params["enable"]; !ok { if _, ok := params["enable"]; !ok {
return nil, BadRequestError{Reason: "Missing value for enable"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"}
} }
raw := params.Get("enable") raw := params.Get("enable")
enable, err := strconv.ParseBool(raw) enable, err := strconv.ParseBool(raw)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
} }
// Get the provided token, if any, and vet against any ACL policies. // Get the provided token, if any, and vet against any ACL policies.
@ -1416,9 +1411,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
} }
if !logging.ValidateLogLevel(logLevel) { if !logging.ValidateLogLevel(logLevel) {
return nil, BadRequestError{ return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unknown log level: %s", logLevel)}
Reason: fmt.Sprintf("Unknown log level: %s", logLevel),
}
} }
flusher, ok := resp.(http.Flusher) flusher, ok := resp.(http.Flusher)
@ -1469,7 +1462,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) { func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled() { if s.checkACLDisabled() {
return nil, UnauthorizedError{Reason: "ACL support disabled"} return nil, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"}
} }
// Fetch the ACL token, if any, and enforce agent policy. // Fetch the ACL token, if any, and enforce agent policy.
@ -1491,7 +1484,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (
// fields to this later if needed. // fields to this later if needed.
var args api.AgentToken var args api.AgentToken
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Figure out the target token. // Figure out the target token.
@ -1522,7 +1515,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (
s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI) s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI)
default: default:
return NotFoundError{Reason: fmt.Sprintf("Token %q is unknown", target)} return HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Token %q is unknown", target)}
} }
// TODO: is it safe to move this out of WithPersistenceLock? // TODO: is it safe to move this out of WithPersistenceLock?
@ -1641,7 +1634,7 @@ func (s *HTTPHandlers) AgentConnectAuthorize(resp http.ResponseWriter, req *http
} }
if err := decodeBody(req.Body, &authReq); err != nil { if err := decodeBody(req.Body, &authReq); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
if !s.validateRequestPartition(resp, &authReq.EnterpriseMeta) { if !s.validateRequestPartition(resp, &authReq.EnterpriseMeta) {

View File

@ -5473,7 +5473,7 @@ func TestAgent_Token(t *testing.T) {
url: "acl_token?token=root", url: "acl_token?token=root",
body: badJSON(), body: badJSON(),
code: http.StatusBadRequest, code: http.StatusBadRequest,
expectedErr: `Bad request: Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`, expectedErr: `Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`,
}, },
{ {
name: "set user legacy", name: "set user legacy",

View File

@ -136,7 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque
} }
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Setup the default DC if not provided // Setup the default DC if not provided
@ -166,7 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req
return nil, err return nil, err
} }
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Setup the default DC if not provided // Setup the default DC if not provided
@ -363,7 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
} }
// Make the RPC request // Make the RPC request
@ -438,7 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if args.Node == "" { if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
} }
// Make the RPC request // Make the RPC request
@ -503,7 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt
return nil, err return nil, err
} }
if args.Node == "" { if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
} }
// Make the RPC request // Make the RPC request
@ -554,7 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing gateway name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
} }
// Make the RPC request // Make the RPC request

View File

@ -56,7 +56,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i
setMeta(resp, &reply.QueryMeta) setMeta(resp, &reply.QueryMeta)
if reply.Entry == nil { if reply.Entry == nil {
return nil, NotFoundError{Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])}
} }
return reply.Entry, nil return reply.Entry, nil
@ -75,7 +75,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i
return reply.Entries, nil return reply.Entries, nil
default: default:
return nil, NotFoundError{Reason: "Must provide either a kind or both kind and name"} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide either a kind or both kind and name"}
} }
} }
@ -91,12 +91,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request)
pathArgs := strings.SplitN(kindAndName, "/", 2) pathArgs := strings.SplitN(kindAndName, "/", 2)
if len(pathArgs) != 2 { if len(pathArgs) != 2 {
return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide both a kind and name to delete"}
} }
entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1]) entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1])
if err != nil { if err != nil {
return nil, BadRequestError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
} }
args.Entry = entry args.Entry = entry
// Parse enterprise meta. // Parse enterprise meta.
@ -139,13 +139,13 @@ func (s *HTTPHandlers) ConfigApply(resp http.ResponseWriter, req *http.Request)
var raw map[string]interface{} var raw map[string]interface{}
if err := decodeBodyDeprecated(req, &raw, nil); err != nil { if err := decodeBodyDeprecated(req, &raw, nil); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
} }
if entry, err := structs.DecodeConfigEntry(raw); err == nil { if entry, err := structs.DecodeConfigEntry(raw); err == nil {
args.Entry = entry args.Entry = entry
} else { } else {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
} }
// Parse enterprise meta. // Parse enterprise meta.

View File

@ -601,9 +601,8 @@ func TestConfig_Apply_Decoding(t *testing.T) {
_, err := a.srv.ConfigApply(resp, req) _, err := a.srv.ConfigApply(resp, req)
require.Error(t, err) require.Error(t, err)
badReq, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok) require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", err.Error())
require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", badReq.Reason)
}) })
t.Run("Kind Not String", func(t *testing.T) { t.Run("Kind Not String", func(t *testing.T) {
@ -619,9 +618,8 @@ func TestConfig_Apply_Decoding(t *testing.T) {
_, err := a.srv.ConfigApply(resp, req) _, err := a.srv.ConfigApply(resp, req)
require.Error(t, err) require.Error(t, err)
badReq, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok) require.Equal(t, "Request decoding failed: Kind value in payload is not a string", err.Error())
require.Equal(t, "Request decoding failed: Kind value in payload is not a string", badReq.Reason)
}) })
t.Run("Lowercase kind", func(t *testing.T) { t.Run("Lowercase kind", func(t *testing.T) {

View File

@ -3,6 +3,7 @@ package agent
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -23,7 +24,7 @@ import (
// The ACL token and the auth request are provided and the auth decision (true // The ACL token and the auth request are provided and the auth decision (true
// means authorized) and reason string are returned. // means authorized) and reason string are returned.
// //
// If the request input is invalid the error returned will be a BadRequestError, // If the request input is invalid the error returned will be a BadRequest HTTPError,
// if the token doesn't grant necessary access then an acl.ErrPermissionDenied // if the token doesn't grant necessary access then an acl.ErrPermissionDenied
// error is returned, otherwise error indicates an unexpected server failure. If // error is returned, otherwise error indicates an unexpected server failure. If
// access is denied, no error is returned but the first return value is false. // access is denied, no error is returned but the first return value is false.
@ -37,23 +38,23 @@ func (a *Agent) ConnectAuthorize(token string,
} }
if req == nil { if req == nil {
return returnErr(BadRequestError{"Invalid request"}) return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid request"})
} }
// We need to have a target to check intentions // We need to have a target to check intentions
if req.Target == "" { if req.Target == "" {
return returnErr(BadRequestError{"Target service must be specified"}) return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Target service must be specified"})
} }
// Parse the certificate URI from the client ID // Parse the certificate URI from the client ID
uri, err := connect.ParseCertURIFromString(req.ClientCertURI) uri, err := connect.ParseCertURIFromString(req.ClientCertURI)
if err != nil { if err != nil {
return returnErr(BadRequestError{"ClientCertURI not a valid Connect identifier"}) return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Connect identifier"})
} }
uriService, ok := uri.(*connect.SpiffeIDService) uriService, ok := uri.(*connect.SpiffeIDService)
if !ok { if !ok {
return returnErr(BadRequestError{"ClientCertURI not a valid Service identifier"}) return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Service identifier"})
} }
// We need to verify service:write permissions for the given token. // We need to verify service:write permissions for the given token.

View File

@ -20,7 +20,7 @@ func (s *HTTPHandlers) ConnectCARoots(resp http.ResponseWriter, req *http.Reques
if pemParam := req.URL.Query().Get("pem"); pemParam != "" { if pemParam := req.URL.Query().Get("pem"); pemParam != "" {
val, err := strconv.ParseBool(pemParam) val, err := strconv.ParseBool(pemParam)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: "The 'pem' query parameter must be a boolean value"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The 'pem' query parameter must be a boolean value"}
} }
pemResponse = val pemResponse = val
} }
@ -90,15 +90,14 @@ func (s *HTTPHandlers) ConnectCAConfigurationSet(req *http.Request) (interface{}
s.parseDC(req, &args.Datacenter) s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Config); err != nil { if err := decodeBody(req.Body, &args.Config); err != nil {
return nil, BadRequestError{ return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
Reason: fmt.Sprintf("Request decode failed: %v", err),
}
} }
var reply interface{} var reply interface{}
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply) err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
if err != nil && err.Error() == consul.ErrStateReadOnly.Error() { if err != nil && err.Error() == consul.ErrStateReadOnly.Error() {
return nil, BadRequestError{ return nil, HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Provider State is read-only. It must be omitted" + Reason: "Provider State is read-only. It must be omitted" +
" or identical to the current value", " or identical to the current value",
} }

View File

@ -14,7 +14,7 @@ func (s *HTTPHandlers) checkCoordinateDisabled() error {
if !s.agent.config.DisableCoordinates { if !s.agent.config.DisableCoordinates {
return nil return nil
} }
return UnauthorizedError{Reason: "Coordinate support disabled"} return HTTPError{StatusCode: http.StatusUnauthorized, Reason: "Coordinate support disabled"}
} }
// sorter wraps a coordinate list and implements the sort.Interface to sort by // sorter wraps a coordinate list and implements the sort.Interface to sort by
@ -156,7 +156,7 @@ func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Requ
args := structs.CoordinateUpdateRequest{} args := structs.CoordinateUpdateRequest{}
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
s.parseDC(req, &args.Datacenter) s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)

View File

@ -39,10 +39,14 @@ func TestCoordinate_Disabled_Response(t *testing.T) {
req, _ := http.NewRequest("PUT", "/should/not/care", nil) req, _ := http.NewRequest("PUT", "/should/not/care", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := tt(resp, req) obj, err := tt(resp, req)
err, ok := err.(UnauthorizedError) if err, ok := err.(HTTPError); ok {
if !ok { if err.StatusCode != 401 {
t.Fatalf("expected unauthorized error but got %v", err) t.Fatalf("expected status 401 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
} }
if obj != nil { if obj != nil {
t.Fatalf("bad: %#v", obj) t.Fatalf("bad: %#v", obj)
} }

View File

@ -25,7 +25,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
return nil, err return nil, err
} }
if args.Name == "" { if args.Name == "" {
return nil, BadRequestError{Reason: "Missing chain name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing chain name"}
} }
args.EvaluateInDatacenter = req.URL.Query().Get("compile-dc") args.EvaluateInDatacenter = req.URL.Query().Get("compile-dc")
@ -38,12 +38,12 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
if req.Method == "POST" { if req.Method == "POST" {
var raw map[string]interface{} var raw map[string]interface{}
if err := decodeBody(req.Body, &raw); err != nil { if err := decodeBody(req.Body, &raw); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
} }
apiReq, err := decodeDiscoveryChainReadRequest(raw) apiReq, err := decodeDiscoveryChainReadRequest(raw)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
} }
args.OverrideProtocol = apiReq.OverrideProtocol args.OverrideProtocol = apiReq.OverrideProtocol
@ -52,7 +52,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
if apiReq.OverrideMeshGateway.Mode != "" { if apiReq.OverrideMeshGateway.Mode != "" {
_, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode)) _, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode))
if err != nil { if err != nil {
return nil, BadRequestError{Reason: "Invalid OverrideMeshGateway.Mode parameter"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid OverrideMeshGateway.Mode parameter"}
} }
args.OverrideMeshGateway = apiReq.OverrideMeshGateway args.OverrideMeshGateway = apiReq.OverrideMeshGateway
} }

View File

@ -57,8 +57,7 @@ func TestDiscoveryChainRead(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err = a.srv.DiscoveryChainRead(resp, req) _, err = a.srv.DiscoveryChainRead(resp, req)
require.Error(t, err) require.Error(t, err)
_, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err))
require.True(t, ok)
})) }))
require.True(t, t.Run(method+": read default chain", func(t *testing.T) { require.True(t, t.Run(method+": read default chain", func(t *testing.T) {

View File

@ -25,7 +25,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
return nil, err return nil, err
} }
if event.Name == "" { if event.Name == "" {
return nil, BadRequestError{Reason: "Missing name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing name"}
} }
// Get the ACL token // Get the ACL token
@ -55,7 +55,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
// Try to fire the event // Try to fire the event
if err := s.agent.UserEvent(dc, token, event); err != nil { if err := s.agent.UserEvent(dc, token, event); err != nil {
if acl.IsErrPermissionDenied(err) { if acl.IsErrPermissionDenied(err) {
return nil, ForbiddenError{Reason: acl.ErrPermissionDenied.Error()} return nil, HTTPError{StatusCode: http.StatusForbidden, Reason: acl.ErrPermissionDenied.Error()}
} }
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
return nil, err return nil, err

View File

@ -103,8 +103,12 @@ func TestEventFire_token(t *testing.T) {
if !acl.IsErrPermissionDenied(err) { if !acl.IsErrPermissionDenied(err) {
t.Fatalf("bad: %s", err.Error()) t.Fatalf("bad: %s", err.Error())
} }
if err, ok := err.(ForbiddenError); !ok { if err, ok := err.(HTTPError); ok {
t.Fatalf("Expected forbidden but got %v", err) if err.StatusCode != 403 {
t.Fatalf("Expected 403 but got %d", err.StatusCode)
}
} else {
t.Fatalf("Expected HTTP Error %v", err)
} }
} }
} }

View File

@ -13,7 +13,7 @@ func (s *HTTPHandlers) FederationStateGet(resp http.ResponseWriter, req *http.Re
return nil, err return nil, err
} }
if datacenterName == "" { if datacenterName == "" {
return nil, BadRequestError{Reason: "Missing datacenter name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing datacenter name"}
} }
args := structs.FederationStateQuery{ args := structs.FederationStateQuery{

View File

@ -34,7 +34,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if args.State == "" { if args.State == "" {
return nil, BadRequestError{Reason: "Missing check state"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check state"}
} }
// Make the RPC request // Make the RPC request
@ -82,7 +82,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ
return nil, err return nil, err
} }
if args.Node == "" { if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
} }
// Make the RPC request // Make the RPC request
@ -132,7 +132,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
} }
// Make the RPC request // Make the RPC request
@ -224,7 +224,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
} }
out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args) out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args)
@ -242,7 +242,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
// Filter to only passing if specified // Filter to only passing if specified
filter, err := getBoolQueryParam(params, api.HealthPassing) filter, err := getBoolQueryParam(params, api.HealthPassing)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: "Invalid value for ?passing"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid value for ?passing"}
} }
// FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer // FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer

View File

@ -1405,9 +1405,7 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil) req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.HealthServiceNodes(resp, req) _, err := a.srv.HealthServiceNodes(resp, req)
if _, ok := err.(BadRequestError); !ok { require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err))
t.Fatalf("Expected bad request error but got %v", err)
}
if !strings.Contains(err.Error(), "Invalid value for ?passing") { if !strings.Contains(err.Error(), "Invalid value for ?passing") {
t.Errorf("bad %s", err.Error()) t.Errorf("bad %s", err.Error())
} }
@ -1813,8 +1811,7 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.HealthConnectServiceNodes(resp, req) _, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.NotNil(t, err) assert.NotNil(t, err)
_, ok := err.(BadRequestError) assert.True(t, isHTTPBadRequest(err))
assert.True(t, ok)
assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing")) assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing"))
}) })

View File

@ -51,41 +51,6 @@ func (e MethodNotAllowedError) Error() string {
return fmt.Sprintf("method %s not allowed", e.Method) return fmt.Sprintf("method %s not allowed", e.Method)
} }
// BadRequestError should be returned by a handler when parameters or the payload are not valid
type BadRequestError struct {
Reason string
}
func (e BadRequestError) Error() string {
return fmt.Sprintf("Bad request: %s", e.Reason)
}
// NotFoundError should be returned by a handler when a resource specified does not exist
type NotFoundError struct {
Reason string
}
func (e NotFoundError) Error() string {
return e.Reason
}
// UnauthorizedError should be returned by a handler when the request lacks valid authorization.
type UnauthorizedError struct {
Reason string
}
func (e UnauthorizedError) Error() string {
return e.Reason
}
type EntityTooLargeError struct {
Reason string
}
func (e EntityTooLargeError) Error() string {
return e.Reason
}
// CodeWithPayloadError allow returning non HTTP 200 // CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload // Error codes while not returning PlainText payload
type CodeWithPayloadError struct { type CodeWithPayloadError struct {
@ -98,12 +63,15 @@ func (e CodeWithPayloadError) Error() string {
return e.Reason return e.Reason
} }
type ForbiddenError struct { // HTTPError is returned by the handler when a specific http error
Reason string // code is needed alongside a plain text response.
type HTTPError struct {
StatusCode int
Reason string
} }
func (e ForbiddenError) Error() string { func (h HTTPError) Error() string {
return e.Reason return h.Reason
} }
// HTTPHandlers provides an HTTP api for an agent. // HTTPHandlers provides an HTTP api for an agent.
@ -423,8 +391,7 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
if acl.IsErrPermissionDenied(err) || acl.IsErrNotFound(err) { if acl.IsErrPermissionDenied(err) || acl.IsErrNotFound(err) {
return true return true
} }
_, ok := err.(ForbiddenError) return false
return ok
} }
isMethodNotAllowed := func(err error) bool { isMethodNotAllowed := func(err error) bool {
@ -432,35 +399,20 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
return ok return ok
} }
isBadRequest := func(err error) bool {
_, ok := err.(BadRequestError)
return ok
}
isNotFound := func(err error) bool {
_, ok := err.(NotFoundError)
return ok
}
isUnauthorized := func(err error) bool {
_, ok := err.(UnauthorizedError)
return ok
}
isTooManyRequests := func(err error) bool { isTooManyRequests := func(err error) bool {
// Sadness net/rpc can't do nice typed errors so this is all we got // Sadness net/rpc can't do nice typed errors so this is all we got
return err.Error() == consul.ErrRateLimited.Error() return err.Error() == consul.ErrRateLimited.Error()
} }
isEntityToLarge := func(err error) bool {
_, ok := err.(EntityTooLargeError)
return ok
}
addAllowHeader := func(methods []string) { addAllowHeader := func(methods []string) {
resp.Header().Add("Allow", strings.Join(methods, ",")) resp.Header().Add("Allow", strings.Join(methods, ","))
} }
isHTTPError := func(err error) bool {
_, ok := err.(HTTPError)
return ok
}
handleErr := func(err error) { handleErr := func(err error) {
if req.Context().Err() != nil { if req.Context().Err() != nil {
httpLogger.Info("Request cancelled", httpLogger.Info("Request cancelled",
@ -490,21 +442,21 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
addAllowHeader(err.(MethodNotAllowedError).Allow) addAllowHeader(err.(MethodNotAllowedError).Allow)
resp.WriteHeader(http.StatusMethodNotAllowed) // 405 resp.WriteHeader(http.StatusMethodNotAllowed) // 405
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())
case isBadRequest(err): case isHTTPError(err):
resp.WriteHeader(http.StatusBadRequest) err := err.(HTTPError)
fmt.Fprint(resp, err.Error()) code := http.StatusInternalServerError
case isNotFound(err): if err.StatusCode != 0 {
resp.WriteHeader(http.StatusNotFound) code = err.StatusCode
fmt.Fprint(resp, err.Error()) }
case isUnauthorized(err): reason := "An unexpected error occurred"
resp.WriteHeader(http.StatusUnauthorized) if err.Error() != "" {
fmt.Fprint(resp, err.Error()) reason = err.Error()
}
resp.WriteHeader(code)
fmt.Fprint(resp, reason)
case isTooManyRequests(err): case isTooManyRequests(err):
resp.WriteHeader(http.StatusTooManyRequests) resp.WriteHeader(http.StatusTooManyRequests)
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())
case isEntityToLarge(err):
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprint(resp, err.Error())
default: default:
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())
@ -1175,7 +1127,7 @@ func (s *HTTPHandlers) checkWriteAccess(req *http.Request) error {
} }
} }
return ForbiddenError{Reason: "Access is restricted"} return HTTPError{StatusCode: http.StatusForbidden, Reason: "Access is restricted"}
} }
func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) { func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) {

View File

@ -14,10 +14,16 @@ import (
func (s *HTTPHandlers) parseEntMeta(req *http.Request, entMeta *acl.EnterpriseMeta) error { func (s *HTTPHandlers) parseEntMeta(req *http.Request, entMeta *acl.EnterpriseMeta) error {
if headerNS := req.Header.Get("X-Consul-Namespace"); headerNS != "" { if headerNS := req.Header.Get("X-Consul-Namespace"); headerNS != "" {
return BadRequestError{Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature",
}
} }
if queryNS := req.URL.Query().Get("ns"); queryNS != "" { if queryNS := req.URL.Query().Get("ns"); queryNS != "" {
return BadRequestError{Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature",
}
} }
return s.parseEntMetaPartition(req, entMeta) return s.parseEntMetaPartition(req, entMeta)
@ -32,7 +38,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionPartition(logName, partition s
// No special handling for wildcard namespaces as they are pointless in OSS. // No special handling for wildcard namespaces as they are pointless in OSS.
return BadRequestError{Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature",
}
} }
func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string, _ bool) error { func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string, _ bool) error {
@ -44,7 +53,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string,
// No special handling for wildcard namespaces as they are pointless in OSS. // No special handling for wildcard namespaces as they are pointless in OSS.
return BadRequestError{Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature",
}
} }
func (s *HTTPHandlers) parseEntMetaNoWildcard(req *http.Request, _ *acl.EnterpriseMeta) error { func (s *HTTPHandlers) parseEntMetaNoWildcard(req *http.Request, _ *acl.EnterpriseMeta) error {
@ -72,7 +84,10 @@ func (s *HTTPHandlers) rewordUnknownEnterpriseFieldError(err error) error {
func parseACLAuthMethodEnterpriseMeta(req *http.Request, _ *structs.ACLAuthMethodEnterpriseMeta) error { func parseACLAuthMethodEnterpriseMeta(req *http.Request, _ *structs.ACLAuthMethodEnterpriseMeta) error {
if methodNS := req.URL.Query().Get("authmethod-ns"); methodNS != "" { if methodNS := req.URL.Query().Get("authmethod-ns"); methodNS != "" {
return BadRequestError{Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature",
}
} }
return nil return nil
@ -91,10 +106,16 @@ func (s *HTTPHandlers) uiTemplateDataTransform(data map[string]interface{}) erro
func (s *HTTPHandlers) parseEntMetaPartition(req *http.Request, meta *acl.EnterpriseMeta) error { func (s *HTTPHandlers) parseEntMetaPartition(req *http.Request, meta *acl.EnterpriseMeta) error {
if headerAP := req.Header.Get("X-Consul-Partition"); headerAP != "" { if headerAP := req.Header.Get("X-Consul-Partition"); headerAP != "" {
return BadRequestError{Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature",
}
} }
if queryAP := req.URL.Query().Get("partition"); queryAP != "" { if queryAP := req.URL.Query().Get("partition"); queryAP != "" {
return BadRequestError{Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature"} return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature",
}
} }
return nil return nil

View File

@ -1488,10 +1488,16 @@ func TestAllowedNets(t *testing.T) {
t.Fatalf("bad checkWriteAccess for values %+v, got %v", v, err) t.Fatalf("bad checkWriteAccess for values %+v, got %v", v, err)
} }
_, isForbiddenErr := err.(ForbiddenError) if err != nil {
if err != nil && !isForbiddenErr { if err, ok := err.(HTTPError); ok {
t.Fatalf("expected ForbiddenError but got: %s", err) if err.StatusCode != 403 {
t.Fatalf("expected 403 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP Error but got %v", err)
}
} }
} }
} }

View File

@ -56,8 +56,9 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque
if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil { if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil {
return nil, err return nil, err
} }
if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") { if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") {
return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"}
} }
args := structs.IntentionRequest{ args := structs.IntentionRequest{
@ -70,10 +71,10 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque
} }
if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" { if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"}
} }
if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" { if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"}
} }
args.Intention.FillPartitionAndNamespace(&entMeta, false) args.Intention.FillPartitionAndNamespace(&entMeta, false)
@ -324,7 +325,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type // We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() { if err.Error() == consul.ErrIntentionNotFound.Error() {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
// Not ideal, but there are a number of error scenarios that are not // Not ideal, but there are a number of error scenarios that are not
@ -332,7 +333,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req
// to detect a parameter error and return a 400 response. The error // to detect a parameter error and return a 400 response. The error
// is not a constant type or message, so we have to use strings.Contains // is not a constant type or message, so we have to use strings.Contains
if strings.Contains(err.Error(), "UUID") { if strings.Contains(err.Error(), "UUID") {
return nil, BadRequestError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
} }
return nil, err return nil, err
@ -366,7 +367,7 @@ func (s *HTTPHandlers) IntentionPutExact(resp http.ResponseWriter, req *http.Req
s.parseDC(req, &args.Datacenter) s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Intention); err != nil { if err := decodeBody(req.Body, &args.Intention); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
// Explicitly CLEAR the old legacy ID field // Explicitly CLEAR the old legacy ID field
@ -520,7 +521,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter,
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil { if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type // We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() { if err.Error() == consul.ErrIntentionNotFound.Error() {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
// Not ideal, but there are a number of error scenarios that are not // Not ideal, but there are a number of error scenarios that are not
@ -528,7 +529,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter,
// to detect a parameter error and return a 400 response. The error // to detect a parameter error and return a 400 response. The error
// is not a constant type or message, so we have to use strings.Contains // is not a constant type or message, so we have to use strings.Contains
if strings.Contains(err.Error(), "UUID") { if strings.Contains(err.Error(), "UUID") {
return nil, BadRequestError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
} }
return nil, err return nil, err
@ -552,8 +553,9 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit
if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil { if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil {
return nil, err return nil, err
} }
if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") { if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") {
return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"}
} }
args := structs.IntentionRequest{ args := structs.IntentionRequest{
@ -562,14 +564,14 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit
s.parseDC(req, &args.Datacenter) s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Intention); err != nil { if err := decodeBody(req.Body, &args.Intention); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" { if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"}
} }
if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" { if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"}
} }
args.Intention.FillPartitionAndNamespace(&entMeta, false) args.Intention.FillPartitionAndNamespace(&entMeta, false)

View File

@ -483,7 +483,7 @@ func TestIntentionSpecificGet(t *testing.T) {
obj, err := a.srv.IntentionSpecific(resp, req) obj, err := a.srv.IntentionSpecific(resp, req)
require.Nil(t, obj) require.Nil(t, obj)
require.Error(t, err) require.Error(t, err)
require.IsType(t, BadRequestError{}, err) require.True(t, isHTTPBadRequest(err))
require.Contains(t, err.Error(), "UUID") require.Contains(t, err.Error(), "UUID")
}) })

View File

@ -56,7 +56,7 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args
if _, ok := params["recurse"]; ok { if _, ok := params["recurse"]; ok {
method = "KVS.List" method = "KVS.List"
} else if args.Key == "" { } else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
} }
// Do not allow wildcard NS on GET reqs // Do not allow wildcard NS on GET reqs
@ -157,7 +157,7 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
return nil, err return nil, err
} }
if args.Key == "" { if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
} }
if conflictingFlags(resp, req, "cas", "acquire", "release") { if conflictingFlags(resp, req, "cas", "acquire", "release") {
return nil, nil return nil, nil
@ -208,7 +208,8 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
// Check the content-length // Check the content-length
if req.ContentLength > int64(s.agent.config.KVMaxValueSize) { if req.ContentLength > int64(s.agent.config.KVMaxValueSize) {
return nil, EntityTooLargeError{ return nil, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/config/config-files#kv_max_value_size"), req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/config/config-files#kv_max_value_size"),
} }
@ -257,7 +258,7 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar
if _, ok := params["recurse"]; ok { if _, ok := params["recurse"]; ok {
applyReq.Op = api.KVDeleteTree applyReq.Op = api.KVDeleteTree
} else if args.Key == "" { } else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
} }
// Check for cas value // Check for cas value

View File

@ -49,10 +49,13 @@ func (s *HTTPHandlers) OperatorRaftPeer(resp http.ResponseWriter, req *http.Requ
} }
if !hasID && !hasAddress { if !hasID && !hasAddress {
return nil, BadRequestError{Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove"} return nil, HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove",
}
} }
if hasID && hasAddress { if hasID && hasAddress {
return nil, BadRequestError{Reason: "Must specify only one of ?id or ?address"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Must specify only one of ?id or ?address"}
} }
var reply struct{} var reply struct{}
@ -79,7 +82,7 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
var args keyringArgs var args keyringArgs
if req.Method == "POST" || req.Method == "PUT" || req.Method == "DELETE" { if req.Method == "POST" || req.Method == "PUT" || req.Method == "DELETE" {
if err := decodeBody(req.Body, &args); err != nil { if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
} }
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
@ -88,12 +91,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
if relayFactor := req.URL.Query().Get("relay-factor"); relayFactor != "" { if relayFactor := req.URL.Query().Get("relay-factor"); relayFactor != "" {
n, err := strconv.Atoi(relayFactor) n, err := strconv.Atoi(relayFactor)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing relay factor: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing relay factor: %v", err)}
} }
args.RelayFactor, err = ParseRelayFactor(n) args.RelayFactor, err = ParseRelayFactor(n)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid relay-factor: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid relay-factor: %v", err)}
} }
} }
@ -102,12 +105,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
var err error var err error
args.LocalOnly, err = strconv.ParseBool(localOnly) args.LocalOnly, err = strconv.ParseBool(localOnly)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing local-only: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing local-only: %v", err)}
} }
err = ValidateLocalOnly(args.LocalOnly, req.Method == "GET") err = ValidateLocalOnly(args.LocalOnly, req.Method == "GET")
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid use of local-only: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid use of local-only: %v", err)}
} }
} }
@ -226,7 +229,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter,
conf := api.NewAutopilotConfiguration() conf := api.NewAutopilotConfiguration()
if err := decodeBody(req.Body, &conf); err != nil { if err := decodeBody(req.Body, &conf); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)}
} }
args.Config = structs.AutopilotConfig{ args.Config = structs.AutopilotConfig{
@ -245,7 +248,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter,
if _, ok := params["cas"]; ok { if _, ok := params["cas"]; ok {
casVal, err := strconv.ParseUint(params.Get("cas"), 10, 64) casVal, err := strconv.ParseUint(params.Get("cas"), 10, 64)
if err != nil { if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing cas value: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing cas value: %v", err)}
} }
args.Config.ModifyIndex = casVal args.Config.ModifyIndex = casVal
args.CAS = true args.CAS = true

View File

@ -23,7 +23,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R
s.parseDC(req, &args.Datacenter) s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Query); err != nil { if err := decodeBody(req.Body, &args.Query); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
var reply string var reply string
@ -143,7 +143,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter,
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if structs.IsErrQueryNotFound(err) { if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
return nil, err return nil, err
} }
@ -196,7 +196,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if structs.IsErrQueryNotFound(err) { if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
return nil, err return nil, err
} }
@ -225,7 +225,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds // We have to check the string since the RPC sheds
// the specific error type. // the specific error type.
if structs.IsErrQueryNotFound(err) { if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
} }
return nil, err return nil, err
} }
@ -247,7 +247,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter,
s.parseToken(req, &args.Token) s.parseToken(req, &args.Token)
if req.ContentLength > 0 { if req.ContentLength > 0 {
if err := decodeBody(req.Body, &args.Query); err != nil { if err := decodeBody(req.Body, &args.Query); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
} }

View File

@ -621,8 +621,12 @@ func TestPreparedQuery_Execute(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body) req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req) _, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok { if err, ok := err.(HTTPError); ok {
t.Fatalf("Expected not found error but got %v", err) if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
} }
}) })
} }
@ -756,8 +760,12 @@ func TestPreparedQuery_Explain(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body) req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req) _, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok { if err, ok := err.(HTTPError); ok {
t.Fatalf("Expected not found error but got %v", err) if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
} }
}) })
@ -845,8 +853,12 @@ func TestPreparedQuery_Get(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body) req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req) _, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok { if err, ok := err.(HTTPError); ok {
t.Fatalf("Expected not found error but got %v", err) if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
} }
}) })
} }

View File

@ -40,7 +40,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request
// Handle optional request body // Handle optional request body
if req.ContentLength > 0 { if req.ContentLength > 0 {
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil { if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
} }
} }
@ -75,7 +75,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques
return nil, err return nil, err
} }
if args.Session.ID == "" { if args.Session.ID == "" {
return nil, BadRequestError{Reason: "Missing session"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
} }
var out string var out string
@ -103,14 +103,14 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request)
} }
args.Session = args.SessionID args.Session = args.SessionID
if args.SessionID == "" { if args.SessionID == "" {
return nil, BadRequestError{Reason: "Missing session"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
} }
var out structs.IndexedSessions var out structs.IndexedSessions
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil { if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
return nil, err return nil, err
} else if out.Sessions == nil { } else if out.Sessions == nil {
return nil, NotFoundError{Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)}
} }
return out.Sessions, nil return out.Sessions, nil
@ -134,7 +134,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) (
} }
args.Session = args.SessionID args.Session = args.SessionID
if args.SessionID == "" { if args.SessionID == "" {
return nil, BadRequestError{Reason: "Missing session"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
} }
var out structs.IndexedSessions var out structs.IndexedSessions
@ -190,7 +190,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque
return nil, err return nil, err
} }
if args.Node == "" { if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
} }
var out structs.IndexedSessions var out structs.IndexedSessions

View File

@ -88,7 +88,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// Check Content-Length first before decoding to return early // Check Content-Length first before decoding to return early
if req.ContentLength > maxTxnLen { if req.ContentLength > maxTxnLen {
return nil, 0, EntityTooLargeError{ return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.", Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"), req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"),
} }
@ -100,7 +101,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
if err.Error() == "http: request body too large" { if err.Error() == "http: request body too large" {
// The request size is also verified during decoding to double check // The request size is also verified during decoding to double check
// if the Content-Length header was not set by the client. // if the Content-Length header was not set by the client.
return nil, 0, EntityTooLargeError{ return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.", Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.",
maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"), maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"),
} }
@ -108,15 +110,16 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// Note the body is in API format, and not the RPC format. If we can't // Note the body is in API format, and not the RPC format. If we can't
// decode it, we will return a 400 since we don't have enough context to // decode it, we will return a 400 since we don't have enough context to
// associate the error with a given operation. // associate the error with a given operation.
return nil, 0, BadRequestError{Reason: fmt.Sprintf("Failed to parse body: %v", err)} return nil, 0, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to parse body: %v", err)}
} }
} }
// Enforce a reasonable upper limit on the number of operations in a // Enforce a reasonable upper limit on the number of operations in a
// transaction in order to curb abuse. // transaction in order to curb abuse.
if size := len(ops); size > maxTxnOps { if size := len(ops); size > maxTxnOps {
return nil, 0, EntityTooLargeError{ return nil, 0, HTTPError{
Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps), StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps),
} }
} }
@ -130,8 +133,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
case in.KV != nil: case in.KV != nil:
size := len(in.KV.Value) size := len(in.KV.Value)
if int64(size) > kvMaxValueSize { if int64(size) > kvMaxValueSize {
return nil, 0, EntityTooLargeError{ return nil, 0, HTTPError{
Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize), StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize),
} }
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
@ -31,10 +32,7 @@ func TestTxnEndpoint_Bad_JSON(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf) req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.Txn(resp, req) _, err := a.srv.Txn(resp, req)
err, ok := err.(BadRequestError) require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err))
if !ok {
t.Fatalf("expected bad request error but got %v", err)
}
if !strings.Contains(err.Error(), "Failed to parse") { if !strings.Contains(err.Error(), "Failed to parse") {
t.Fatalf("expected conflicting args error") t.Fatalf("expected conflicting args error")
} }
@ -63,11 +61,19 @@ func TestTxnEndpoint_Bad_Size_Item(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf) req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := agent.srv.Txn(resp, req) _, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err) if wantPass {
} if err != nil {
if err != nil && wantPass { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} else {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("excected HTTP error but got %v", err)
}
} }
} }
@ -138,11 +144,19 @@ func TestTxnEndpoint_Bad_Size_Net(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf) req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := agent.srv.Txn(resp, req) _, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err) if wantPass {
} if err != nil {
if err != nil && wantPass { t.Fatalf("err: %v", err)
t.Fatalf("err: %v", err) }
} else {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("excected HTTP error but got %v", err)
}
} }
} }
@ -205,8 +219,13 @@ func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf) req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
_, err := a.srv.Txn(resp, req) _, err := a.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok {
t.Fatalf("expected too large error but got %v", err) if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
} }
} }

View File

@ -140,7 +140,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) (
return nil, err return nil, err
} }
if args.Node == "" { if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
} }
// Make the RPC request // Make the RPC request
@ -272,7 +272,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing gateway name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
} }
// Make the RPC request // Make the RPC request
@ -316,12 +316,12 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
return nil, err return nil, err
} }
if args.ServiceName == "" { if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
} }
kind, ok := req.URL.Query()["kind"] kind, ok := req.URL.Query()["kind"]
if !ok { if !ok {
return nil, BadRequestError{Reason: "Missing service kind"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service kind"}
} }
args.ServiceKind = structs.ServiceKind(kind[0]) args.ServiceKind = structs.ServiceKind(kind[0])
@ -329,7 +329,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
case structs.ServiceKindTypical, structs.ServiceKindIngressGateway: case structs.ServiceKindTypical, structs.ServiceKindIngressGateway:
// allowed // allowed
default: default:
return nil, BadRequestError{Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)}
} }
// Make the RPC request // Make the RPC request
@ -594,7 +594,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R
return nil, err return nil, err
} }
if name == "" { if name == "" {
return nil, BadRequestError{Reason: "Missing gateway name"} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
} }
args.Match = &structs.IntentionQueryMatch{ args.Match = &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination, Type: structs.IntentionMatchDestination,
@ -624,14 +624,14 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
// Check the UI was enabled at agent startup (note this is not reloadable // Check the UI was enabled at agent startup (note this is not reloadable
// currently). // currently).
if !s.IsUIEnabled() { if !s.IsUIEnabled() {
return nil, NotFoundError{Reason: "UI is not enabled"} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "UI is not enabled"}
} }
// Load reloadable proxy config // Load reloadable proxy config
cfg, ok := s.metricsProxyCfg.Load().(config.UIMetricsProxy) cfg, ok := s.metricsProxyCfg.Load().(config.UIMetricsProxy)
if !ok || cfg.BaseURL == "" { if !ok || cfg.BaseURL == "" {
// Proxy not configured // Proxy not configured
return nil, NotFoundError{Reason: "Metrics proxy is not enabled"} return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Metrics proxy is not enabled"}
} }
// Fetch the ACL token, if provided, but ONLY from headers since other // Fetch the ACL token, if provided, but ONLY from headers since other
@ -686,7 +686,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
u, err := url.Parse(newURL) u, err := url.Parse(newURL)
if err != nil { if err != nil {
log.Error("couldn't parse target URL", "base_url", cfg.BaseURL, "path", subPath) log.Error("couldn't parse target URL", "base_url", cfg.BaseURL, "path", subPath)
return nil, BadRequestError{Reason: "Invalid path."} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."}
} }
// Clean the new URL path to prevent path traversal attacks and remove any // Clean the new URL path to prevent path traversal attacks and remove any
@ -735,7 +735,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
"path", subPath, "path", subPath,
"target_url", u.String(), "target_url", u.String(),
) )
return nil, BadRequestError{Reason: "Invalid path."} return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."}
} }
// Add any configured headers // Add any configured headers