Unify various status errors into one HTTP error type. (#12594)

Replaces specific error types for HTTP Status codes with 
a generic HTTPError type.

Co-authored-by: Chris S. Kim <ckim@hashicorp.com>
This commit is contained in:
Mathew Estafanous 2022-04-29 13:42:49 -04:00 committed by GitHub
parent d04fe6ca2c
commit 474385d153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 332 additions and 320 deletions

View File

@ -2,6 +2,7 @@ package agent
import (
"fmt"
"net/http"
"github.com/hashicorp/serf/serf"
@ -82,9 +83,10 @@ func (a *Agent) vetServiceUpdateWithAuthorizer(authz acl.Authorizer, serviceID s
// agent/local/state.go's deleteService assumes the Catalog.Deregister RPC call
// will include "Unknown service"in the error if deregistration fails due to a
// service with that ID not existing.
return NotFoundError{Reason: fmt.Sprintf(
"Unknown service ID %q. Ensure that the service ID is passed, not the service name.",
serviceID)}
return HTTPError{
StatusCode: http.StatusNotFound,
Reason: fmt.Sprintf("Unknown service ID %q. Ensure that the service ID is passed, not the service name.", serviceID),
}
}
return nil
@ -140,9 +142,10 @@ func (a *Agent) vetCheckUpdateWithAuthorizer(authz acl.Authorizer, checkID struc
}
}
} else {
return NotFoundError{Reason: fmt.Sprintf(
"Unknown check ID %q. Ensure that the check ID is passed, not the check name.",
checkID.String())}
return HTTPError{
StatusCode: http.StatusNotFound,
Reason: fmt.Sprintf("Unknown check ID %q. Ensure that the check ID is passed, not the check name.", checkID.String()),
}
}
return nil

View File

@ -16,7 +16,7 @@ type aclBootstrapResponse struct {
structs.ACLToken
}
var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"}
var aclDisabled = HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"}
// checkACLDisabled will return a standard response if ACLs are disabled. This
// returns true if they are disabled and we should not continue.
@ -127,7 +127,7 @@ func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request
return nil, err
}
if policyID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing policy ID"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy ID"}
}
return fn(resp, req, policyID)
@ -175,7 +175,7 @@ func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.R
return nil, err
}
if policyName == "" {
return nil, BadRequestError{Reason: "Missing policy Name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing policy Name"}
}
return s.ACLPolicyRead(resp, req, "", policyName)
@ -207,18 +207,18 @@ func (s *HTTPHandlers) aclPolicyWriteInternal(_resp http.ResponseWriter, req *ht
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Policy)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Policy decoding failed: %v", err)}
}
args.Policy.Syntax = acl.SyntaxCurrent
if create {
if args.Policy.ID != "" {
return nil, BadRequestError{Reason: "Cannot specify the ID when creating a new policy"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify the ID when creating a new policy"}
}
} else {
if args.Policy.ID != "" && args.Policy.ID != policyID {
return nil, BadRequestError{Reason: "Policy ID in URL and payload do not match"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Policy ID in URL and payload do not match"}
} else if args.Policy.ID == "" {
args.Policy.ID = policyID
}
@ -317,7 +317,7 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request)
fn = s.ACLTokenClone
}
if tokenID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing token ID"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing token ID"}
}
return fn(resp, req, tokenID)
@ -422,12 +422,12 @@ func (s *HTTPHandlers) aclTokenSetInternal(req *http.Request, tokenID string, cr
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}
if !create {
if args.ACLToken.AccessorID != "" && args.ACLToken.AccessorID != tokenID {
return nil, BadRequestError{Reason: "Token Accessor ID in URL and payload do not match"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Token Accessor ID in URL and payload do not match"}
} else if args.ACLToken.AccessorID == "" {
args.ACLToken.AccessorID = tokenID
}
@ -472,7 +472,7 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request
return nil, err
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.ACLToken)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}
s.parseToken(req, &args.Token)
@ -546,7 +546,7 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request)
return nil, err
}
if roleID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing role ID"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role ID"}
}
return fn(resp, req, roleID)
@ -562,7 +562,7 @@ func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Req
return nil, err
}
if roleName == "" {
return nil, BadRequestError{Reason: "Missing role Name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing role Name"}
}
return s.ACLRoleRead(resp, req, "", roleName)
@ -621,11 +621,11 @@ func (s *HTTPHandlers) ACLRoleWrite(resp http.ResponseWriter, req *http.Request,
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Role)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Role decoding failed: %v", err)}
}
if args.Role.ID != "" && args.Role.ID != roleID {
return nil, BadRequestError{Reason: "Role ID in URL and payload do not match"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Role ID in URL and payload do not match"}
} else if args.Role.ID == "" {
args.Role.ID = roleID
}
@ -716,7 +716,7 @@ func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Re
return nil, err
}
if bindingRuleID == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing binding rule ID"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing binding rule ID"}
}
return fn(resp, req, bindingRuleID)
@ -770,11 +770,11 @@ func (s *HTTPHandlers) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.R
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.BindingRule)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
}
if args.BindingRule.ID != "" && args.BindingRule.ID != bindingRuleID {
return nil, BadRequestError{Reason: "BindingRule ID in URL and payload do not match"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "BindingRule ID in URL and payload do not match"}
} else if args.BindingRule.ID == "" {
args.BindingRule.ID = bindingRuleID
}
@ -862,7 +862,7 @@ func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Req
return nil, err
}
if methodName == "" && req.Method != "PUT" {
return nil, BadRequestError{Reason: "Missing auth method name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing auth method name"}
}
return fn(resp, req, methodName)
@ -916,12 +916,12 @@ func (s *HTTPHandlers) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Re
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.AuthMethod)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
}
if methodName != "" {
if args.AuthMethod.Name != "" && args.AuthMethod.Name != methodName {
return nil, BadRequestError{Reason: "AuthMethod Name in URL and payload do not match"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "AuthMethod Name in URL and payload do not match"}
} else if args.AuthMethod.Name == "" {
args.AuthMethod.Name = methodName
}
@ -969,7 +969,7 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in
}
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Auth)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
}
var out structs.ACLToken
@ -1058,11 +1058,11 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
s.parseDC(req, &request.Datacenter)
if err := decodeBody(req.Body, &request.Requests); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to decode request body: %v", err)}
}
if len(request.Requests) > maxRequests {
return nil, BadRequestError{Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Refusing to process more than %d authorizations at once", maxRequests)}
}
if len(request.Requests) == 0 {
@ -1083,7 +1083,7 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
responses, err = structs.CreateACLAuthorizationResponses(authz, request.Requests)
if err != nil {
return nil, BadRequestError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
}
}

View File

@ -26,6 +26,16 @@ import (
// They are not intended to thoroughly test the backing RPC
// functionality as that will be done with other tests.
func isHTTPBadRequest(err error) bool {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 400 {
return false
}
return true
}
return false
}
func TestACL_Disabled_Response(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
@ -71,7 +81,7 @@ func TestACL_Disabled_Response(t *testing.T) {
resp := httptest.NewRecorder()
obj, err := tt.fn(resp, req)
require.Nil(t, obj)
require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"})
require.ErrorIs(t, err, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"})
})
}
}
@ -270,8 +280,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Policy CRUD Missing ID in URL", func(t *testing.T) {
@ -279,8 +288,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update", func(t *testing.T) {
@ -327,8 +335,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Invalid payload", func(t *testing.T) {
@ -339,8 +346,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLPolicyCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Delete", func(t *testing.T) {
@ -497,8 +503,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Role CRUD Missing ID in URL", func(t *testing.T) {
@ -506,8 +511,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update", func(t *testing.T) {
@ -567,8 +571,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Invalid payload", func(t *testing.T) {
@ -579,8 +582,7 @@ func TestACL_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLRoleCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Delete", func(t *testing.T) {
@ -818,8 +820,7 @@ func TestACL_HTTP(t *testing.T) {
obj, err := a.srv.ACLTokenCRUD(resp, req)
require.Error(t, err)
require.Nil(t, obj)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update Accessor Mismatch", func(t *testing.T) {
originalToken := tokenMap[idMap["token-cloned"]]
@ -841,8 +842,7 @@ func TestACL_HTTP(t *testing.T) {
obj, err := a.srv.ACLTokenCRUD(resp, req)
require.Error(t, err)
require.Nil(t, obj)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Delete", func(t *testing.T) {
req, _ := http.NewRequest("DELETE", "/v1/acl/token/"+idMap["token-cloned"]+"?token=root", nil)
@ -1284,8 +1284,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update Name URL Mismatch", func(t *testing.T) {
@ -1302,8 +1301,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update", func(t *testing.T) {
@ -1342,8 +1340,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLAuthMethodCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("List", func(t *testing.T) {
@ -1480,8 +1477,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCRUD(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Update", func(t *testing.T) {
@ -1529,8 +1525,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("Invalid payload", func(t *testing.T) {
@ -1541,8 +1536,7 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.ACLBindingRuleCreate(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
})
t.Run("List", func(t *testing.T) {

View File

@ -425,7 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request)
svcState := s.agent.State.ServiceState(sid)
if svcState == nil {
return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
return "", nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
}
svc := svcState.Service
@ -555,7 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request)
// key are ok, otherwise the argument doesn't apply to
// the WAN.
default:
return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot provide a segment with wan=true"}
}
}
@ -732,16 +732,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
}
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Verify the check has a name.
if args.Name == "" {
return nil, BadRequestError{Reason: "Missing check name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check name"}
}
if args.Status != "" && !structs.ValidStatus(args.Status) {
return nil, BadRequestError{Reason: "Bad check status"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Bad check status"}
}
authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil)
@ -760,7 +760,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
chkType := args.CheckType()
err = chkType.Validate()
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)}
}
// Store the type of check based on the definition
@ -773,7 +773,7 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
if service != nil {
health.ServiceName = service.Service
} else {
return nil, NotFoundError{fmt.Sprintf("ServiceID %q does not exist", cid.String())}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("ServiceID %q does not exist", cid.String())}
}
}
@ -878,7 +878,7 @@ type checkUpdate struct {
func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate
if err := decodeBody(req.Body, &update); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
switch update.Status {
@ -886,7 +886,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ
case api.HealthWarning:
case api.HealthCritical:
default:
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
}
ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/")
@ -981,7 +981,7 @@ func (s *HTTPHandlers) AgentHealthServiceByID(resp http.ResponseWriter, req *htt
return nil, err
}
if serviceID == "" {
return nil, &BadRequestError{Reason: "Missing serviceID"}
return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing serviceID"}
}
var entMeta acl.EnterpriseMeta
@ -1043,7 +1043,7 @@ func (s *HTTPHandlers) AgentHealthServiceByName(resp http.ResponseWriter, req *h
}
if serviceName == "" {
return nil, &BadRequestError{Reason: "Missing service Name"}
return nil, &HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service Name"}
}
var entMeta acl.EnterpriseMeta
@ -1114,18 +1114,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
}
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Verify the service has a name.
if args.Name == "" {
return nil, BadRequestError{Reason: "Missing service name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
}
// Check the service address here and in the catalog RPC endpoint
// since service registration isn't synchronous.
if ipaddr.IsAny(args.Address) {
return nil, BadRequestError{Reason: "Invalid service address"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid service address"}
}
var token string
@ -1144,27 +1144,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
ns := args.NodeService()
if ns.Weights != nil {
if err := structs.ValidateWeights(ns.Weights); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Weights: %v", err)}
}
}
if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
}
// Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
}
// Verify the check type.
chkTypes, err := args.CheckTypes()
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check: %v", err)}
}
for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) {
return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"}
}
}
@ -1172,15 +1172,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
if args.Connect != nil && args.Connect.SidecarService != nil {
chkTypes, err := args.Connect.SidecarService.CheckTypes()
if err != nil {
return nil, &BadRequestError{
Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err),
}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid check in sidecar_service: %v", err)}
}
for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) {
return nil, &BadRequestError{
Reason: "Status for checks must 'passing', 'warning', 'critical'",
}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Status for checks must 'passing', 'warning', 'critical'"}
}
}
}
@ -1193,12 +1189,11 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
// See if we have a sidecar to register too
sidecar, sidecarChecks, sidecarToken, err := s.agent.sidecarServiceFromNodeService(ns, token)
if err != nil {
return nil, &BadRequestError{
Reason: fmt.Sprintf("Invalid SidecarService: %s", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid SidecarService: %s", err)}
}
if sidecar != nil {
if err := sidecar.Validate(); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
}
// Make sure we are allowed to register the sidecar using the token
// specified (might be specific to sidecar or the same one as the overall
@ -1299,19 +1294,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
sid := structs.NewServiceID(serviceID, nil)
if sid.ID == "" {
return nil, BadRequestError{Reason: "Missing service ID"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service ID"}
}
// Ensure we have some action
params := req.URL.Query()
if _, ok := params["enable"]; !ok {
return nil, BadRequestError{Reason: "Missing value for enable"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"}
}
raw := params.Get("enable")
enable, err := strconv.ParseBool(raw)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
}
// Get the provided token, if any, and vet against any ACL policies.
@ -1340,11 +1335,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
if enable {
reason := params.Get("reason")
if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
} else {
if err = s.agent.DisableServiceMaintenance(sid); err != nil {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
}
s.syncChanges()
@ -1355,13 +1350,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http.
// Ensure we have some action
params := req.URL.Query()
if _, ok := params["enable"]; !ok {
return nil, BadRequestError{Reason: "Missing value for enable"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing value for enable"}
}
raw := params.Get("enable")
enable, err := strconv.ParseBool(raw)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
}
// Get the provided token, if any, and vet against any ACL policies.
@ -1416,9 +1411,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
}
if !logging.ValidateLogLevel(logLevel) {
return nil, BadRequestError{
Reason: fmt.Sprintf("Unknown log level: %s", logLevel),
}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unknown log level: %s", logLevel)}
}
flusher, ok := resp.(http.Flusher)
@ -1469,7 +1462,7 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled() {
return nil, UnauthorizedError{Reason: "ACL support disabled"}
return nil, HTTPError{StatusCode: http.StatusUnauthorized, Reason: "ACL support disabled"}
}
// Fetch the ACL token, if any, and enforce agent policy.
@ -1491,7 +1484,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (
// fields to this later if needed.
var args api.AgentToken
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Figure out the target token.
@ -1522,7 +1515,7 @@ func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (
s.agent.tokens.UpdateReplicationToken(args.Token, token_store.TokenSourceAPI)
default:
return NotFoundError{Reason: fmt.Sprintf("Token %q is unknown", target)}
return HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Token %q is unknown", target)}
}
// TODO: is it safe to move this out of WithPersistenceLock?
@ -1641,7 +1634,7 @@ func (s *HTTPHandlers) AgentConnectAuthorize(resp http.ResponseWriter, req *http
}
if err := decodeBody(req.Body, &authReq); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
if !s.validateRequestPartition(resp, &authReq.EnterpriseMeta) {

View File

@ -5473,7 +5473,7 @@ func TestAgent_Token(t *testing.T) {
url: "acl_token?token=root",
body: badJSON(),
code: http.StatusBadRequest,
expectedErr: `Bad request: Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`,
expectedErr: `Request decode failed: json: cannot unmarshal bool into Go value of type api.AgentToken`,
},
{
name: "set user legacy",

View File

@ -136,7 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -166,7 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req
return nil, err
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -363,7 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
}
// Make the RPC request
@ -438,7 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
}
// Make the RPC request
@ -503,7 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt
return nil, err
}
if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
}
// Make the RPC request
@ -554,7 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing gateway name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
}
// Make the RPC request

View File

@ -56,7 +56,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i
setMeta(resp, &reply.QueryMeta)
if reply.Entry == nil {
return nil, NotFoundError{Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("%s for %q / %q", ConfigEntryNotFoundErr, pathArgs[0], pathArgs[1])}
}
return reply.Entry, nil
@ -75,7 +75,7 @@ func (s *HTTPHandlers) configGet(resp http.ResponseWriter, req *http.Request) (i
return reply.Entries, nil
default:
return nil, NotFoundError{Reason: "Must provide either a kind or both kind and name"}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide either a kind or both kind and name"}
}
}
@ -91,12 +91,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request)
pathArgs := strings.SplitN(kindAndName, "/", 2)
if len(pathArgs) != 2 {
return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Must provide both a kind and name to delete"}
}
entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1])
if err != nil {
return nil, BadRequestError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
}
args.Entry = entry
// Parse enterprise meta.
@ -139,13 +139,13 @@ func (s *HTTPHandlers) ConfigApply(resp http.ResponseWriter, req *http.Request)
var raw map[string]interface{}
if err := decodeBodyDeprecated(req, &raw, nil); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}
if entry, err := structs.DecodeConfigEntry(raw); err == nil {
args.Entry = entry
} else {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}
// Parse enterprise meta.

View File

@ -601,9 +601,8 @@ func TestConfig_Apply_Decoding(t *testing.T) {
_, err := a.srv.ConfigApply(resp, req)
require.Error(t, err)
badReq, ok := err.(BadRequestError)
require.True(t, ok)
require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", badReq.Reason)
require.True(t, isHTTPBadRequest(err))
require.Equal(t, "Request decoding failed: Payload does not contain a kind/Kind key at the top level", err.Error())
})
t.Run("Kind Not String", func(t *testing.T) {
@ -619,9 +618,8 @@ func TestConfig_Apply_Decoding(t *testing.T) {
_, err := a.srv.ConfigApply(resp, req)
require.Error(t, err)
badReq, ok := err.(BadRequestError)
require.True(t, ok)
require.Equal(t, "Request decoding failed: Kind value in payload is not a string", badReq.Reason)
require.True(t, isHTTPBadRequest(err))
require.Equal(t, "Request decoding failed: Kind value in payload is not a string", err.Error())
})
t.Run("Lowercase kind", func(t *testing.T) {

View File

@ -3,6 +3,7 @@ package agent
import (
"context"
"fmt"
"net/http"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
@ -23,7 +24,7 @@ import (
// The ACL token and the auth request are provided and the auth decision (true
// means authorized) and reason string are returned.
//
// If the request input is invalid the error returned will be a BadRequestError,
// If the request input is invalid the error returned will be a BadRequest HTTPError,
// if the token doesn't grant necessary access then an acl.ErrPermissionDenied
// error is returned, otherwise error indicates an unexpected server failure. If
// access is denied, no error is returned but the first return value is false.
@ -37,23 +38,23 @@ func (a *Agent) ConnectAuthorize(token string,
}
if req == nil {
return returnErr(BadRequestError{"Invalid request"})
return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid request"})
}
// We need to have a target to check intentions
if req.Target == "" {
return returnErr(BadRequestError{"Target service must be specified"})
return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "Target service must be specified"})
}
// Parse the certificate URI from the client ID
uri, err := connect.ParseCertURIFromString(req.ClientCertURI)
if err != nil {
return returnErr(BadRequestError{"ClientCertURI not a valid Connect identifier"})
return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Connect identifier"})
}
uriService, ok := uri.(*connect.SpiffeIDService)
if !ok {
return returnErr(BadRequestError{"ClientCertURI not a valid Service identifier"})
return returnErr(HTTPError{StatusCode: http.StatusBadRequest, Reason: "ClientCertURI not a valid Service identifier"})
}
// We need to verify service:write permissions for the given token.

View File

@ -20,7 +20,7 @@ func (s *HTTPHandlers) ConnectCARoots(resp http.ResponseWriter, req *http.Reques
if pemParam := req.URL.Query().Get("pem"); pemParam != "" {
val, err := strconv.ParseBool(pemParam)
if err != nil {
return nil, BadRequestError{Reason: "The 'pem' query parameter must be a boolean value"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "The 'pem' query parameter must be a boolean value"}
}
pemResponse = val
}
@ -90,15 +90,14 @@ func (s *HTTPHandlers) ConnectCAConfigurationSet(req *http.Request) (interface{}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Config); err != nil {
return nil, BadRequestError{
Reason: fmt.Sprintf("Request decode failed: %v", err),
}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
var reply interface{}
err := s.agent.RPC("ConnectCA.ConfigurationSet", &args, &reply)
if err != nil && err.Error() == consul.ErrStateReadOnly.Error() {
return nil, BadRequestError{
return nil, HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Provider State is read-only. It must be omitted" +
" or identical to the current value",
}

View File

@ -14,7 +14,7 @@ func (s *HTTPHandlers) checkCoordinateDisabled() error {
if !s.agent.config.DisableCoordinates {
return nil
}
return UnauthorizedError{Reason: "Coordinate support disabled"}
return HTTPError{StatusCode: http.StatusUnauthorized, Reason: "Coordinate support disabled"}
}
// sorter wraps a coordinate list and implements the sort.Interface to sort by
@ -156,7 +156,7 @@ func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Requ
args := structs.CoordinateUpdateRequest{}
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)

View File

@ -39,10 +39,14 @@ func TestCoordinate_Disabled_Response(t *testing.T) {
req, _ := http.NewRequest("PUT", "/should/not/care", nil)
resp := httptest.NewRecorder()
obj, err := tt(resp, req)
err, ok := err.(UnauthorizedError)
if !ok {
t.Fatalf("expected unauthorized error but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 401 {
t.Fatalf("expected status 401 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
}
if obj != nil {
t.Fatalf("bad: %#v", obj)
}

View File

@ -25,7 +25,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
return nil, err
}
if args.Name == "" {
return nil, BadRequestError{Reason: "Missing chain name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing chain name"}
}
args.EvaluateInDatacenter = req.URL.Query().Get("compile-dc")
@ -38,12 +38,12 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
if req.Method == "POST" {
var raw map[string]interface{}
if err := decodeBody(req.Body, &raw); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}
apiReq, err := decodeDiscoveryChainReadRequest(raw)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}
args.OverrideProtocol = apiReq.OverrideProtocol
@ -52,7 +52,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
if apiReq.OverrideMeshGateway.Mode != "" {
_, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode))
if err != nil {
return nil, BadRequestError{Reason: "Invalid OverrideMeshGateway.Mode parameter"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid OverrideMeshGateway.Mode parameter"}
}
args.OverrideMeshGateway = apiReq.OverrideMeshGateway
}

View File

@ -57,8 +57,7 @@ func TestDiscoveryChainRead(t *testing.T) {
resp := httptest.NewRecorder()
_, err = a.srv.DiscoveryChainRead(resp, req)
require.Error(t, err)
_, ok := err.(BadRequestError)
require.True(t, ok)
require.True(t, isHTTPBadRequest(err))
}))
require.True(t, t.Run(method+": read default chain", func(t *testing.T) {

View File

@ -25,7 +25,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
return nil, err
}
if event.Name == "" {
return nil, BadRequestError{Reason: "Missing name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing name"}
}
// Get the ACL token
@ -55,7 +55,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
// Try to fire the event
if err := s.agent.UserEvent(dc, token, event); err != nil {
if acl.IsErrPermissionDenied(err) {
return nil, ForbiddenError{Reason: acl.ErrPermissionDenied.Error()}
return nil, HTTPError{StatusCode: http.StatusForbidden, Reason: acl.ErrPermissionDenied.Error()}
}
resp.WriteHeader(http.StatusInternalServerError)
return nil, err

View File

@ -103,8 +103,12 @@ func TestEventFire_token(t *testing.T) {
if !acl.IsErrPermissionDenied(err) {
t.Fatalf("bad: %s", err.Error())
}
if err, ok := err.(ForbiddenError); !ok {
t.Fatalf("Expected forbidden but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 403 {
t.Fatalf("Expected 403 but got %d", err.StatusCode)
}
} else {
t.Fatalf("Expected HTTP Error %v", err)
}
}
}

View File

@ -13,7 +13,7 @@ func (s *HTTPHandlers) FederationStateGet(resp http.ResponseWriter, req *http.Re
return nil, err
}
if datacenterName == "" {
return nil, BadRequestError{Reason: "Missing datacenter name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing datacenter name"}
}
args := structs.FederationStateQuery{

View File

@ -34,7 +34,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.State == "" {
return nil, BadRequestError{Reason: "Missing check state"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing check state"}
}
// Make the RPC request
@ -82,7 +82,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ
return nil, err
}
if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
}
// Make the RPC request
@ -132,7 +132,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
}
// Make the RPC request
@ -224,7 +224,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
}
out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args)
@ -242,7 +242,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
// Filter to only passing if specified
filter, err := getBoolQueryParam(params, api.HealthPassing)
if err != nil {
return nil, BadRequestError{Reason: "Invalid value for ?passing"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid value for ?passing"}
}
// FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer

View File

@ -1405,9 +1405,7 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil)
resp := httptest.NewRecorder()
_, err := a.srv.HealthServiceNodes(resp, req)
if _, ok := err.(BadRequestError); !ok {
t.Fatalf("Expected bad request error but got %v", err)
}
require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err))
if !strings.Contains(err.Error(), "Invalid value for ?passing") {
t.Errorf("bad %s", err.Error())
}
@ -1813,8 +1811,7 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
resp := httptest.NewRecorder()
_, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.NotNil(t, err)
_, ok := err.(BadRequestError)
assert.True(t, ok)
assert.True(t, isHTTPBadRequest(err))
assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing"))
})

View File

@ -51,41 +51,6 @@ func (e MethodNotAllowedError) Error() string {
return fmt.Sprintf("method %s not allowed", e.Method)
}
// BadRequestError should be returned by a handler when parameters or the payload are not valid
type BadRequestError struct {
Reason string
}
func (e BadRequestError) Error() string {
return fmt.Sprintf("Bad request: %s", e.Reason)
}
// NotFoundError should be returned by a handler when a resource specified does not exist
type NotFoundError struct {
Reason string
}
func (e NotFoundError) Error() string {
return e.Reason
}
// UnauthorizedError should be returned by a handler when the request lacks valid authorization.
type UnauthorizedError struct {
Reason string
}
func (e UnauthorizedError) Error() string {
return e.Reason
}
type EntityTooLargeError struct {
Reason string
}
func (e EntityTooLargeError) Error() string {
return e.Reason
}
// CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload
type CodeWithPayloadError struct {
@ -98,12 +63,15 @@ func (e CodeWithPayloadError) Error() string {
return e.Reason
}
type ForbiddenError struct {
Reason string
// HTTPError is returned by the handler when a specific http error
// code is needed alongside a plain text response.
type HTTPError struct {
StatusCode int
Reason string
}
func (e ForbiddenError) Error() string {
return e.Reason
func (h HTTPError) Error() string {
return h.Reason
}
// HTTPHandlers provides an HTTP api for an agent.
@ -423,8 +391,7 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
if acl.IsErrPermissionDenied(err) || acl.IsErrNotFound(err) {
return true
}
_, ok := err.(ForbiddenError)
return ok
return false
}
isMethodNotAllowed := func(err error) bool {
@ -432,35 +399,20 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
return ok
}
isBadRequest := func(err error) bool {
_, ok := err.(BadRequestError)
return ok
}
isNotFound := func(err error) bool {
_, ok := err.(NotFoundError)
return ok
}
isUnauthorized := func(err error) bool {
_, ok := err.(UnauthorizedError)
return ok
}
isTooManyRequests := func(err error) bool {
// Sadness net/rpc can't do nice typed errors so this is all we got
return err.Error() == consul.ErrRateLimited.Error()
}
isEntityToLarge := func(err error) bool {
_, ok := err.(EntityTooLargeError)
return ok
}
addAllowHeader := func(methods []string) {
resp.Header().Add("Allow", strings.Join(methods, ","))
}
isHTTPError := func(err error) bool {
_, ok := err.(HTTPError)
return ok
}
handleErr := func(err error) {
if req.Context().Err() != nil {
httpLogger.Info("Request cancelled",
@ -490,21 +442,21 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
addAllowHeader(err.(MethodNotAllowedError).Allow)
resp.WriteHeader(http.StatusMethodNotAllowed) // 405
fmt.Fprint(resp, err.Error())
case isBadRequest(err):
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, err.Error())
case isNotFound(err):
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
case isUnauthorized(err):
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, err.Error())
case isHTTPError(err):
err := err.(HTTPError)
code := http.StatusInternalServerError
if err.StatusCode != 0 {
code = err.StatusCode
}
reason := "An unexpected error occurred"
if err.Error() != "" {
reason = err.Error()
}
resp.WriteHeader(code)
fmt.Fprint(resp, reason)
case isTooManyRequests(err):
resp.WriteHeader(http.StatusTooManyRequests)
fmt.Fprint(resp, err.Error())
case isEntityToLarge(err):
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprint(resp, err.Error())
default:
resp.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(resp, err.Error())
@ -1175,7 +1127,7 @@ func (s *HTTPHandlers) checkWriteAccess(req *http.Request) error {
}
}
return ForbiddenError{Reason: "Access is restricted"}
return HTTPError{StatusCode: http.StatusForbidden, Reason: "Access is restricted"}
}
func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) {

View File

@ -14,10 +14,16 @@ import (
func (s *HTTPHandlers) parseEntMeta(req *http.Request, entMeta *acl.EnterpriseMeta) error {
if headerNS := req.Header.Get("X-Consul-Namespace"); headerNS != "" {
return BadRequestError{Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid header: \"X-Consul-Namespace\" - Namespaces are a Consul Enterprise feature",
}
}
if queryNS := req.URL.Query().Get("ns"); queryNS != "" {
return BadRequestError{Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"ns\" - Namespaces are a Consul Enterprise feature",
}
}
return s.parseEntMetaPartition(req, entMeta)
@ -32,7 +38,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionPartition(logName, partition s
// No special handling for wildcard namespaces as they are pointless in OSS.
return BadRequestError{Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid " + logName + "(" + partition + ")" + ": Partitions is a Consul Enterprise feature",
}
}
func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string, _ bool) error {
@ -44,7 +53,10 @@ func (s *HTTPHandlers) validateEnterpriseIntentionNamespace(logName, ns string,
// No special handling for wildcard namespaces as they are pointless in OSS.
return BadRequestError{Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid " + logName + "(" + ns + ")" + ": Namespaces is a Consul Enterprise feature",
}
}
func (s *HTTPHandlers) parseEntMetaNoWildcard(req *http.Request, _ *acl.EnterpriseMeta) error {
@ -72,7 +84,10 @@ func (s *HTTPHandlers) rewordUnknownEnterpriseFieldError(err error) error {
func parseACLAuthMethodEnterpriseMeta(req *http.Request, _ *structs.ACLAuthMethodEnterpriseMeta) error {
if methodNS := req.URL.Query().Get("authmethod-ns"); methodNS != "" {
return BadRequestError{Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"authmethod-ns\" - Namespaces are a Consul Enterprise feature",
}
}
return nil
@ -91,10 +106,16 @@ func (s *HTTPHandlers) uiTemplateDataTransform(data map[string]interface{}) erro
func (s *HTTPHandlers) parseEntMetaPartition(req *http.Request, meta *acl.EnterpriseMeta) error {
if headerAP := req.Header.Get("X-Consul-Partition"); headerAP != "" {
return BadRequestError{Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid header: \"X-Consul-Partition\" - Partitions are a Consul Enterprise feature",
}
}
if queryAP := req.URL.Query().Get("partition"); queryAP != "" {
return BadRequestError{Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature"}
return HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Invalid query parameter: \"partition\" - Partitions are a Consul Enterprise feature",
}
}
return nil

View File

@ -1488,10 +1488,16 @@ func TestAllowedNets(t *testing.T) {
t.Fatalf("bad checkWriteAccess for values %+v, got %v", v, err)
}
_, isForbiddenErr := err.(ForbiddenError)
if err != nil && !isForbiddenErr {
t.Fatalf("expected ForbiddenError but got: %s", err)
if err != nil {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 403 {
t.Fatalf("expected 403 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP Error but got %v", err)
}
}
}
}

View File

@ -56,8 +56,9 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque
if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil {
return nil, err
}
if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") {
return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"}
}
args := structs.IntentionRequest{
@ -70,10 +71,10 @@ func (s *HTTPHandlers) IntentionCreate(resp http.ResponseWriter, req *http.Reque
}
if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"}
}
if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"}
}
args.Intention.FillPartitionAndNamespace(&entMeta, false)
@ -324,7 +325,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
// Not ideal, but there are a number of error scenarios that are not
@ -332,7 +333,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req
// to detect a parameter error and return a 400 response. The error
// is not a constant type or message, so we have to use strings.Contains
if strings.Contains(err.Error(), "UUID") {
return nil, BadRequestError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
}
return nil, err
@ -366,7 +367,7 @@ func (s *HTTPHandlers) IntentionPutExact(resp http.ResponseWriter, req *http.Req
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Intention); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Explicitly CLEAR the old legacy ID field
@ -520,7 +521,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter,
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
// Not ideal, but there are a number of error scenarios that are not
@ -528,7 +529,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter,
// to detect a parameter error and return a 400 response. The error
// is not a constant type or message, so we have to use strings.Contains
if strings.Contains(err.Error(), "UUID") {
return nil, BadRequestError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: err.Error()}
}
return nil, err
@ -552,8 +553,9 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit
if err := s.parseEntMetaNoWildcard(req, &entMeta); err != nil {
return nil, err
}
if entMeta.PartitionOrDefault() != acl.PartitionOrDefault("") {
return nil, BadRequestError{Reason: "Cannot use a partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot use a partition with this endpoint"}
}
args := structs.IntentionRequest{
@ -562,14 +564,14 @@ func (s *HTTPHandlers) IntentionSpecificUpdate(id string, resp http.ResponseWrit
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Intention); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
if args.Intention.DestinationPartition != "" && args.Intention.DestinationPartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a destination partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a destination partition with this endpoint"}
}
if args.Intention.SourcePartition != "" && args.Intention.SourcePartition != "default" {
return nil, BadRequestError{Reason: "Cannot specify a source partition with this endpoint"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Cannot specify a source partition with this endpoint"}
}
args.Intention.FillPartitionAndNamespace(&entMeta, false)

View File

@ -483,7 +483,7 @@ func TestIntentionSpecificGet(t *testing.T) {
obj, err := a.srv.IntentionSpecific(resp, req)
require.Nil(t, obj)
require.Error(t, err)
require.IsType(t, BadRequestError{}, err)
require.True(t, isHTTPBadRequest(err))
require.Contains(t, err.Error(), "UUID")
})

View File

@ -56,7 +56,7 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args
if _, ok := params["recurse"]; ok {
method = "KVS.List"
} else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
}
// Do not allow wildcard NS on GET reqs
@ -157,7 +157,7 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
return nil, err
}
if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
}
if conflictingFlags(resp, req, "cas", "acquire", "release") {
return nil, nil
@ -208,7 +208,8 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
// Check the content-length
if req.ContentLength > int64(s.agent.config.KVMaxValueSize) {
return nil, EntityTooLargeError{
return nil, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/config/config-files#kv_max_value_size"),
}
@ -257,7 +258,7 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar
if _, ok := params["recurse"]; ok {
applyReq.Op = api.KVDeleteTree
} else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing key name"}
}
// Check for cas value

View File

@ -49,10 +49,13 @@ func (s *HTTPHandlers) OperatorRaftPeer(resp http.ResponseWriter, req *http.Requ
}
if !hasID && !hasAddress {
return nil, BadRequestError{Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove"}
return nil, HTTPError{
StatusCode: http.StatusBadRequest,
Reason: "Must specify either ?id with the server's ID or ?address with IP:port of peer to remove",
}
}
if hasID && hasAddress {
return nil, BadRequestError{Reason: "Must specify only one of ?id or ?address"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Must specify only one of ?id or ?address"}
}
var reply struct{}
@ -79,7 +82,7 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
var args keyringArgs
if req.Method == "POST" || req.Method == "PUT" || req.Method == "DELETE" {
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
}
s.parseToken(req, &args.Token)
@ -88,12 +91,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
if relayFactor := req.URL.Query().Get("relay-factor"); relayFactor != "" {
n, err := strconv.Atoi(relayFactor)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing relay factor: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing relay factor: %v", err)}
}
args.RelayFactor, err = ParseRelayFactor(n)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid relay-factor: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid relay-factor: %v", err)}
}
}
@ -102,12 +105,12 @@ func (s *HTTPHandlers) OperatorKeyringEndpoint(resp http.ResponseWriter, req *ht
var err error
args.LocalOnly, err = strconv.ParseBool(localOnly)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing local-only: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing local-only: %v", err)}
}
err = ValidateLocalOnly(args.LocalOnly, req.Method == "GET")
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid use of local-only: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Invalid use of local-only: %v", err)}
}
}
@ -226,7 +229,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter,
conf := api.NewAutopilotConfiguration()
if err := decodeBody(req.Body, &conf); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing autopilot config: %v", err)}
}
args.Config = structs.AutopilotConfig{
@ -245,7 +248,7 @@ func (s *HTTPHandlers) OperatorAutopilotConfiguration(resp http.ResponseWriter,
if _, ok := params["cas"]; ok {
casVal, err := strconv.ParseUint(params.Get("cas"), 10, 64)
if err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Error parsing cas value: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Error parsing cas value: %v", err)}
}
args.Config.ModifyIndex = casVal
args.CAS = true

View File

@ -23,7 +23,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Query); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
var reply string
@ -143,7 +143,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter,
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
return nil, err
}
@ -196,7 +196,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
return nil, err
}
@ -225,7 +225,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
return nil, NotFoundError{Reason: err.Error()}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: err.Error()}
}
return nil, err
}
@ -247,7 +247,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter,
s.parseToken(req, &args.Token)
if req.ContentLength > 0 {
if err := decodeBody(req.Body, &args.Query); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
}

View File

@ -621,8 +621,12 @@ func TestPreparedQuery_Execute(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body)
resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
}
})
}
@ -756,8 +760,12 @@ func TestPreparedQuery_Explain(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body)
resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
}
})
@ -845,8 +853,12 @@ func TestPreparedQuery_Get(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body)
resp := httptest.NewRecorder()
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 404 {
t.Fatalf("expected status 404 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
}
})
}

View File

@ -40,7 +40,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request
// Handle optional request body
if req.ContentLength > 0 {
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
}
@ -75,7 +75,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques
return nil, err
}
if args.Session.ID == "" {
return nil, BadRequestError{Reason: "Missing session"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
}
var out string
@ -103,14 +103,14 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request)
}
args.Session = args.SessionID
if args.SessionID == "" {
return nil, BadRequestError{Reason: "Missing session"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
}
var out structs.IndexedSessions
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
return nil, err
} else if out.Sessions == nil {
return nil, NotFoundError{Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)}
}
return out.Sessions, nil
@ -134,7 +134,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) (
}
args.Session = args.SessionID
if args.SessionID == "" {
return nil, BadRequestError{Reason: "Missing session"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing session"}
}
var out structs.IndexedSessions
@ -190,7 +190,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque
return nil, err
}
if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
}
var out structs.IndexedSessions

View File

@ -88,7 +88,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// Check Content-Length first before decoding to return early
if req.ContentLength > maxTxnLen {
return nil, 0, EntityTooLargeError{
return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"),
}
@ -100,7 +101,8 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
if err.Error() == "http: request body too large" {
// The request size is also verified during decoding to double check
// if the Content-Length header was not set by the client.
return nil, 0, EntityTooLargeError{
return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.",
maxTxnLen, "https://www.consul.io/docs/agent/config/config-files#txn_max_req_len"),
}
@ -108,15 +110,16 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// Note the body is in API format, and not the RPC format. If we can't
// decode it, we will return a 400 since we don't have enough context to
// associate the error with a given operation.
return nil, 0, BadRequestError{Reason: fmt.Sprintf("Failed to parse body: %v", err)}
return nil, 0, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Failed to parse body: %v", err)}
}
}
// Enforce a reasonable upper limit on the number of operations in a
// transaction in order to curb abuse.
if size := len(ops); size > maxTxnOps {
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps),
return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps),
}
}
@ -130,8 +133,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
case in.KV != nil:
size := len(in.KV.Value)
if int64(size) > kvMaxValueSize {
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize),
return nil, 0, HTTPError{
StatusCode: http.StatusRequestEntityTooLarge,
Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize),
}
}

View File

@ -12,6 +12,7 @@ import (
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
@ -31,10 +32,7 @@ func TestTxnEndpoint_Bad_JSON(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
_, err := a.srv.Txn(resp, req)
err, ok := err.(BadRequestError)
if !ok {
t.Fatalf("expected bad request error but got %v", err)
}
require.True(t, isHTTPBadRequest(err), fmt.Sprintf("Expected bad request HTTP error but got %v", err))
if !strings.Contains(err.Error(), "Failed to parse") {
t.Fatalf("expected conflicting args error")
}
@ -63,11 +61,19 @@ func TestTxnEndpoint_Bad_Size_Item(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
_, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err)
}
if err != nil && wantPass {
t.Fatalf("err: %v", err)
if wantPass {
if err != nil {
t.Fatalf("err: %v", err)
}
} else {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("excected HTTP error but got %v", err)
}
}
}
@ -138,11 +144,19 @@ func TestTxnEndpoint_Bad_Size_Net(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
_, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err)
}
if err != nil && wantPass {
t.Fatalf("err: %v", err)
if wantPass {
if err != nil {
t.Fatalf("err: %v", err)
}
} else {
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("excected HTTP error but got %v", err)
}
}
}
@ -205,8 +219,13 @@ func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) {
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
_, err := a.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok {
t.Fatalf("expected too large error but got %v", err)
if err, ok := err.(HTTPError); ok {
if err.StatusCode != 413 {
t.Fatalf("expected 413 but got %d", err.StatusCode)
}
} else {
t.Fatalf("expected HTTP error but got %v", err)
}
}

View File

@ -140,7 +140,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) (
return nil, err
}
if args.Node == "" {
return nil, BadRequestError{Reason: "Missing node name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing node name"}
}
// Make the RPC request
@ -272,7 +272,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing gateway name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
}
// Make the RPC request
@ -316,12 +316,12 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
return nil, err
}
if args.ServiceName == "" {
return nil, BadRequestError{Reason: "Missing service name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service name"}
}
kind, ok := req.URL.Query()["kind"]
if !ok {
return nil, BadRequestError{Reason: "Missing service kind"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing service kind"}
}
args.ServiceKind = structs.ServiceKind(kind[0])
@ -329,7 +329,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
case structs.ServiceKindTypical, structs.ServiceKindIngressGateway:
// allowed
default:
return nil, BadRequestError{Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)}
}
// Make the RPC request
@ -594,7 +594,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R
return nil, err
}
if name == "" {
return nil, BadRequestError{Reason: "Missing gateway name"}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Missing gateway name"}
}
args.Match = &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,
@ -624,14 +624,14 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
// Check the UI was enabled at agent startup (note this is not reloadable
// currently).
if !s.IsUIEnabled() {
return nil, NotFoundError{Reason: "UI is not enabled"}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "UI is not enabled"}
}
// Load reloadable proxy config
cfg, ok := s.metricsProxyCfg.Load().(config.UIMetricsProxy)
if !ok || cfg.BaseURL == "" {
// Proxy not configured
return nil, NotFoundError{Reason: "Metrics proxy is not enabled"}
return nil, HTTPError{StatusCode: http.StatusNotFound, Reason: "Metrics proxy is not enabled"}
}
// Fetch the ACL token, if provided, but ONLY from headers since other
@ -686,7 +686,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
u, err := url.Parse(newURL)
if err != nil {
log.Error("couldn't parse target URL", "base_url", cfg.BaseURL, "path", subPath)
return nil, BadRequestError{Reason: "Invalid path."}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."}
}
// Clean the new URL path to prevent path traversal attacks and remove any
@ -735,7 +735,7 @@ func (s *HTTPHandlers) UIMetricsProxy(resp http.ResponseWriter, req *http.Reques
"path", subPath,
"target_url", u.String(),
)
return nil, BadRequestError{Reason: "Invalid path."}
return nil, HTTPError{StatusCode: http.StatusBadRequest, Reason: "Invalid path."}
}
// Add any configured headers